Read Transformers config.json metadata

This commit is contained in:
oobabooga 2023-09-28 19:19:47 -07:00
parent 9ccaf5eebb
commit 1dd13e4643
2 changed files with 11 additions and 20 deletions

View file

@ -1,3 +1,4 @@
import json
import re
from pathlib import Path
@ -15,6 +16,7 @@ def get_fallback_settings():
'skip_special_tokens': shared.settings['skip_special_tokens'],
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
'truncation_length': shared.settings['truncation_length'],
'max_seq_len': 2048,
'n_ctx': 2048,
'rope_freq_base': 0,
'compress_pos_emb': 1,
@ -54,6 +56,15 @@ def get_model_metadata(model):
if 'llama.rope.freq_base' in metadata:
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
# Read transformers metadata. In particular, the sequence length for the model.
else:
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'max_position_embeddings' in metadata:
model_settings['truncation_length'] = metadata['max_position_embeddings']
model_settings['max_seq_len'] = metadata['max_position_embeddings']
# Apply user settings from models/config-user.yaml
settings = shared.user_config
for pat in settings: