initial multi-lora support (#1103)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Alex "mcmonkey" Goodwin 2023-04-14 10:52:06 -07:00 committed by GitHub
parent ebb81eb176
commit 64e3b44e0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 24 deletions

View file

@ -4,19 +4,31 @@ import torch
from peft import PeftModel
import modules.shared as shared
from modules.models import reload_model
def add_lora_to_model(lora_name):
def add_lora_to_model(lora_names):
prior_set = set(shared.lora_names)
added_set = set(lora_names) - prior_set
removed_set = prior_set - set(lora_names)
shared.lora_names = list(lora_names)
# If a LoRA had been previously loaded, or if we want
# to unload a LoRA, reload the model
if shared.lora_name not in ['None', ''] or lora_name in ['None', '']:
reload_model()
shared.lora_name = lora_name
# Nothing to do = skip.
if len(added_set) == 0 and len(removed_set) == 0:
return
if lora_name not in ['None', '']:
print(f"Adding the LoRA {lora_name} to the model...")
# Only adding, and already peft? Do it the easy way.
if len(removed_set) == 0 and len(prior_set) > 0:
print(f"Adding the LoRA(s) named {added_set} to the model...")
for lora in added_set:
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
return
# If removing anything, disable all and re-add.
if len(removed_set) > 0:
shared.model.disable_adapter()
if len(lora_names) > 0:
print("Applying the following LoRAs to {}: {}".format(shared.model_name, ', '.join(lora_names)))
params = {}
if not shared.args.cpu:
params['dtype'] = shared.model.dtype
@ -25,7 +37,11 @@ def add_lora_to_model(lora_name):
elif shared.args.load_in_8bit:
params['device_map'] = {'': 0}
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_names[0]}"), **params)
for lora in lora_names[1:]:
shared.model.load_adapter(Path(f"{shared.args.lora_dir}/{lora}"), lora)
if not shared.args.load_in_8bit and not shared.args.cpu:
shared.model.half()
if not hasattr(shared.model, "hf_device_map"):