gpu-memory arg change
This commit is contained in:
parent
9ae063e42b
commit
5abdc99a7c
2 changed files with 7 additions and 2 deletions
|
@ -96,7 +96,12 @@ def load_model(model_name):
|
|||
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
params.append(f"max_memory={{0: '{shared.args.gpu_memory or '99'}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
|
||||
memory_map = shared.args.gpu_memory.split(",")
|
||||
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
|
||||
for i in range(1,len(memory_map)):
|
||||
max_memory+=(f", {i}: '{memory_map[i]}GiB'")
|
||||
max_memory+=(f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'" + "}")
|
||||
params.append(max_memory)
|
||||
elif not shared.args.load_in_8bit:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
||||
suggestion = round((total_mem-1000)/1000)*1000
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue