Exllamav2 lora support (#4229)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Forkoz 2023-10-14 19:12:41 +00:00 committed by GitHub
parent 1f5a2c5597
commit 8cce1f1126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 12 deletions

View file

@ -60,6 +60,7 @@ class Exllamav2Model:
result.cache = cache
result.tokenizer = tokenizer
result.generator = generator
result.loras = None
return result, result
def encode(self, string, **kwargs):
@ -75,8 +76,10 @@ class Exllamav2Model:
def get_logits(self, token_ids, **kwargs):
self.cache.current_seq_len = 0
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
if token_ids.shape[-1] > 1:
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, loras=self.loras, **kwargs).float().cpu()
def generate_with_streaming(self, prompt, state):
settings = ExLlamaV2Sampler.Settings()
@ -105,12 +108,12 @@ class Exllamav2Model:
# _gen_begin_base
self.cache.current_seq_len = 0
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
self.model.forward(ids[:, :-1], self.cache, input_mask=None, preprocess_only=True, loras=self.loras)
has_leading_space = False
for i in range(max_new_tokens):
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None).float().cpu()
token, _, _= ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
logits = self.model.forward(ids[:, -1:], self.cache, input_mask=None, loras=self.loras).float().cpu()
token, _, _ = ExLlamaV2Sampler.sample(logits, settings, ids, random.random(), self.tokenizer)
ids = torch.cat([ids, token], dim=1)
if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):