token probs for non HF loaders (#3957)

This commit is contained in:
saltacc 2023-09-17 13:42:32 +00:00 committed by GitHub
parent 0668f4e67f
commit cd08eb0753
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 5 deletions

View file

@ -1,6 +1,7 @@
import re
from functools import partial
import numpy as np
import torch
from modules import RoPE, shared
@ -100,6 +101,12 @@ class LlamaCppModel:
def decode(self, tokens):
return self.model.detokenize(tokens)
def get_logits(self, tokens):
self.model.eval(tokens)
logits = self.model._scores
logits = np.expand_dims(logits, 0) # batch dim is expected
return torch.tensor(logits, dtype=torch.float32)
def generate(self, prompt, state, callback=None):
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList