Make the code more like PEP8 for readability (#862)
This commit is contained in:
parent
848c4edfd5
commit
ea6e77df72
28 changed files with 302 additions and 165 deletions
|
@ -17,9 +17,11 @@ from quant import make_quant
|
|||
|
||||
|
||||
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
config = AutoConfig.from_pretrained(model)
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
|
@ -34,11 +36,11 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
for name in exclude_layers:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
|
||||
|
||||
gptq_args = inspect.getfullargspec(make_quant).args
|
||||
|
||||
make_quant_kwargs = {
|
||||
'module': model,
|
||||
'module': model,
|
||||
'names': layers,
|
||||
'bits': wbits,
|
||||
}
|
||||
|
@ -48,7 +50,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
make_quant_kwargs['faster'] = faster_kernel
|
||||
if 'kernel_switch_threshold' in gptq_args:
|
||||
make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
|
||||
|
||||
|
||||
make_quant(**make_quant_kwargs)
|
||||
|
||||
del layers
|
||||
|
@ -56,14 +58,15 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
print('Loading model ...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint), strict = False)
|
||||
model.load_state_dict(safe_load(checkpoint), strict=False)
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint), strict = False)
|
||||
model.load_state_dict(torch.load(checkpoint), strict=False)
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
|
@ -114,7 +117,7 @@ def load_quantized(model_name):
|
|||
pt_model = f'{model_name}-{shared.args.wbits}bit'
|
||||
|
||||
# Try to find the .safetensors or .pt both in the model dir and in the subfolder
|
||||
for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
|
||||
if path.exists():
|
||||
print(f"Found {path}")
|
||||
pt_path = path
|
||||
|
@ -133,7 +136,7 @@ def load_quantized(model_name):
|
|||
|
||||
# accelerate offload (doesn't work properly)
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
|
||||
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
|
||||
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue