diff --git a/modules/LoRA.py b/modules/LoRA.py index 2085033..0a2aaa7 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -18,10 +18,10 @@ def add_lora_to_model(lora_name): params = {} if shared.args.load_in_8bit: params['device_map'] = {'': 0} - else: + elif not shared.args.cpu: params['device_map'] = 'auto' params['dtype'] = shared.model.dtype shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) - if not shared.args.load_in_8bit: + if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half()