Don't show .pt models in the list

This commit is contained in:
oobabooga 2023-03-09 21:54:50 -03:00
parent 1a3d25f75d
commit 9849aac0f1
2 changed files with 4 additions and 1 deletions

View file

@ -105,6 +105,9 @@ def load_model(model_name):
if not Path(f"models/{pt_model}").exists():
print(f"Could not find models/{pt_model}, exiting...")
exit()
elif pt_model == '':
print(f"Could not find the .pt model for {model_name}, exiting...")
exit()
model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
model = model.to(torch.device('cuda:0'))