Add AutoGPTQ support (basic) (#2132)

This commit is contained in:
oobabooga 2023-05-17 11:12:12 -03:00 committed by GitHub
parent 10cf7831f7
commit 1a8151a2b6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 2 deletions

View file

@ -72,7 +72,10 @@ def load_model(model_name):
shared.model_type = find_model_type(model_name)
if shared.args.wbits > 0:
load_func = GPTQ_loader
if shared.args.autogptq:
load_func = AutoGPTQ_loader
else:
load_func = GPTQ_loader
elif shared.model_type == 'llamacpp':
load_func = llamacpp_loader
elif shared.model_type == 'rwkv':
@ -261,6 +264,12 @@ def GPTQ_loader(model_name):
return model
def AutoGPTQ_loader(model_name):
from modules.AutoGPTQ_loader import load_quantized
return load_quantized(model_name)
def get_max_memory_dict():
max_memory = {}
if shared.args.gpu_memory:
@ -283,7 +292,7 @@ def get_max_memory_dict():
logging.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
return max_memory
return max_memory if len(max_memory) > 0 else None
def clear_torch_cache():