Unify the 3 interface modes (#3554)
This commit is contained in:
parent
bf70c19603
commit
a1a9ec895d
29 changed files with 660 additions and 714 deletions
|
@ -175,9 +175,6 @@ def get_stopping_strings(state):
|
|||
f"\n{state['name2']}:"
|
||||
]
|
||||
|
||||
if state['stop_at_newline']:
|
||||
stopping_strings.append("\n")
|
||||
|
||||
return stopping_strings
|
||||
|
||||
|
||||
|
@ -201,7 +198,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
if not any((regenerate, _continue)):
|
||||
visible_text = text
|
||||
text, visible_text = apply_extensions('chat_input', text, visible_text, state)
|
||||
text = apply_extensions('input', text, state)
|
||||
text = apply_extensions('input', text, state, is_chat=True)
|
||||
|
||||
# *Is typing...*
|
||||
if loading_message:
|
||||
|
@ -230,45 +227,37 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Generate
|
||||
cumulative_reply = ''
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||
reply = cumulative_reply + reply
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||
|
||||
# Extract the reply
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
# Extract the reply
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
if shared.stop_everything:
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
if shared.stop_everything:
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
yield output
|
||||
return
|
||||
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
output['internal'].append(['', ''])
|
||||
output['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
if is_stream:
|
||||
yield output
|
||||
elif not (j == 0 and visible_reply.strip() == ''):
|
||||
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
if is_stream:
|
||||
yield output
|
||||
return
|
||||
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
output['internal'].append(['', ''])
|
||||
output['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
if is_stream:
|
||||
yield output
|
||||
elif not (j == 0 and visible_reply.strip() == ''):
|
||||
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
if is_stream:
|
||||
yield output
|
||||
|
||||
if reply in [None, cumulative_reply]:
|
||||
break
|
||||
else:
|
||||
cumulative_reply = reply
|
||||
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state)
|
||||
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
|
||||
yield output
|
||||
|
||||
|
||||
|
@ -278,27 +267,15 @@ def impersonate_wrapper(text, start_with, state):
|
|||
yield ''
|
||||
return
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
prompt = generate_chat_prompt('', state, impersonate=True)
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
yield text + '...'
|
||||
cumulative_reply = text
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True):
|
||||
reply = cumulative_reply + reply
|
||||
yield reply.lstrip(' ')
|
||||
if shared.stop_everything:
|
||||
return
|
||||
|
||||
if reply in [None, cumulative_reply]:
|
||||
break
|
||||
else:
|
||||
cumulative_reply = reply
|
||||
|
||||
yield cumulative_reply.lstrip(' ')
|
||||
reply = None
|
||||
for reply in generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True):
|
||||
yield reply.lstrip(' ')
|
||||
if shared.stop_everything:
|
||||
return
|
||||
|
||||
|
||||
def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_message=True):
|
||||
|
@ -352,7 +329,7 @@ def replace_last_reply(text, state):
|
|||
return history
|
||||
elif len(history['visible']) > 0:
|
||||
history['visible'][-1][1] = text
|
||||
history['internal'][-1][1] = apply_extensions('input', text, state)
|
||||
history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True)
|
||||
|
||||
return history
|
||||
|
||||
|
@ -360,7 +337,7 @@ def replace_last_reply(text, state):
|
|||
def send_dummy_message(text, state):
|
||||
history = state['history']
|
||||
history['visible'].append([text, ''])
|
||||
history['internal'].append([apply_extensions('input', text, state), ''])
|
||||
history['internal'].append([apply_extensions('input', text, state, is_chat=True), ''])
|
||||
return history
|
||||
|
||||
|
||||
|
@ -371,7 +348,7 @@ def send_dummy_reply(text, state):
|
|||
history['internal'].append(['', ''])
|
||||
|
||||
history['visible'][-1][1] = text
|
||||
history['internal'][-1][1] = apply_extensions('input', text, state)
|
||||
history['internal'][-1][1] = apply_extensions('input', text, state, is_chat=True)
|
||||
return history
|
||||
|
||||
|
||||
|
@ -385,7 +362,7 @@ def clear_chat_log(state):
|
|||
if mode != 'instruct':
|
||||
if greeting != '':
|
||||
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
history['visible'] += [['', apply_extensions('output', greeting, state)]]
|
||||
history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
|
||||
|
||||
return history
|
||||
|
||||
|
@ -452,7 +429,7 @@ def load_persistent_history(state):
|
|||
history = {'internal': [], 'visible': []}
|
||||
if greeting != "":
|
||||
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
history['visible'] += [['', apply_extensions('output', greeting, state)]]
|
||||
history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
|
||||
|
||||
return history
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue