diff --git a/modules/models.py b/modules/models.py index 26a10f7..a683931 100644 --- a/modules/models.py +++ b/modules/models.py @@ -76,7 +76,7 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, shared.model_name, policy) + model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: