From fcda3f87767e642d1c0411776e549e1d3894843d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 1 Apr 2023 01:12:13 -0300 Subject: [PATCH] Add also_return_rows to generate_chat_prompt --- modules/chat.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index cc3c45c..db79e7d 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -22,7 +22,7 @@ def generate_chat_output(history, name1, name2, character): else: return history -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False): user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] @@ -51,7 +51,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat rows.pop(1) prompt = ''.join(rows) - return prompt + + if also_return_rows: + return prompt, rows + else: + return prompt def extract_message_from_reply(reply, name1, name2, check): next_character_found = False