Read Transformers config.json metadata
This commit is contained in:
parent
9ccaf5eebb
commit
1dd13e4643
2 changed files with 11 additions and 20 deletions
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -15,6 +16,7 @@ def get_fallback_settings():
|
|||
'skip_special_tokens': shared.settings['skip_special_tokens'],
|
||||
'custom_stopping_strings': shared.settings['custom_stopping_strings'],
|
||||
'truncation_length': shared.settings['truncation_length'],
|
||||
'max_seq_len': 2048,
|
||||
'n_ctx': 2048,
|
||||
'rope_freq_base': 0,
|
||||
'compress_pos_emb': 1,
|
||||
|
@ -54,6 +56,15 @@ def get_model_metadata(model):
|
|||
if 'llama.rope.freq_base' in metadata:
|
||||
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
||||
|
||||
# Read transformers metadata. In particular, the sequence length for the model.
|
||||
else:
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
if 'max_position_embeddings' in metadata:
|
||||
model_settings['truncation_length'] = metadata['max_position_embeddings']
|
||||
model_settings['max_seq_len'] = metadata['max_position_embeddings']
|
||||
|
||||
# Apply user settings from models/config-user.yaml
|
||||
settings = shared.user_config
|
||||
for pat in settings:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue