Add QuIP# support (#4803)

It has to be installed manually for now.
This commit is contained in:
oobabooga 2023-12-06 00:01:01 -03:00 committed by GitHub
parent 6430acadde
commit 98361af4d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 8 deletions

View file

@ -33,14 +33,24 @@ def get_model_metadata(model):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
hf_metadata = json.loads(open(path, 'r').read())
else:
hf_metadata = None
if 'loader' not in model_settings:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
if hf_metadata is not None and 'quip_params' in hf_metadata:
model_settings['loader'] = 'QuIP#'
else:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
model_settings['loader'] = loader
model_settings['loader'] = loader
# Read GGUF metadata
# GGUF metadata
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
path = Path(f'{shared.args.model_dir}/{model}')
if path.is_file():
@ -57,9 +67,8 @@ def get_model_metadata(model):
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
else:
# Read transformers metadata
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
# Transformers metadata
if hf_metadata is not None:
metadata = json.loads(open(path, 'r').read())
if 'max_position_embeddings' in metadata:
model_settings['truncation_length'] = metadata['max_position_embeddings']