Better warning messages

This commit is contained in:
oobabooga 2023-05-03 21:43:17 -03:00
parent 0a48b29cd8
commit 95d04d6a8d
13 changed files with 194 additions and 83 deletions

View file

@ -1,4 +1,5 @@
import inspect
import logging
import re
import sys
from pathlib import Path
@ -71,7 +72,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
del layers
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint), strict=False)
@ -90,8 +90,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
quant.autotune_warmup_fused(model)
model.seqlen = 2048
print('Done.')
return model
@ -119,11 +117,13 @@ def find_quantized_model_file(model_name):
if len(found_pts) > 0:
if len(found_pts) > 1:
print('Warning: more than one .pt model has been found. The last one will be selected. It could be wrong.')
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found_pts[-1]
elif len(found_safetensors) > 0:
if len(found_pts) > 1:
print('Warning: more than one .safetensors model has been found. The last one will be selected. It could be wrong.')
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
pt_path = found_safetensors[-1]
return pt_path
@ -142,8 +142,7 @@ def load_quantized(model_name):
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "
"argument")
logging.error("Can't determine model type from model name. Please specify it manually using --model_type argument")
exit()
else:
model_type = shared.args.model_type.lower()
@ -153,20 +152,21 @@ def load_quantized(model_name):
load_quant = llama_inference_offload.load_quant
elif model_type in ('llama', 'opt', 'gptj'):
if shared.args.pre_layer:
print("Warning: ignoring --pre_layer because it only works for llama model type.")
logging.warning("Ignoring --pre_layer because it only works for llama model type.")
load_quant = _load_quant
else:
print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
exit()
# Find the quantized model weights file (.pt/.safetensors)
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
pt_path = find_quantized_model_file(model_name)
if not pt_path:
print("Could not find the quantized model in .pt or .safetensors format, exiting...")
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
exit()
else:
print(f"Found the following quantized model: {pt_path}")
logging.info(f"Found the following quantized model: {pt_path}")
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:
@ -188,7 +188,7 @@ def load_quantized(model_name):
max_memory = accelerate.utils.get_balanced_memory(model)
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
print("Using the following device map for the quantized model:", device_map)
logging.info("Using the following device map for the quantized model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)