Jinja templates for Instruct and Chat (#4874)
This commit is contained in:
parent
aab0dd962d
commit
39d2fe1ed9
71 changed files with 1774 additions and 518 deletions
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
|
||||
import yaml
|
||||
|
||||
from modules import loaders, metadata_gguf, shared, ui
|
||||
from modules import chat, loaders, metadata_gguf, shared, ui
|
||||
|
||||
|
||||
def get_fallback_settings():
|
||||
|
@ -33,7 +33,6 @@ def get_model_metadata(model):
|
|||
for k in settings[pat]:
|
||||
model_settings[k] = settings[pat][k]
|
||||
|
||||
|
||||
path = Path(f'{shared.args.model_dir}/{model}/config.json')
|
||||
if path.exists():
|
||||
hf_metadata = json.loads(open(path, 'r').read())
|
||||
|
@ -100,6 +99,31 @@ def get_model_metadata(model):
|
|||
if 'desc_act' in metadata:
|
||||
model_settings['desc_act'] = metadata['desc_act']
|
||||
|
||||
# Try to find the Jinja instruct template
|
||||
path = Path(f'{shared.args.model_dir}/{model}') / 'tokenizer_config.json'
|
||||
if path.exists():
|
||||
metadata = json.loads(open(path, 'r').read())
|
||||
if 'chat_template' in metadata:
|
||||
template = metadata['chat_template']
|
||||
for k in ['eos_token', 'bos_token']:
|
||||
if k in metadata:
|
||||
value = metadata[k]
|
||||
if type(value) is dict:
|
||||
value = value['content']
|
||||
|
||||
template = template.replace(k, "'{}'".format(value))
|
||||
|
||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||
|
||||
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
|
||||
model_settings['instruction_template_str'] = template
|
||||
|
||||
if 'instruction_template' not in model_settings:
|
||||
model_settings['instruction_template'] = 'Alpaca'
|
||||
|
||||
if model_settings['instruction_template'] != 'Custom (obtained from model metadata)':
|
||||
model_settings['instruction_template_str'] = chat.load_instruction_template(model_settings['instruction_template'])
|
||||
|
||||
# Ignore rope_freq_base if set to the default value
|
||||
if 'rope_freq_base' in model_settings and model_settings['rope_freq_base'] == 10000:
|
||||
model_settings.pop('rope_freq_base')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue