From 25be9698c74d7af950cbcbf8ec4c0cd9bebc6d3c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Mar 2023 01:18:32 -0300 Subject: [PATCH] Fix LoRA on mps --- modules/LoRA.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index aa68ad3..283fcf4 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -1,5 +1,7 @@ from pathlib import Path +import torch + import modules.shared as shared from modules.models import load_model from modules.text_generation import clear_torch_cache @@ -34,4 +36,8 @@ def add_lora_to_model(lora_name): if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): - shared.model.cuda() + if torch.has_mps: + device = torch.device('mps') + shared.model = shared.model.to(device) + else: + shared.model = shared.model.cuda()