diff --git a/server.py b/server.py index 1309c17..4c3497c 100644 --- a/server.py +++ b/server.py @@ -64,9 +64,7 @@ def load_model_wrapper(selected_model): return selected_model def reload_model(): - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() + unload_model() shared.model, shared.tokenizer = load_model(shared.model_name) def unload_model(): @@ -74,6 +72,7 @@ def unload_model(): if not shared.args.cpu: gc.collect() torch.cuda.empty_cache() + print("Model weights unloaded.") def load_lora_wrapper(selected_lora): shared.lora_name = selected_lora