Read GGUF metadata (#3873)
This commit is contained in:
parent
39f4800d94
commit
9331ab4798
8 changed files with 154 additions and 52 deletions
|
@ -3,23 +3,57 @@ from pathlib import Path
|
|||
|
||||
import yaml
|
||||
|
||||
from modules import loaders, shared, ui
|
||||
from modules import loaders, metadata_gguf, shared, ui
|
||||
|
||||
|
||||
def get_model_settings_from_yamls(model):
|
||||
settings = shared.model_config
|
||||
def get_fallback_settings():
|
||||
return {
|
||||
'wbits': 'None',
|
||||
'model_type': 'None',
|
||||
'groupsize': 'None',
|
||||
'pre_layer': 0,
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'n_ctx': 2048,
|
||||
'rope_freq_base': 0,
|
||||
}
|
||||
|
||||
|
||||
def get_model_metadata(model):
|
||||
model_settings = {}
|
||||
|
||||
# Get settings from models/config.yaml and models/config-user.yaml
|
||||
settings = shared.model_config
|
||||
for pat in settings:
|
||||
if re.match(pat.lower(), model.lower()):
|
||||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model, model_settings)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
|
||||
model_settings['loader'] = loader
|
||||
|
||||
# Read GGUF metadata
|
||||
if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
path = Path(f'{shared.args.model_dir}/{model}')
|
||||
if path.is_file():
|
||||
model_file = path
|
||||
else:
|
||||
model_file = list(path.glob('*.gguf'))[0]
|
||||
|
||||
metadata = metadata_gguf.load_metadata(model_file)
|
||||
if 'llama.context_length' in metadata:
|
||||
model_settings['n_ctx'] = metadata['llama.context_length']
|
||||
|
||||
return model_settings
|
||||
|
||||
|
||||
def infer_loader(model_name):
|
||||
def infer_loader(model_name, model_settings):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
model_settings = get_model_settings_from_yamls(model_name)
|
||||
if not path_to_model.exists():
|
||||
loader = None
|
||||
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or ('wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0):
|
||||
|
@ -85,11 +119,9 @@ def update_model_parameters(state, initial=False):
|
|||
|
||||
# UI: update the state variable with the model settings
|
||||
def apply_model_settings_to_state(model, state):
|
||||
model_settings = get_model_settings_from_yamls(model)
|
||||
if 'loader' not in model_settings:
|
||||
loader = infer_loader(model)
|
||||
if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0:
|
||||
loader = 'AutoGPTQ'
|
||||
model_settings = get_model_metadata(model)
|
||||
if 'loader' in model_settings:
|
||||
loader = model_settings.pop('loader')
|
||||
|
||||
# If the user is using an alternative loader for the same model type, let them keep using it
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama', 'ExLlama_HF']) and not (loader == 'llama.cpp' and state['loader'] in ['llamacpp_HF', 'ctransformers']):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue