Streamline GPTQ-for-LLaMa support
This commit is contained in:
parent
a3295dd666
commit
bee73cedbd
5 changed files with 21 additions and 55 deletions
|
@ -11,26 +11,9 @@ from transformers import AutoConfig, AutoModelForCausalLM
|
|||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||
|
||||
try:
|
||||
import llama_inference_offload
|
||||
except ImportError:
|
||||
logger.error('Failed to load GPTQ-for-LLaMa')
|
||||
logger.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
|
||||
sys.exit(-1)
|
||||
|
||||
try:
|
||||
from modelutils import find_layers
|
||||
except ImportError:
|
||||
from utils import find_layers
|
||||
|
||||
try:
|
||||
from quant import make_quant
|
||||
is_triton = False
|
||||
except ImportError:
|
||||
import quant
|
||||
is_triton = True
|
||||
from gptq_for_llama import llama_inference_offload
|
||||
from gptq_for_llama.modelutils import find_layers
|
||||
from gptq_for_llama.quant import make_quant
|
||||
|
||||
|
||||
# This function is a replacement for the load_quant function in the
|
||||
|
@ -59,24 +42,21 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
if not is_triton:
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
if 'groupsize' in gptq_args:
|
||||
make_quant_kwargs['groupsize'] = groupsize
|
||||
if 'faster' in gptq_args:
|
||||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
if 'groupsize' in gptq_args:
|
||||
make_quant_kwargs['groupsize'] = groupsize
|
||||
if 'faster' in gptq_args:
|
||||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
else:
|
||||
quant.make_quant_linear(model, layers, wbits, groupsize)
|
||||
make_quant(**make_quant_kwargs)
|
||||
|
||||
del layers
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
|
@ -85,18 +65,6 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
else:
|
||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||
|
||||
if is_triton:
|
||||
if shared.args.quant_attn:
|
||||
quant.make_quant_attn(model)
|
||||
|
||||
if eval and shared.args.fused_mlp:
|
||||
quant.make_fused_mlp(model)
|
||||
|
||||
if shared.args.warmup_autotune:
|
||||
quant.autotune_warmup_linear(model, transpose=not eval)
|
||||
if eval and shared.args.fused_mlp:
|
||||
quant.autotune_warmup_fused(model)
|
||||
|
||||
model.seqlen = 2048
|
||||
return model
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue