Add various checks to model loading functions

This commit is contained in:
oobabooga 2023-05-17 15:52:23 -03:00
parent abd361b3a0
commit ef10ffc6b4
2 changed files with 28 additions and 19 deletions

View file

@ -13,19 +13,18 @@ def load_quantized(model_name):
use_safetensors = False
# Find the model checkpoint
found_pts = list(path_to_model.glob("*.pt"))
found_safetensors = list(path_to_model.glob("*.safetensors"))
if len(found_safetensors) > 0:
if len(found_safetensors) > 1:
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
for ext in ['.safetensors', '.pt', '.bin']:
found = list(path_to_model.glob(f"*{ext}"))
if len(found) > 0:
if len(found) > 1:
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
use_safetensors = True
pt_path = found_safetensors[-1]
elif len(found_pts) > 0:
if len(found_pts) > 1:
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
pt_path = found[-1]
break
pt_path = found_pts[-1]
if pt_path is None:
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
return
# Define the params for AutoGPTQForCausalLM.from_quantized
params = {