AutoGPTQ: Add --disable_exllamav2 flag (Mixtral CPU offloading needs this)
This commit is contained in:
parent
7de10f4c8e
commit
3bbf6c601d
7 changed files with 16 additions and 4 deletions
|
@ -156,7 +156,7 @@ def huggingface_loader(model_name):
|
|||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama]):
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.compress_pos_emb > 1, shared.args.alpha_value > 1, shared.args.disable_exllama, shared.args.disable_exllamav2]):
|
||||
model = LoaderClass.from_pretrained(path_to_model, **params)
|
||||
if torch.backends.mps.is_available():
|
||||
device = torch.device('mps')
|
||||
|
@ -221,11 +221,16 @@ def huggingface_loader(model_name):
|
|||
if shared.args.disk:
|
||||
params['offload_folder'] = shared.args.disk_cache_dir
|
||||
|
||||
if shared.args.disable_exllama:
|
||||
if shared.args.disable_exllama or shared.args.disable_exllamav2:
|
||||
try:
|
||||
gptq_config = GPTQConfig(bits=config.quantization_config.get('bits', 4), disable_exllama=True)
|
||||
gptq_config = GPTQConfig(
|
||||
bits=config.quantization_config.get('bits', 4),
|
||||
disable_exllama=shared.args.disable_exllama,
|
||||
disable_exllamav2=shared.args.disable_exllamav2,
|
||||
)
|
||||
|
||||
params['quantization_config'] = gptq_config
|
||||
logger.info('Loading with ExLlama kernel disabled.')
|
||||
logger.info(f'Loading with disable_exllama={shared.args.disable_exllama} and disable_exllamav2={shared.args.disable_exllamav2}.')
|
||||
except:
|
||||
exc = traceback.format_exc()
|
||||
logger.error('Failed to disable exllama. Does the config.json for this model contain the necessary quantization info?')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue