token probs for non HF loaders (#3957)
This commit is contained in:
parent
0668f4e67f
commit
cd08eb0753
5 changed files with 53 additions and 5 deletions
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import RoPE, shared
|
||||
|
@ -100,6 +101,12 @@ class LlamaCppModel:
|
|||
def decode(self, tokens):
|
||||
return self.model.detokenize(tokens)
|
||||
|
||||
def get_logits(self, tokens):
|
||||
self.model.eval(tokens)
|
||||
logits = self.model._scores
|
||||
logits = np.expand_dims(logits, 0) # batch dim is expected
|
||||
return torch.tensor(logits, dtype=torch.float32)
|
||||
|
||||
def generate(self, prompt, state, callback=None):
|
||||
|
||||
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue