Fix a bug in llama.cpp get_logits() function
This commit is contained in:
parent
000b77a17d
commit
092a2c3516
1 changed files with 1 additions and 0 deletions
|
@ -105,6 +105,7 @@ class LlamaCppModel:
|
||||||
return self.model.detokenize(ids).decode('utf-8')
|
return self.model.detokenize(ids).decode('utf-8')
|
||||||
|
|
||||||
def get_logits(self, tokens):
|
def get_logits(self, tokens):
|
||||||
|
self.model.reset()
|
||||||
self.model.eval(tokens)
|
self.model.eval(tokens)
|
||||||
logits = self.model._scores
|
logits = self.model._scores
|
||||||
logits = np.expand_dims(logits, 0) # batch dim is expected
|
logits = np.expand_dims(logits, 0) # batch dim is expected
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue