Change how spaces are handled in continue/generation attempts
This commit is contained in:
parent
2eeb27659d
commit
e283ddc559
4 changed files with 13 additions and 12 deletions
|
@ -185,18 +185,13 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
# Generate
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
if _continue:
|
||||
sep = ' ' if last_reply[0][-1] not in [' ', '\n'] else ''
|
||||
reply = last_reply[0] + sep + reply
|
||||
sep = ' ' if last_reply[1][-1] not in [' ', '\n'] else ''
|
||||
visible_reply = last_reply[1] + sep + visible_reply
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
|
@ -209,7 +204,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
|
||||
if not (j == 0 and visible_reply.strip() == ''):
|
||||
if _continue:
|
||||
shared.history['internal'][-1] = [text, last_reply[0] + reply]
|
||||
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
yield shared.history['visible']
|
||||
elif not (j == 0 and visible_reply.strip() == ''):
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
yield shared.history['visible']
|
||||
|
@ -217,7 +216,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
if next_character_found:
|
||||
break
|
||||
|
||||
if reply is not None:
|
||||
if reply in [None, '']:
|
||||
break
|
||||
else:
|
||||
cumulative_reply = reply
|
||||
|
||||
yield shared.history['visible']
|
||||
|
@ -239,7 +240,7 @@ def impersonate_wrapper(text, state):
|
|||
cumulative_reply = text
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True):
|
||||
for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
yield reply
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue