Exllamav2 lora support (#4229)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Forkoz 2023-10-14 19:12:41 +00:00 committed by GitHub
parent 1f5a2c5597
commit 8cce1f1126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 12 deletions

View file

@ -98,7 +98,9 @@ class ExllamaModel:
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):