Experimental jank multiGPU inference that's 2x faster than native somehow (#2100)

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-05-17 06:41:09 -07:00 committed by GitHub
parent fd743a0207
commit 1f50dbe352
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 10 additions and 3 deletions

View file

@ -172,7 +172,12 @@ def load_quantized(model_name):
# qwopqwop200's offload
if model_type == 'llama' and shared.args.pre_layer:
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
if len(shared.args.pre_layer) == 1:
pre_layer = shared.args.pre_layer[0]
else:
pre_layer = shared.args.pre_layer
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, pre_layer)
else:
threshold = False if model_type == 'gptj' else 128
model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, kernel_switch_threshold=threshold)