Add support for characters

This commit is contained in:
oobabooga 2023-01-19 16:46:46 -03:00
parent 3121f4788e
commit 8d788874d7
4 changed files with 84 additions and 40 deletions

View file

@ -35,6 +35,7 @@ args = parser.parse_args()
loaded_preset = None
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower)
available_characters = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('characters').glob('*.json'))), key=str.lower)
settings = {
'max_new_tokens': 200,
@ -50,6 +51,7 @@ settings = {
'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'stop_at_newline': True,
'stop_at_newline_pygmalion': False,
}
if args.settings is not None and Path(args.settings).exists():
@ -217,6 +219,7 @@ description = f"\n\n# Text generation lab\nGenerate text using Large Language Mo
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
if args.chat or args.cai_chat:
history = []
character = None
# This gets the new line characters right.
def clean_chat_message(text):
@ -284,12 +287,12 @@ if args.chat or args.cai_chat:
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
yield generate_chat_html(history, name1, name2)
yield generate_chat_html(history, name1, name2, character)
def remove_last_message(name1, name2):
history.pop()
if args.cai_chat:
return generate_chat_html(history, name1, name2)
return generate_chat_html(history, name1, name2, character)
else:
return history
@ -298,11 +301,11 @@ if args.chat or args.cai_chat:
history = []
def clear_html():
return generate_chat_html([], "", "")
return generate_chat_html([], "", "", character)
def redraw_html(name1, name2):
global history
return generate_chat_html(history, name1, name2)
return generate_chat_html(history, name1, name2, character)
def save_history():
if not Path('logs').exists():
@ -315,18 +318,43 @@ if args.chat or args.cai_chat:
global history
history = json.loads(file.decode('utf-8'))['data']
if 'pygmalion' in model_name.lower():
context_str = settings['context_pygmalion']
name1_str = settings['name1_pygmalion']
name2_str = settings['name2_pygmalion']
else:
context_str = settings['context']
name1_str = settings['name1']
name2_str = settings['name2']
def load_character(_character, name1, name2):
global history, character
context = ""
history = []
if _character != 'None':
character = _character
with open(Path(f'characters/{_character}.json'), 'r') as f:
data = json.loads(f.read())
name2 = data['char_name']
if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
if 'world_scenario' in data and data['world_scenario'] != '':
context += f"Scenario: {data['world_scenario']}\n"
if 'example_dialogue' in data and data['example_dialogue'] != '':
context += f"{data['example_dialogue']}"
context = f"{context.strip()}\n<START>"
if 'char_greeting' in data:
history = [['', data['char_greeting']]]
else:
character = None
context = settings['context_pygmalion']
name2 = settings['name2_pygmalion']
if args.cai_chat:
return name2, context, generate_chat_html(history, name1, name2, character)
else:
return name2, context, history
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
context_str = settings[f'context{suffix}']
name1_str = settings[f'name1{suffix}']
name2_str = settings[f'name2{suffix}']
stop_at_newline = settings[f'stop_at_newline{suffix}']
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
if args.cai_chat:
display1 = gr.HTML(value=generate_chat_html([], "", ""))
display1 = gr.HTML(value=generate_chat_html([], "", "", character))
else:
display1 = gr.Chatbot()
textbox = gr.Textbox(lines=2, label='Input')
@ -347,7 +375,9 @@ if args.chat or args.cai_chat:
name2 = gr.Textbox(value=name2_str, lines=1, label='Bot\'s name')
context = gr.Textbox(value=context_str, lines=2, label='Context')
with gr.Row():
check = gr.Checkbox(value=settings['stop_at_newline'], label='Stop generating at new line character?')
character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character')
with gr.Row():
check = gr.Checkbox(value=stop_at_newline, label='Stop generating at new line character?')
with gr.Row():
with gr.Column():
gr.Markdown("Upload chat history")
@ -371,9 +401,10 @@ if args.chat or args.cai_chat:
btn.click(lambda x: "", textbox, textbox, show_progress=False)
textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
stop.click(None, None, None, cancels=[gen_event, gen_event2])
save_btn.click(save_history, inputs=[], outputs=[download])
upload.upload(load_history, [upload], [])
character_menu.change(load_character, [character_menu, name1, name2], [name2, context, display1])
if args.cai_chat:
upload.upload(redraw_html, [name1, name2], [display1])
else: