Add LoRA support to ExLlama_HF
This commit is contained in:
parent
b7c627f9a0
commit
22d455b072
3 changed files with 17 additions and 7 deletions
|
@ -30,6 +30,7 @@ class ExllamaHF(PreTrainedModel):
|
|||
self.ex_config = config
|
||||
self.ex_model = ExLlama(self.ex_config)
|
||||
self.generation_config = GenerationConfig()
|
||||
self.lora = None
|
||||
|
||||
def _validate_model_class(self):
|
||||
pass
|
||||
|
@ -53,9 +54,9 @@ class ExllamaHF(PreTrainedModel):
|
|||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
if cache is None:
|
||||
cache = ExLlamaCache(self.ex_model)
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue