Make llamacpp_HF 6x faster

This commit is contained in:
oobabooga 2023-08-01 13:15:14 -07:00
parent 385229313f
commit b53ed70a70
2 changed files with 2 additions and 2 deletions

View file

@ -56,7 +56,7 @@ class LlamacppHF(PreTrainedModel):
else:
self.model.eval([seq[-1]])
logits = torch.tensor(self.model.eval_logits[-1]).view(1, 1, -1).to(kwargs['input_ids'].device)
logits = torch.tensor(self.model.scores[self.model.n_tokens-1, :]).view(1, 1, -1).to(kwargs['input_ids'].device)
else:
self.model.reset()
self.model.eval(seq)