Add --load-in-4bit parameter (#2320)
This commit is contained in:
parent
63ce5f9c28
commit
361451ba60
6 changed files with 61 additions and 22 deletions
|
@ -149,7 +149,7 @@ def huggingface_loader(model_name):
|
|||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=shared.args.trust_remote_code)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
|
@ -179,7 +179,21 @@ def huggingface_loader(model_name):
|
|||
params["torch_dtype"] = torch.float32
|
||||
else:
|
||||
params["device_map"] = 'auto'
|
||||
if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
|
||||
if shared.args.load_in_4bit:
|
||||
|
||||
# See https://github.com/huggingface/transformers/pull/23479/files
|
||||
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
||||
quantization_config_params = {
|
||||
'load_in_4bit': True,
|
||||
'bnb_4bit_compute_dtype': eval("torch.{}".format(shared.args.compute_dtype)) if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
|
||||
'bnb_4bit_quant_type': shared.args.quant_type,
|
||||
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
|
||||
}
|
||||
|
||||
logger.warning("Using the following 4-bit params: " + str(quantization_config_params))
|
||||
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
|
||||
|
||||
elif shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||
elif shared.args.load_in_8bit:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue