Add 4-bit LoRA support (#1200)
This commit is contained in:
parent
ec3e869c27
commit
39099663a0
7 changed files with 100 additions and 34 deletions
|
@ -16,6 +16,8 @@ from modelutils import find_layers
|
|||
from quant import make_quant
|
||||
|
||||
|
||||
# This function is a replacement for the load_quant function in the
|
||||
# GPTQ-for_LLaMa repository. It supports more models and branches.
|
||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
|
@ -64,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
|
||||
try:
|
||||
from quant import autotune_warmup, make_quant_attn
|
||||
|
||||
# triton branch
|
||||
make_quant_attn(model)
|
||||
if not shared.args.no_warmup_autotune:
|
||||
|
@ -77,6 +80,41 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
return model
|
||||
|
||||
|
||||
# Used to locate the .pt/.safetensors quantized file
|
||||
def find_quantized_model_file(model_name):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
pt_path = None
|
||||
priority_name_list = [
|
||||
Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
|
||||
for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
|
||||
for ext in ['.safetensors', '.pt']
|
||||
for hyphen in ['-', f'/{model_name}-', '/']
|
||||
]
|
||||
for path in priority_name_list:
|
||||
if path.exists():
|
||||
pt_path = path
|
||||
break
|
||||
|
||||
# If the model hasn't been found with a well-behaved name, pick the last .pt
|
||||
# or the last .safetensors found in its folder as a last resort
|
||||
if not pt_path:
|
||||
found_pts = list(path_to_model.glob("*.pt"))
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
pt_path = None
|
||||
|
||||
if len(found_pts) > 0:
|
||||
if len(found_pts) > 1:
|
||||
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found_pts[-1]
|
||||
elif len(found_safetensors) > 0:
|
||||
if len(found_pts) > 1:
|
||||
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found_safetensors[-1]
|
||||
|
||||
return pt_path
|
||||
|
||||
|
||||
# The function that loads the model in modules/models.py
|
||||
def load_quantized(model_name):
|
||||
|
||||
# Find the model type
|
||||
|
@ -106,37 +144,9 @@ def load_quantized(model_name):
|
|||
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||
exit()
|
||||
|
||||
# Locate the quantized model file
|
||||
# Find the quantized model weights file (.pt/.safetensors)
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
pt_path = None
|
||||
priority_name_list = [
|
||||
Path(f'{shared.args.model_dir}/{model_name}{hyphen}{shared.args.wbits}bit{group}{ext}')
|
||||
for group in ([f'-{shared.args.groupsize}g', ''] if shared.args.groupsize > 0 else [''])
|
||||
for ext in ['.safetensors', '.pt']
|
||||
for hyphen in ['-', f'/{model_name}-', '/']
|
||||
]
|
||||
for path in priority_name_list:
|
||||
if path.exists():
|
||||
pt_path = path
|
||||
break
|
||||
|
||||
# If the model hasn't been found with a well-behaved name, pick the last .pt
|
||||
# or the last .safetensors found in its folder as a last resort
|
||||
if not pt_path:
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
found_pts = list(path_to_model.glob("*.pt"))
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
pt_path = None
|
||||
|
||||
if len(found_pts) > 0:
|
||||
if len(found_pts) > 1:
|
||||
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found_pts[-1]
|
||||
elif len(found_safetensors) > 0:
|
||||
if len(found_pts) > 1:
|
||||
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found_safetensors[-1]
|
||||
|
||||
pt_path = find_quantized_model_file(model_name)
|
||||
if not pt_path:
|
||||
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||
exit()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue