Read GGUF metadata (#3873)

This commit is contained in:
oobabooga 2023-09-11 18:49:30 -03:00 committed by GitHub
parent 39f4800d94
commit 9331ab4798
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 154 additions and 52 deletions

View file

@ -3,23 +3,57 @@ from pathlib import Path
import yaml
from modules import loaders, shared, ui
from modules import loaders, metadata_gguf, shared, ui
def get_model_settings_from_yamls(model):
settings = shared.model_config
def get_fallback_settings():
return {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'n_ctx': 2048,
'rope_freq_base': 0,
}
def get_model_metadata(model):
model_settings = {}
# Get settings from models/config.yaml and models/config-user.yaml
settings = shared.model_config
for pat in settings:
if re.match(pat.lower(), model.lower()):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
if 'loader' not in model_settings:
loader = infer_loader(model, model_settings)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
model_settings['loader'] = loader
# Read GGUF metadata
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
path = Path(f'{shared.args.model_dir}/{model}')
if path.is_file():
model_file = path
else:
model_file = list(path.glob('*.gguf'))[0]
metadata = metadata_gguf.load_metadata(model_file)
if 'llama.context_length' in metadata:
model_settings['n_ctx'] = metadata['llama.context_length']
return model_settings
def infer_loader(model_name):
def infer_loader(model_name, model_settings):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
model_settings = get_model_settings_from_yamls(model_name)
if not path_to_model.exists():
loader = None
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
@ -85,11 +119,9 @@ def update_model_parameters(state, initial=False):
# UI: update the state variable with the model settings
def apply_model_settings_to_state(model, state):
model_settings = get_model_settings_from_yamls(model)
if 'loader' not in model_settings:
loader = infer_loader(model)
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
loader = 'AutoGPTQ'
model_settings = get_model_metadata(model)
if 'loader' in model_settings:
loader = model_settings.pop('loader')
# If the user is using an alternative loader for the same model type, let them keep using it
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']):