Refactor the UI
A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary. The goal is to make it easier to add new elements to the UI.
This commit is contained in:
parent
64f5c90ee7
commit
0f212093a3
3 changed files with 136 additions and 100 deletions
|
@ -74,16 +74,16 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||
return prompt
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
||||
def extract_message_from_reply(reply, state):
|
||||
next_character_found = False
|
||||
|
||||
if stop_at_newline:
|
||||
if state['stop_at_newline']:
|
||||
lines = reply.split('\n')
|
||||
reply = lines[0].strip()
|
||||
if len(lines) > 1:
|
||||
next_character_found = True
|
||||
else:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||
idx = reply.find(string)
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
|
@ -92,7 +92,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||
# If something like "\nYo" is generated just before "\nYou:"
|
||||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in [f"\n{name1}:", f"\n{name2}:"]:
|
||||
for string in [f"\n{state['name1']}:", f"\n{state['name2']}:"]:
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
|
@ -105,21 +105,18 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
|
|||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
if state['mode'] == 'instruct':
|
||||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
name1_original = name1
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
|
@ -136,28 +133,28 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
|
||||
# Generating the prompt
|
||||
kwargs = {
|
||||
'end_of_turn': end_of_turn,
|
||||
'is_instruct': mode == 'instruct',
|
||||
'end_of_turn': state['end_of_turn'],
|
||||
'is_instruct': state['mode'] == 'instruct',
|
||||
'_continue': _continue
|
||||
}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
|
||||
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
|
@ -171,7 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
shared.history['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
||||
sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
|
||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||
else:
|
||||
|
@ -188,28 +185,25 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
|
|||
yield shared.history['visible']
|
||||
|
||||
|
||||
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
if mode == 'instruct':
|
||||
stopping_strings = [f"\n{name1}", f"\n{name2}"]
|
||||
def impersonate_wrapper(text, state):
|
||||
if state['mode'] == 'instruct':
|
||||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if generate_state['stop_at_newline'] else None
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
name1 = "You"
|
||||
|
||||
prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
|
||||
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
for i in range(generate_state['chat_generation_attempts']):
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
yield reply
|
||||
if next_character_found:
|
||||
break
|
||||
|
@ -220,32 +214,32 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
|
|||
yield reply
|
||||
|
||||
|
||||
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
yield chat_html_wrapper(history, name1, name2, mode)
|
||||
def cai_chatbot_wrapper(text, state):
|
||||
for history in chatbot_wrapper(text, state):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
def regenerate_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
|
||||
for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
|
||||
def continue_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
# Yield ' ...'
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
|
@ -284,7 +278,7 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
||||
|
@ -452,7 +446,7 @@ def load_character(character, name1, name2, mode):
|
|||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue