From eac27f4f556b2e4fd149e65e2395fbc9ce2ea3c7 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Mar 2023 00:55:33 -0300 Subject: [PATCH] Make LoRAs work in 16-bit mode --- modules/LoRA.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 6915e15..2085033 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -13,10 +13,15 @@ def add_lora_to_model(lora_name): print("Reloading the model to remove the LoRA...") shared.model, shared.tokenizer = load_model(shared.model_name) else: - # Why doesn't this work in 16-bit mode? print(f"Adding the LoRA {lora_name} to the model...") - + params = {} - params['device_map'] = {'': 0} - #params['dtype'] = shared.model.dtype + if shared.args.load_in_8bit: + params['device_map'] = {'': 0} + else: + params['device_map'] = 'auto' + params['dtype'] = shared.model.dtype + shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) + if not shared.args.load_in_8bit: + shared.model.half()