Jinja templates for Instruct and Chat (#4874)

This commit is contained in:
oobabooga 2023-12-12 17:23:14 -03:00 committed by GitHub
parent aab0dd962d
commit 39d2fe1ed9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
71 changed files with 1774 additions and 518 deletions

View file

@ -4,7 +4,7 @@ from pathlib import Path
import yaml
from modules import loaders, metadata_gguf, shared, ui
from modules import chat, loaders, metadata_gguf, shared, ui
def get_fallback_settings():
@ -33,7 +33,6 @@ def get_model_metadata(model):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
path = Path(f'{shared.args.model_dir}/{model}/config.json')
if path.exists():
hf_metadata = json.loads(open(path, 'r').read())
@ -100,6 +99,31 @@ def get_model_metadata(model):
if 'desc_act' in metadata:
model_settings['desc_act'] = metadata['desc_act']
# Try to find the Jinja instruct template
path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json'
if path.exists():
metadata = json.loads(open(path, 'r').read())
if 'chat_template' in metadata:
template = metadata['chat_template']
for k in ['eos_token', 'bos_token']:
if k in metadata:
value = metadata[k]
if type(value) is dict:
value = value['content']
template = template.replace(k, "'{}'".format(value))
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
model_settings['instruction_template_str'] = template
if 'instruction_template' not in model_settings:
model_settings['instruction_template'] = 'Alpaca'
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
# Ignore rope_freq_base if set to the default value
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
model_settings.pop('rope_freq_base')