Generalize superbooga to chat mode

This commit is contained in:
oobabooga 2023-05-07 15:01:14 -03:00
parent ec1cda0e1f
commit 6b67cb6611
2 changed files with 65 additions and 17 deletions

View file

@ -27,6 +27,11 @@ def replace_all(text, dic):
def generate_chat_prompt(user_input, state, **kwargs):
# Check if an extension is sending its modified history.
# If not, use the regular history
history = state['history'] if 'history' in state else shared.history['internal']
# Define some variables
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
# Building the prompt
i = len(shared.history['internal']) - 1
i = len(history) - 1
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
if _continue and i == len(history) - 1:
rows.insert(1, bot_turn_stripped + history[i][1].strip())
else:
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
string = shared.history['internal'][i][0]
string = history[i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
elif not _continue:
# Adding the user message
if len(user_input) > 0:
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))