Read more metadata (config.json & quantize_config.json)
This commit is contained in:
parent
56b5a4af74
commit
96da2e1c0d
3 changed files with 59 additions and 67 deletions
|
@ -10,16 +10,16 @@ from modules import loaders, metadata_gguf, shared, ui
|
|||
def get_fallback_settings():
|
||||
return {
|
||||
'wbits': 'None',
|
||||
'model_type': 'None',
|
||||
'groupsize': 'None',
|
||||
'pre_layer': 0,
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'desc_act': False,
|
||||
'model_type': 'None',
|
||||
'max_seq_len': 2048,
|
||||
'n_ctx': 2048,
|
||||
'rope_freq_base': 0,
|
||||
'compress_pos_emb': 1,
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
}
|
||||
|
||||
|
||||
|
@ -56,8 +56,8 @@ def get_model_metadata(model):
|
|||
if 'llama.rope.freq_base' in metadata:
|
||||
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
||||
|
||||
# Read transformers metadata. In particular, the sequence length for the model.
|
||||
else:
|
||||
# Read transformers metadata
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
|
@ -65,6 +65,32 @@ def get_model_metadata(model):
|
|||
model_settings['truncation_length'] = metadata['max_position_embeddings']
|
||||
model_settings['max_seq_len'] = metadata['max_position_embeddings']
|
||||
|
||||
if 'rope_theta' in metadata:
|
||||
model_settings['rope_freq_base'] = metadata['rope_theta']
|
||||
|
||||
if 'rope_scaling' in metadata and type(metadata['rope_scaling']) is dict and all(key in metadata['rope_scaling'] for key in ('type', 'factor')):
|
||||
if metadata['rope_scaling']['type'] == 'linear':
|
||||
model_settings['compress_pos_emb'] = metadata['rope_scaling']['factor']
|
||||
|
||||
if 'quantization_config' in metadata:
|
||||
if 'bits' in metadata['quantization_config']:
|
||||
model_settings['wbits'] = metadata['quantization_config']['bits']
|
||||
if 'group_size' in metadata['quantization_config']:
|
||||
model_settings['groupsize'] = metadata['quantization_config']['group_size']
|
||||
if 'desc_act' in metadata['quantization_config']:
|
||||
model_settings['desc_act'] = metadata['quantization_config']['desc_act']
|
||||
|
||||
# Read AutoGPTQ metadata
|
||||
path = Path(f'{shared.args.model_dir}/{model}/quantize_config.json')
|
||||
if path.exists():
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
if 'bits' in metadata:
|
||||
model_settings['wbits'] = metadata['bits']
|
||||
if 'group_size' in metadata:
|
||||
model_settings['groupsize'] = metadata['group_size']
|
||||
if 'desc_act' in metadata:
|
||||
model_settings['desc_act'] = metadata['desc_act']
|
||||
|
||||
# Apply user settings from models/config-user.yaml
|
||||
settings = shared.user_config
|
||||
for pat in settings:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue