Update comments

This commit is contained in:
oobabooga 2023-03-20 16:40:08 -03:00 committed by GitHub
parent 7618f3fe8c
commit db4219a340
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -57,13 +57,13 @@ def load_quantized(model_name):
print(f"Could not find {pt_model}, exiting...") print(f"Could not find {pt_model}, exiting...")
exit() exit()
# Using qwopqwop200's offload # qwopqwop200's offload
if shared.args.gptq_pre_layer: if shared.args.gptq_pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer) model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits, shared.args.gptq_pre_layer)
else: else:
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits) model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
# Using accelerate offload (doesn't work properly) # accelerate offload (doesn't work properly)
if shared.args.gpu_memory: if shared.args.gpu_memory:
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory)) memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB' max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
@ -76,6 +76,8 @@ def load_quantized(model_name):
print("Using the following device map for the 4-bit model:", device_map) print("Using the following device map for the 4-bit model:", device_map)
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model # https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True) model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
# No offload
elif not shared.args.cpu: elif not shared.args.cpu:
model = model.to(torch.device('cuda:0')) model = model.to(torch.device('cuda:0'))