Exllamav2 lora support (#4229)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
1f5a2c5597
commit
8cce1f1126
4 changed files with 47 additions and 12 deletions
|
@ -98,7 +98,9 @@ class ExllamaModel:
|
|||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
if token_ids.shape[-1] > 1:
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue