token probs for non HF loaders (#3957)

This commit is contained in:
saltacc 2023-09-17 13:42:32 +00:00 committed by GitHub
parent 0668f4e67f
commit cd08eb0753
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 53 additions and 5 deletions

View file

@ -113,3 +113,8 @@ class Exllamav2Model:
ids = ids.view(1, -1)
return self.tokenizer.decode(ids)[0]
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()