Generalize superbooga to chat mode
This commit is contained in:
parent
ec1cda0e1f
commit
6b67cb6611
2 changed files with 65 additions and 17 deletions
|
@ -27,6 +27,11 @@ def replace_all(text, dic):
|
|||
|
||||
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
# Check if an extension is sending its modified history.
|
||||
# If not, use the regular history
|
||||
history = state['history'] if 'history' in state else shared.history['internal']
|
||||
|
||||
# Define some variables
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
|
@ -61,14 +66,14 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
bot_turn_stripped = replace_all(bot_turn.split('<|bot-message|>')[0], replacements)
|
||||
|
||||
# Building the prompt
|
||||
i = len(shared.history['internal']) - 1
|
||||
i = len(history) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, bot_turn_stripped + shared.history['internal'][i][1].strip())
|
||||
if _continue and i == len(history) - 1:
|
||||
rows.insert(1, bot_turn_stripped + history[i][1].strip())
|
||||
else:
|
||||
rows.insert(1, bot_turn.replace('<|bot-message|>', shared.history['internal'][i][1].strip()))
|
||||
rows.insert(1, bot_turn.replace('<|bot-message|>', history[i][1].strip()))
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
string = history[i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, replace_all(user_turn, {'<|user-message|>': string.strip(), '<|round|>': str(i)}))
|
||||
|
||||
|
@ -80,7 +85,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
elif not _continue:
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(shared.history["internal"]))}))
|
||||
rows.append(replace_all(user_turn, {'<|user-message|>': user_input.strip(), '<|round|>': str(len(history))}))
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue