gpu-memory arg change

This commit is contained in:
luis 2023-02-23 18:43:55 -05:00
parent 9ae063e42b
commit 5abdc99a7c
2 changed files with 7 additions and 2 deletions

View file

@ -96,7 +96,12 @@ def load_model(model_name):
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
if shared.args.gpu_memory:
params.append(f"max_memory={{0: '{shared.args.gpu_memory or '99'}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
memory_map = shared.args.gpu_memory.split(",")
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
for i in range(1,len(memory_map)):
max_memory+=(f", {i}: '{memory_map[i]}GiB'")
max_memory+=(f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'" + "}")
params.append(max_memory)
elif not shared.args.load_in_8bit:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
suggestion = round((total_mem-1000)/1000)*1000