diff --git a/server.py b/server.py index ea05dec..06e6529 100644 --- a/server.py +++ b/server.py @@ -99,7 +99,7 @@ def load_model(model_name): if args.gpu_memory: settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}") - elif not args.load_in_8bit: + elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit: total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024)) suggestion = round((total_mem-1000)/1000)*1000 if total_mem-suggestion < 800: