Jinja templates for Instruct and Chat (#4874)

This commit is contained in:
oobabooga 2023-12-12 17:23:14 -03:00 committed by GitHub
parent aab0dd962d
commit 39d2fe1ed9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
71 changed files with 1774 additions and 518 deletions

View file

@ -5,10 +5,12 @@ import html
import json
import re
from datetime import datetime
from functools import partial
from pathlib import Path
import gradio as gr
import yaml
from jinja2.sandbox import ImmutableSandboxedEnvironment
from PIL import Image
import modules.shared as shared
@ -20,12 +22,10 @@ from modules.text_generation import (
get_encoded_length,
get_max_prompt_length
)
from modules.utils import (
delete_file,
get_available_characters,
replace_all,
save_file
)
from modules.utils import delete_file, get_available_characters, save_file
# Copied from the Transformers library
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
def str_presenter(dumper, data):
@ -44,31 +44,34 @@ yaml.add_representer(str, str_presenter)
yaml.representer.SafeRepresenter.add_representer(str, str_presenter)
def get_turn_substrings(state, instruct=False):
if instruct:
if 'turn_template' not in state or state['turn_template'] == '':
template = '<|user|>\n<|user-message|>\n<|bot|>\n<|bot-message|>\n'
else:
template = state['turn_template'].replace(r'\n', '\n')
def get_generation_prompt(renderer, impersonate=False, strip_trailing_spaces=True):
'''
Given a Jinja template, reverse-engineers the prefix and the suffix for
an assistant message (if impersonate=False) or an user message
(if impersonate=True)
'''
if impersonate:
messages = [
{"role": "user", "content": "<<|user-message-1|>>"},
{"role": "user", "content": "<<|user-message-2|>>"},
]
else:
template = '<|user|>: <|user-message|>\n<|bot|>: <|bot-message|>\n'
messages = [
{"role": "assistant", "content": "<<|user-message-1|>>"},
{"role": "assistant", "content": "<<|user-message-2|>>"},
]
replacements = {
'<|user|>': state['name1_instruct' if instruct else 'name1'].strip(),
'<|bot|>': state['name2_instruct' if instruct else 'name2'].strip(),
}
prompt = renderer(messages=messages)
output = {
'user_turn': template.split('<|bot|>')[0],
'bot_turn': '<|bot|>' + template.split('<|bot|>')[1],
'user_turn_stripped': template.split('<|bot|>')[0].split('<|user-message|>')[0],
'bot_turn_stripped': '<|bot|>' + template.split('<|bot|>')[1].split('<|bot-message|>')[0],
}
suffix_plus_prefix = prompt.split("<<|user-message-1|>>")[1].split("<<|user-message-2|>>")[0]
suffix = prompt.split("<<|user-message-2|>>")[1]
prefix = suffix_plus_prefix[len(suffix):]
for k in output:
output[k] = replace_all(output[k], replacements)
if strip_trailing_spaces:
prefix = prefix.rstrip(' ')
return output
return prefix, suffix
def generate_chat_prompt(user_input, state, **kwargs):
@ -76,121 +79,130 @@ def generate_chat_prompt(user_input, state, **kwargs):
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = kwargs.get('history', state['history'])['internal']
is_instruct = state['mode'] == 'instruct'
# Find the maximum prompt size
max_length = get_max_prompt_length(state)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
'instruct': get_turn_substrings(state, instruct=True)
}
# Templates
chat_template = jinja_env.from_string(state['chat_template_str'])
instruction_template = jinja_env.from_string(state['instruction_template_str'])
chat_renderer = partial(chat_template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
instruct_renderer = partial(instruction_template.render, add_generation_prompt=False)
substrings = all_substrings['instruct' if is_instruct else 'chat']
messages = []
# Create the template for "chat-instruct" mode
if state['mode'] == 'chat-instruct':
wrapper = ''
command = state['chat-instruct_command'].replace('<|character|>', state['name2'] if not impersonate else state['name1'])
context_instruct = state['context_instruct']
if state['mode'] == 'instruct':
renderer = instruct_renderer
if state['custom_system_message'].strip() != '':
context_instruct = context_instruct.replace('<|system-message|>', state['custom_system_message'])
else:
context_instruct = context_instruct.replace('<|system-message|>', state['system_message'])
wrapper += context_instruct
wrapper += all_substrings['instruct']['user_turn'].replace('<|user-message|>', command)
wrapper += all_substrings['instruct']['bot_turn_stripped']
if impersonate:
wrapper += substrings['user_turn_stripped'].rstrip(' ')
elif _continue:
wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'], state)
wrapper += history[-1][1]
else:
wrapper += apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state)
messages.append({"role": "system", "content": state['custom_system_message']})
else:
wrapper = '<|prompt|>'
renderer = chat_renderer
if state['context'].strip() != '':
messages.append({"role": "system", "content": state['context']})
if is_instruct:
context = state['context_instruct']
if state['custom_system_message'].strip() != '':
context = context.replace('<|system-message|>', state['custom_system_message'])
insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history):
user_msg = user_msg.strip()
assistant_msg = assistant_msg.strip()
if assistant_msg:
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
if user_msg not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
messages.insert(insert_pos, {"role": "user", "content": user_msg})
user_input = user_input.strip()
if user_input and not impersonate and not _continue:
messages.append({"role": "user", "content": user_input})
def make_prompt(messages):
if state['mode'] == 'chat-instruct' and _continue:
prompt = renderer(messages=messages[:-1])
else:
context = context.replace('<|system-message|>', state['system_message'])
else:
context = replace_character_names(
f"{state['context'].strip()}\n",
state['name1'],
state['name2']
)
prompt = renderer(messages=messages)
# Build the prompt
rows = [context]
min_rows = 3
i = len(history) - 1
while i >= 0 and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) < max_length:
if _continue and i == len(history) - 1:
if state['mode'] != 'chat-instruct':
rows.insert(1, substrings['bot_turn_stripped'] + history[i][1].strip())
else:
rows.insert(1, substrings['bot_turn'].replace('<|bot-message|>', history[i][1].strip()))
string = history[i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, replace_all(substrings['user_turn'], {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
i -= 1
if impersonate:
if state['mode'] == 'chat-instruct':
min_rows = 1
outer_messages = []
if state['custom_system_message'].strip() != '':
outer_messages.append({"role": "system", "content": state['custom_system_message']})
command = state['chat-instruct_command']
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
command = command.replace('<|prompt|>', prompt)
if _continue:
prefix = get_generation_prompt(renderer, impersonate=impersonate, strip_trailing_spaces=False)[0]
prefix += messages[-1]["content"]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
outer_messages.append({"role": "user", "content": command})
outer_messages.append({"role": "assistant", "content": prefix})
prompt = instruction_template.render(messages=outer_messages)
suffix = get_generation_prompt(instruct_renderer, impersonate=False)[1]
prompt = prompt[:-len(suffix)]
else:
min_rows = 2
rows.append(substrings['user_turn_stripped'].rstrip(' '))
elif not _continue:
# Add the user message
if len(user_input) > 0:
rows.append(replace_all(substrings['user_turn'], {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
if _continue:
suffix = get_generation_prompt(renderer, impersonate=impersonate)[1]
prompt = prompt[:-len(suffix)]
else:
prefix = get_generation_prompt(renderer, impersonate=impersonate)[0]
if state['mode'] == 'chat' and not impersonate:
prefix = apply_extensions('bot_prefix', prefix, state)
# Add the character prefix
if state['mode'] != 'chat-instruct':
rows.append(apply_extensions('bot_prefix', substrings['bot_turn_stripped'].rstrip(' '), state))
prompt += prefix
while len(rows) > min_rows and get_encoded_length(wrapper.replace('<|prompt|>', ''.join(rows))) >= max_length:
rows.pop(1)
return prompt
prompt = make_prompt(messages)
# Handle truncation
max_length = get_max_prompt_length(state)
while len(messages) > 0 and get_encoded_length(prompt) > max_length:
# Try to save the system message
if len(messages) > 1 and messages[0]['role'] == 'system':
messages.pop(1)
else:
messages.pop(0)
prompt = make_prompt(messages)
prompt = wrapper.replace('<|prompt|>', ''.join(rows))
if also_return_rows:
return prompt, rows
return prompt, [message['content'] for message in messages]
else:
return prompt
def get_stopping_strings(state):
stopping_strings = []
renderers = []
if state['mode'] in ['instruct', 'chat-instruct']:
stopping_strings += [
state['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0] + '<|bot|>',
state['turn_template'].split('<|bot-message|>')[1] + '<|user|>'
]
replacements = {
'<|user|>': state['name1_instruct'],
'<|bot|>': state['name2_instruct']
}
for i in range(len(stopping_strings)):
stopping_strings[i] = replace_all(stopping_strings[i], replacements).rstrip(' ').replace(r'\n', '\n')
template = jinja_env.from_string(state['instruction_template_str'])
renderer = partial(template.render, add_generation_prompt=False)
renderers.append(renderer)
if state['mode'] in ['chat', 'chat-instruct']:
template = jinja_env.from_string(state['chat_template_str'])
renderer = partial(template.render, add_generation_prompt=False, name1=state['name1'], name2=state['name2'])
renderers.append(renderer)
for renderer in renderers:
prefix_bot, suffix_bot = get_generation_prompt(renderer, impersonate=False)
prefix_user, suffix_user = get_generation_prompt(renderer, impersonate=True)
stopping_strings += [
f"\n{state['name1']}:",
f"\n{state['name2']}:"
suffix_user + prefix_bot,
suffix_user + prefix_user,
suffix_bot + prefix_bot,
suffix_bot + prefix_user,
]
if 'stopping_strings' in state and isinstance(state['stopping_strings'], list):
stopping_strings += state.pop('stopping_strings')
return stopping_strings
return list(set(stopping_strings))
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True):
@ -556,32 +568,26 @@ def generate_pfp_cache(character):
return None
def load_character(character, name1, name2, instruct=False):
context = greeting = turn_template = system_message = ""
def load_character(character, name1, name2):
context = greeting = ""
greeting_field = 'greeting'
picture = None
if instruct:
name1 = name2 = ''
folder = 'instruction-templates'
else:
folder = 'characters'
filepath = None
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
filepath = Path(f'characters/{character}.{extension}')
if filepath.exists():
break
if filepath is None or not filepath.exists():
logger.error(f"Could not find the character \"{character}\" inside {folder}/. No character has been loaded.")
logger.error(f"Could not find the character \"{character}\" inside characters/. No character has been loaded.")
raise ValueError
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
for path in [Path("cache/pfp_character.png"), Path("cache/pfp_character_thumb.png")]:
if path.exists() and not instruct:
if path.exists():
path.unlink()
picture = generate_pfp_cache(character)
@ -599,23 +605,38 @@ def load_character(character, name1, name2, instruct=False):
break
if 'context' in data:
context = data['context']
if not instruct:
context = context.strip() + '\n'
context = data['context'].strip()
elif "char_persona" in data:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
greeting = data.get(greeting_field, greeting)
turn_template = data.get('turn_template', turn_template)
system_message = data.get('system_message', system_message)
return name1, name2, picture, greeting, context
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n"), system_message
def load_instruction_template(template):
for filepath in [Path(f'instruction-templates/{template}.yaml'), Path('instruction-templates/Alpaca.yaml')]:
if filepath.exists():
break
else:
return ''
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = yaml.safe_load(file_contents)
if 'instruction_template' in data:
return data['instruction_template']
else:
return jinja_template_from_old_format(data)
@functools.cache
def load_character_memoized(character, name1, name2, instruct=False):
return load_character(character, name1, name2, instruct=instruct)
def load_character_memoized(character, name1, name2):
return load_character(character, name1, name2)
@functools.cache
def load_instruction_template_memoized(template):
return load_instruction_template(template)
def upload_character(file, img, tavern=False):
@ -707,17 +728,12 @@ def generate_character_yaml(name, greeting, context):
return yaml.dump(data, sort_keys=False, width=float("inf"))
def generate_instruction_template_yaml(user, bot, context, turn_template, system_message):
def generate_instruction_template_yaml(instruction_template):
data = {
'user': user,
'bot': bot,
'turn_template': turn_template,
'context': context,
'system_message': system_message,
'instruction_template': instruction_template
}
data = {k: v for k, v in data.items() if v} # Strip falsy
return yaml.dump(data, sort_keys=False, width=float("inf"))
return my_yaml_output(data)
def save_character(name, greeting, context, picture, filename):
@ -739,3 +755,95 @@ def delete_character(name, instruct=False):
delete_file(Path(f'characters/{name}.{extension}'))
delete_file(Path(f'characters/{name}.png'))
def jinja_template_from_old_format(params, verbose=False):
MASTER_TEMPLATE = """
{%- set found_item = false -%}
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
{%- set found_item = true -%}
{%- endif -%}
{%- endfor -%}
{%- if not found_item -%}
{{- '<|PRE-SYSTEM|>' + '<|SYSTEM-MESSAGE|>' + '<|POST-SYSTEM|>' -}}
{%- endif %}
{%- for message in messages %}
{%- if message['role'] == 'system' -%}
{{- '<|PRE-SYSTEM|>' + message['content'] + '<|POST-SYSTEM|>' -}}
{%- else -%}
{%- if message['role'] == 'user' -%}
{{-'<|PRE-USER|>' + message['content'] + '<|POST-USER|>'-}}
{%- else -%}
{{-'<|PRE-ASSISTANT|>' + message['content'] + '<|POST-ASSISTANT|>' -}}
{%- endif -%}
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
{{-'<|PRE-ASSISTANT-GENERATE|>'-}}
{%- endif -%}
"""
if 'context' in params and '<|system-message|>' in params['context']:
pre_system = params['context'].split('<|system-message|>')[0]
post_system = params['context'].split('<|system-message|>')[1]
else:
pre_system = ''
post_system = ''
pre_user = params['turn_template'].split('<|user-message|>')[0].replace('<|user|>', params['user'])
post_user = params['turn_template'].split('<|user-message|>')[1].split('<|bot|>')[0]
pre_assistant = '<|bot|>' + params['turn_template'].split('<|bot-message|>')[0].split('<|bot|>')[1]
pre_assistant = pre_assistant.replace('<|bot|>', params['bot'])
post_assistant = params['turn_template'].split('<|bot-message|>')[1]
pre_system = pre_system.replace('\n', '\\n')
post_system = post_system.replace('\n', '\\n')
pre_user = pre_user.replace('\n', '\\n')
post_user = post_user.replace('\n', '\\n')
pre_assistant = pre_assistant.replace('\n', '\\n')
post_assistant = post_assistant.replace('\n', '\\n')
if verbose:
print(
'\n',
repr(pre_system) + '\n',
repr(post_system) + '\n',
repr(pre_user) + '\n',
repr(post_user) + '\n',
repr(pre_assistant) + '\n',
repr(post_assistant) + '\n',
)
result = MASTER_TEMPLATE
if 'system_message' in params:
result = result.replace('<|SYSTEM-MESSAGE|>', params['system_message'].replace('\n', '\\n'))
else:
result = result.replace('<|SYSTEM-MESSAGE|>', '')
result = result.replace('<|PRE-SYSTEM|>', pre_system)
result = result.replace('<|POST-SYSTEM|>', post_system)
result = result.replace('<|PRE-USER|>', pre_user)
result = result.replace('<|POST-USER|>', post_user)
result = result.replace('<|PRE-ASSISTANT|>', pre_assistant)
result = result.replace('<|PRE-ASSISTANT-GENERATE|>', pre_assistant.strip())
result = result.replace('<|POST-ASSISTANT|>', post_assistant)
result = result.strip()
return result
def my_yaml_output(data):
'''
pyyaml is very inconsistent with multiline strings.
for simple instruction template outputs, this is enough.
'''
result = ""
for k in data:
result += k + ": |-\n"
for line in data[k].splitlines():
result += " " + line.rstrip(' ') + "\n"
return result