diff --git a/modules/quant_loader.py b/modules/quant_loader.py index 7a5f846..c272349 100644 --- a/modules/quant_loader.py +++ b/modules/quant_loader.py @@ -7,6 +7,8 @@ import torch import modules.shared as shared sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa"))) +import llama +import opt def load_quantized(model_name): @@ -21,9 +23,9 @@ def load_quantized(model_name): model_type = shared.args.gptq_model_type.lower() if model_type == 'llama': - from llama import load_quant + load_quant = llama.load_quant elif model_type == 'opt': - from opt import load_quant + load_quant = opt.load_quant else: print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported") exit() @@ -50,7 +52,7 @@ def load_quantized(model_name): print(f"Could not find {pt_model}, exiting...") exit() - model = load_quant(path_to_model, str(pt_path), shared.args.gptq_bits) + model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits) # Multiple GPUs or GPU+CPU if shared.args.gpu_memory: