Separate context and system message in instruction formats (#4499)

This commit is contained in:
oobabooga 2023-11-07 20:02:58 -03:00 committed by GitHub
parent 322c170566
commit 6e2e0317af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
61 changed files with 130 additions and 62 deletions

View file

@ -106,6 +106,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
if is_instruct:
context = state['context_instruct']
if state['custom_system_message'].strip() != '':
context = context.replace('<|system-message|>', state['custom_system_message'])
else:
context = context.replace('<|system-message|>', state['system_message'])
else:
context = replace_character_names(
f"{state['context'].strip()}\n",
@ -543,7 +547,7 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, instruct=False):
context = greeting = turn_template = ""
context = greeting = turn_template = system_message = ""
greeting_field = 'greeting'
picture = None
@ -591,13 +595,11 @@ def load_character(character, name1, name2, instruct=False):
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
if greeting_field in data:
greeting = data[greeting_field]
greeting = data.get(greeting_field, greeting)
turn_template = data.get('turn_template', turn_template)
system_message = data.get('system_message', system_message)
if 'turn_template' in data:
turn_template = data['turn_template']
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n"), system_message
@functools.cache
@ -694,12 +696,13 @@ def generate_character_yaml(name, greeting, context):
return yaml.dump(data, sort_keys=False, width=float("inf"))
def generate_instruction_template_yaml(user, bot, context, turn_template):
def generate_instruction_template_yaml(user, bot, context, turn_template, system_message):
data = {
'user': user,
'bot': bot,
'turn_template': turn_template,
'context': context,
'system_message': system_message,
}
data = {k: v for k, v in data.items() if v} # Strip falsy