Separate context and system message in instruction formats (#4499)

This commit is contained in:
oobabooga 2023-11-07 20:02:58 -03:00 committed by GitHub
parent 322c170566
commit 6e2e0317af
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
61 changed files with 130 additions and 62 deletions

View file

@ -106,6 +106,10 @@ def generate_chat_prompt(user_input, state, **kwargs):
if is_instruct:
context = state['context_instruct']
if state['custom_system_message'].strip() != '':
context = context.replace('<|system-message|>', state['custom_system_message'])
else:
context = context.replace('<|system-message|>', state['system_message'])
else:
context = replace_character_names(
f"{state['context'].strip()}\n",
@ -543,7 +547,7 @@ def generate_pfp_cache(character):
def load_character(character, name1, name2, instruct=False):
context = greeting = turn_template = ""
context = greeting = turn_template = system_message = ""
greeting_field = 'greeting'
picture = None
@ -591,13 +595,11 @@ def load_character(character, name1, name2, instruct=False):
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
if greeting_field in data:
greeting = data[greeting_field]
greeting = data.get(greeting_field, greeting)
turn_template = data.get('turn_template', turn_template)
system_message = data.get('system_message', system_message)
if 'turn_template' in data:
turn_template = data['turn_template']
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n"), system_message
@functools.cache
@ -694,12 +696,13 @@ def generate_character_yaml(name, greeting, context):
return yaml.dump(data, sort_keys=False, width=float("inf"))
def generate_instruction_template_yaml(user, bot, context, turn_template):
def generate_instruction_template_yaml(user, bot, context, turn_template, system_message):
data = {
'user': user,
'bot': bot,
'turn_template': turn_template,
'context': context,
'system_message': system_message,
}
data = {k: v for k, v in data.items() if v} # Strip falsy

View file

@ -55,6 +55,7 @@ settings = {
'character': 'Assistant',
'name1': 'You',
'instruction_template': 'Alpaca',
'custom_system_message': '',
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'autoload_model': False,
'default_extensions': ['gallery'],

View file

@ -157,6 +157,8 @@ def list_interface_input_elements():
'name1_instruct',
'name2_instruct',
'context_instruct',
'system_message',
'custom_system_message',
'turn_template',
'chat_style',
'chat-instruct_command',

View file

@ -112,10 +112,12 @@ def create_chat_settings_ui():
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
shared.gradio['delete_template'] = gr.Button('🗑️ ', elem_classes='refresh-button', interactive=not mu)
shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string')
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string')
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar'])
shared.gradio['custom_system_message'] = gr.Textbox(value=shared.settings['custom_system_message'], lines=2, label='Custom system message', info='If not empty, will be used instead of the default one.', elem_classes=['add_scrollbar'])
shared.gradio['turn_template'] = gr.Textbox(value='', lines=1, label='Turn template', info='Used to precisely define the placement of spaces and new line characters in instruction prompts.', elem_classes=['add_scrollbar'])
shared.gradio['name1_instruct'] = gr.Textbox(value='', lines=2, label='User string', info='Replaces <|user|> in the turn template.')
shared.gradio['name2_instruct'] = gr.Textbox(value='', lines=1, label='Bot string', info='Replaces <|bot|> in the turn template.')
shared.gradio['context_instruct'] = gr.Textbox(value='', lines=4, label='Context', elem_classes=['add_scrollbar'])
shared.gradio['system_message'] = gr.Textbox(value='', lines=2, label='Default system message', info='Replaces <|system-message|> in the context.', elem_classes=['add_scrollbar'])
with gr.Row():
shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button'])
shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button'])
@ -269,7 +271,7 @@ def create_event_handlers():
lambda: None, None, None, _js=f'() => {{{ui.switch_tabs_js}; switch_to_chat()}}')
shared.gradio['character_menu'].change(
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy')).success(
partial(chat.load_character, instruct=False), gradio('character_menu', 'name1', 'name2'), gradio('name1', 'name2', 'character_picture', 'greeting', 'context', 'dummy', 'dummy')).success(
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.load_latest_history, gradio('interface_state'), gradio('history')).then(
chat.redraw_html, gradio(reload_arr), gradio('display')).then(
@ -285,7 +287,7 @@ def create_event_handlers():
shared.gradio['chat_style'].change(chat.redraw_html, gradio(reload_arr), gradio('display'))
shared.gradio['instruction_template'].change(
partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template'))
partial(chat.load_character, instruct=True), gradio('instruction_template', 'name1_instruct', 'name2_instruct'), gradio('name1_instruct', 'name2_instruct', 'dummy', 'dummy', 'context_instruct', 'turn_template', 'system_message'))
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, gradio('history'), gradio('textbox'), show_progress=False)
@ -299,7 +301,7 @@ def create_event_handlers():
shared.gradio['save_template'].click(
lambda: 'My Template.yaml', None, gradio('save_filename')).then(
lambda: 'instruction-templates/', None, gradio('save_root')).then(
chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template'), gradio('save_contents')).then(
chat.generate_instruction_template_yaml, gradio('name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'system_message'), gradio('save_contents')).then(
lambda: gr.update(visible=True), None, gradio('file_saver'))
shared.gradio['delete_template'].click(