Exllamav2 lora support (#4229)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Forkoz 2023-10-14 19:12:41 +00:00 committed by GitHub
parent 1f5a2c5597
commit 8cce1f1126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 12 deletions

View file

@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel):
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
self.ex_model.load(split)
self.generation_config = GenerationConfig()
self.loras = None
self.ex_cache = ExLlamaV2Cache(self.ex_model)
self.past_seq = None
@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel):
reset = False
ex_cache.current_seq_len = longest_prefix
if len(seq_tensor) - longest_prefix > 1:
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
elif len(seq_tensor) == longest_prefix:
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel):
if reset:
ex_cache.current_seq_len = 0
if len(seq_tensor) > 1:
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
else:
ex_cache.current_seq_len = 0
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
if is_negative:
self.past_seq_negative = seq_tensor