diff --git a/server.py b/server.py index 844d5e6..692c84b 100644 --- a/server.py +++ b/server.py @@ -178,7 +178,7 @@ def load_model(model_name): # DeepSpeed ZeRO-3 elif args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}", no_split_module_classes=["GPTJBlock"])) + model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}")) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None,