token probs for non HF loaders (#3957)
This commit is contained in:
parent
0668f4e67f
commit
cd08eb0753
5 changed files with 53 additions and 5 deletions
|
@ -113,3 +113,8 @@ class Exllamav2Model:
|
|||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue