Add LoRA support to ExLlama_HF

This commit is contained in:
oobabooga 2023-06-26 00:10:13 -03:00
parent b7c627f9a0
commit 22d455b072
3 changed files with 17 additions and 7 deletions

View file

@ -11,7 +11,7 @@ from modules.models import reload_model
def add_lora_to_model(lora_names):
if 'GPTQForCausalLM' in shared.model.__class__.__name__:
add_lora_autogptq(lora_names)
elif shared.model.__class__.__name__ == 'ExllamaModel':
elif shared.model.__class__.__name__ in ['ExllamaModel', 'ExllamaHF']:
add_lora_exllama(lora_names)
else:
add_lora_transformers(lora_names)
@ -29,7 +29,11 @@ def add_lora_exllama(lora_names):
return
if len(lora_names) == 0:
shared.model.generator.lora = None
if shared.model.__class__.__name__ == 'ExllamaModel':
shared.model.generator.lora = None
else:
shared.model.lora = None
shared.lora_names = []
return
else:
@ -41,8 +45,13 @@ def add_lora_exllama(lora_names):
lora_adapter_path = lora_path / "adapter_model.bin"
logger.info("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join([lora_names[0]])))
lora = ExLlamaLora(shared.model.model, str(lora_config_path), str(lora_adapter_path))
shared.model.generator.lora = lora
if shared.model.__class__.__name__ == 'ExllamaModel':
lora = ExLlamaLora(shared.model.model, str(lora_config_path), str(lora_adapter_path))
shared.model.generator.lora = lora
else:
lora = ExLlamaLora(shared.model.ex_model, str(lora_config_path), str(lora_adapter_path))
shared.model.lora = lora
shared.lora_names = [lora_names[0]]
return

View file

@ -30,6 +30,7 @@ class ExllamaHF(PreTrainedModel):
self.ex_config = config
self.ex_model = ExLlama(self.ex_config)
self.generation_config = GenerationConfig()
self.lora = None
def _validate_model_class(self):
pass
@ -53,9 +54,9 @@ class ExllamaHF(PreTrainedModel):
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
if cache is None:
cache = ExLlamaCache(self.ex_model)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True)
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(kwargs['input_ids'].device)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
loss = None
if labels is not None: