Implement stopping string search in string space (#2847)
This commit is contained in:
parent
0f9088f730
commit
8bb3bb39b3
4 changed files with 61 additions and 112 deletions
|
@ -1,4 +1,3 @@
|
|||
import ast
|
||||
import base64
|
||||
import copy
|
||||
import functools
|
||||
|
@ -144,40 +143,10 @@ def get_stopping_strings(state):
|
|||
f"\n{state['name2']}:"
|
||||
]
|
||||
|
||||
stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]")
|
||||
return stopping_strings
|
||||
|
||||
|
||||
def extract_message_from_reply(reply, state):
|
||||
next_character_found = False
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
if state['stop_at_newline']:
|
||||
lines = reply.split('\n')
|
||||
reply = lines[0].strip()
|
||||
if len(lines) > 1:
|
||||
next_character_found = True
|
||||
else:
|
||||
for string in stopping_strings:
|
||||
idx = reply.find(string)
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
next_character_found = True
|
||||
stopping_strings.append("\n")
|
||||
|
||||
# If something like "\nYo" is generated just before "\nYou:"
|
||||
# is completed, trim it
|
||||
if not next_character_found:
|
||||
for string in stopping_strings:
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if reply[-j:] == string[:j]:
|
||||
reply = reply[:-j]
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
return reply, next_character_found
|
||||
return stopping_strings
|
||||
|
||||
|
||||
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||
|
@ -191,7 +160,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||
# Defining some variables
|
||||
just_started = True
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Preparing the input
|
||||
|
@ -231,11 +199,10 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||
cumulative_reply = ''
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)):
|
||||
for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)):
|
||||
reply = cumulative_reply + reply
|
||||
|
||||
# Extract the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
|
||||
|
@ -262,9 +229,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa
|
|||
if state['stream']:
|
||||
yield output
|
||||
|
||||
if next_character_found:
|
||||
break
|
||||
|
||||
if reply in [None, cumulative_reply]:
|
||||
break
|
||||
else:
|
||||
|
@ -281,7 +245,6 @@ def impersonate_wrapper(text, start_with, state):
|
|||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
prompt = generate_chat_prompt('', state, impersonate=True)
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
|
@ -289,16 +252,12 @@ def impersonate_wrapper(text, start_with, state):
|
|||
cumulative_reply = text
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True):
|
||||
for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True):
|
||||
reply = cumulative_reply + reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
yield reply.lstrip(' ')
|
||||
if shared.stop_everything:
|
||||
return
|
||||
|
||||
if next_character_found:
|
||||
break
|
||||
|
||||
if reply in [None, cumulative_reply]:
|
||||
break
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue