Read more metadata (config.json & quantize_config.json)

This commit is contained in:
oobabooga 2023-09-29 06:14:16 -07:00
parent 56b5a4af74
commit 96da2e1c0d
3 changed files with 59 additions and 67 deletions

View file

@ -10,16 +10,16 @@ from modules import loaders, metadata_gguf, shared, ui
def get_fallback_settings():
return {
'wbits': 'None',
'model_type': 'None',
'groupsize': 'None',
'pre_layer': 0,
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'desc_act': False,
'model_type': 'None',
'max_seq_len': 2048,
'n_ctx': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
'truncation_length': shared.settings['truncation_length'],
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
}
@ -56,8 +56,8 @@ def get_model_metadata(model):
if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
# Read transformers metadata. In particular, the sequence length for the model.
else:
# Read transformers metadata
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
@ -65,6 +65,32 @@ def get_model_metadata(model):
model_settings['truncation_length'] = metadata['max_position_embeddings']
model_settings['max_seq_len'] = metadata['max_position_embeddings']
if 'rope_theta' in metadata:
model_settings['rope_freq_base'] = metadata['rope_theta']
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
if metadata['rope_scaling']['type'] == 'linear':
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
if 'quantization_config' in metadata:
if 'bits' in metadata['quantization_config']:
model_settings['wbits'] = metadata['quantization_config']['bits']
if 'group_size' in metadata['quantization_config']:
model_settings['groupsize'] = metadata['quantization_config']['group_size']
if 'desc_act' in metadata['quantization_config']:
model_settings['desc_act'] = metadata['quantization_config']['desc_act']
# Read AutoGPTQ metadata
path = Path(f'{shared.args.model_dir}/{model}/quantize_config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'bits' in metadata:
model_settings['wbits'] = metadata['bits']
if 'group_size' in metadata:
model_settings['groupsize'] = metadata['group_size']
if 'desc_act' in metadata:
model_settings['desc_act'] = metadata['desc_act']
# Apply user settings from models/config-user.yaml
settings = shared.user_config
for pat in settings: