Exllamav2 lora support (#4229)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
1f5a2c5597
commit
8cce1f1126
4 changed files with 47 additions and 12 deletions
|
|
@ -33,8 +33,8 @@ class Exllamav2HF(PreTrainedModel):
|
|||
split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
|
||||
|
||||
self.ex_model.load(split)
|
||||
|
||||
self.generation_config = GenerationConfig()
|
||||
self.loras = None
|
||||
|
||||
self.ex_cache = ExLlamaV2Cache(self.ex_model)
|
||||
self.past_seq = None
|
||||
|
|
@ -97,7 +97,7 @@ class Exllamav2HF(PreTrainedModel):
|
|||
reset = False
|
||||
ex_cache.current_seq_len = longest_prefix
|
||||
if len(seq_tensor) - longest_prefix > 1:
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
self.ex_model.forward(seq_tensor[longest_prefix:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||
elif len(seq_tensor) == longest_prefix:
|
||||
# Very tricky: if the prefix we are reusing *is* the input_ids, then we have to back up the cache pointer by one,
|
||||
# because we feed input_ids[-1] to forward() below, but that last token is already in the cache!
|
||||
|
|
@ -106,12 +106,12 @@ class Exllamav2HF(PreTrainedModel):
|
|||
if reset:
|
||||
ex_cache.current_seq_len = 0
|
||||
if len(seq_tensor) > 1:
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True)
|
||||
self.ex_model.forward(seq_tensor[:-1].view(1, -1), ex_cache, preprocess_only=True, loras=self.loras)
|
||||
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache).to(input_ids.device)
|
||||
logits = self.ex_model.forward(seq_tensor[-1:].view(1, -1), ex_cache, loras=self.loras).to(input_ids.device)
|
||||
else:
|
||||
ex_cache.current_seq_len = 0
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False)
|
||||
logits = self.ex_model.forward(seq_tensor.view(1, -1), ex_cache, last_id_only=False, loras=self.loras)
|
||||
|
||||
if is_negative:
|
||||
self.past_seq_negative = seq_tensor
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue