Update to support GPTQ triton commit c90adef (#1229)

This commit is contained in:
sgsdxzy 2023-04-17 12:11:18 +08:00 committed by GitHub
parent 209fcd21d5
commit b57ffc2ec9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 23 deletions

View file

@ -13,12 +13,18 @@ import modules.shared as shared
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
import llama_inference_offload
from modelutils import find_layers
from quant import make_quant
try:
from quant import make_quant
is_triton = False
except ImportError:
import quant
is_triton = True
# This function is a replacement for the load_quant function in the
# GPTQ-for_LLaMa repository. It supports more models and branches.
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128, eval=True):
def noop(*args, **kwargs):
pass
@ -33,27 +39,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config)
torch.set_default_dtype(torch.float)
model = model.eval()
if eval:
model = model.eval()
layers = find_layers(model)
for name in exclude_layers:
if name in layers:
del layers[name]
gptq_args = inspect.getfullargspec(make_quant).args
if not is_triton:
gptq_args = inspect.getfullargspec(make_quant).args
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant_kwargs = {
'module': model,
'names': layers,
'bits': wbits,
}
if 'groupsize' in gptq_args:
make_quant_kwargs['groupsize'] = groupsize
if 'faster' in gptq_args:
make_quant_kwargs['faster'] = faster_kernel
if 'kernel_switch_threshold' in gptq_args:
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
make_quant(**make_quant_kwargs)
make_quant(**make_quant_kwargs)
else:
quant.make_quant_linear(model, layers, wbits, groupsize)
del layers
@ -64,15 +74,16 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
else:
model.load_state_dict(torch.load(checkpoint), strict=False)
try:
from quant import autotune_warmup, make_quant_attn
if is_triton:
if not shared.args.no_quant_attn:
quant.make_quant_attn(model)
if eval and not shared.args.no_fused_mlp:
quant.make_fused_mlp(model)
# triton branch
make_quant_attn(model)
if not shared.args.no_warmup_autotune:
autotune_warmup(model)
except ImportError: # not triton branch
pass
quant.autotune_warmup_linear(model, transpose=not eval)
if eval and not shared.args.no_fused_mlp:
quant.autotune_warmup_fused(model)
model.seqlen = 2048
print('Done.')