Add various checks to model loading functions
This commit is contained in:
parent
abd361b3a0
commit
ef10ffc6b4
2 changed files with 28 additions and 19 deletions
|
|
@ -13,19 +13,18 @@ def load_quantized(model_name):
|
|||
use_safetensors = False
|
||||
|
||||
# Find the model checkpoint
|
||||
found_pts = list(path_to_model.glob("*.pt"))
|
||||
found_safetensors = list(path_to_model.glob("*.safetensors"))
|
||||
if len(found_safetensors) > 0:
|
||||
if len(found_safetensors) > 1:
|
||||
logging.warning('More than one .safetensors model has been found. The last one will be selected. It could be wrong.')
|
||||
for ext in ['.safetensors', '.pt', '.bin']:
|
||||
found = list(path_to_model.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
if len(found) > 1:
|
||||
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
|
||||
use_safetensors = True
|
||||
pt_path = found_safetensors[-1]
|
||||
elif len(found_pts) > 0:
|
||||
if len(found_pts) > 1:
|
||||
logging.warning('More than one .pt model has been found. The last one will be selected. It could be wrong.')
|
||||
pt_path = found[-1]
|
||||
break
|
||||
|
||||
pt_path = found_pts[-1]
|
||||
if pt_path is None:
|
||||
logging.error("The model could not be loaded because its checkpoint file in .bin/.pt/.safetensors format could not be located.")
|
||||
return
|
||||
|
||||
# Define the params for AutoGPTQForCausalLM.from_quantized
|
||||
params = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue