AutoGPTQ: Add --disable_exllamav2 flag (Mixtral CPU offloading needs this)

This commit is contained in:
oobabooga 2023-12-15 06:46:13 -08:00
parent 7de10f4c8e
commit 3bbf6c601d
7 changed files with 16 additions and 4 deletions

View file

@ -156,7 +156,7 @@ def huggingface_loader(model_name):
LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama, shared.args.disable_exllamav2]):
model = LoaderClass.from_pretrained(path_to_model, **params)
if torch.backends.mps.is_available():
device = torch.device('mps')
@ -221,11 +221,16 @@ def huggingface_loader(model_name):
if shared.args.disk:
params['offload_folder'] = shared.args.disk_cache_dir
if shared.args.disable_exllama:
if shared.args.disable_exllama or shared.args.disable_exllamav2:
try:
gptq_config = GPTQConfig(bits=config.quantization_config.get('bits', 4), disable_exllama=True)
gptq_config = GPTQConfig(
bits=config.quantization_config.get('bits', 4),
disable_exllama=shared.args.disable_exllama,
disable_exllamav2=shared.args.disable_exllamav2,
)
params['quantization_config'] = gptq_config
logger.info('Loading with ExLlama kernel disabled.')
logger.info(f'Loading with disable_exllama={shared.args.disable_exllama} and disable_exllamav2={shared.args.disable_exllamav2}.')
except:
exc = traceback.format_exc()
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')