Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View file

@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs.get('impersonate', False)
_continue = kwargs.get('_continue', False)
also_return_rows = kwargs.get('also_return_rows', False)
history = state.get('history', shared.history['internal'])
history = kwargs.get('history', shared.history)['internal']
is_instruct = state['mode'] == 'instruct'
# Finding the maximum prompt size
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False),
'instruct': get_turn_substrings(state, instruct=True)
}
substrings = all_substrings['instruct' if is_instruct else 'chat']
# Creating the template for "chat-instruct" mode
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
return reply, next_character_found
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
output = copy.deepcopy(history)
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield shared.history['visible']
yield output
return
# Defining some variables
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
text = apply_extensions('input', text)
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
else:
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate:
shared.history['visible'].pop()
shared.history['internal'].pop()
output['visible'].pop()
output['internal'].pop()
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
if loading_message:
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
elif _continue:
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
if loading_message:
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
# Generating the prompt
kwargs = {'_continue': _continue}
kwargs = {
'_continue': _continue,
'history': output,
}
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs)
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
if shared.stop_everything:
return shared.history['visible']
yield output
return
if just_started:
just_started = False
if not _continue:
shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', ''])
output['internal'].append(['', ''])
output['visible'].append(['', ''])
if _continue:
shared.history['internal'][-1] = [text, last_reply[0] + reply]
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
yield shared.history['visible']
output['internal'][-1] = [text, last_reply[0] + reply]
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
yield output
elif not (j == 0 and visible_reply.strip() == ''):
shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply]
yield shared.history['visible']
output['internal'][-1] = [text, reply.lstrip(' ')]
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
yield output
if next_character_found:
break
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
else:
cumulative_reply = reply
yield shared.history['visible']
yield output
def impersonate_wrapper(text, state):
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
yield cumulative_reply
def generate_chat_reply(text, state, regenerate=False, _continue=False):
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
if regenerate or _continue:
text = ''
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield shared.history['visible']
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
yield history
return
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue):
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
yield history
# Same as above but returns HTML
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
for history in generate_chat_reply(text, state, regenerate, _continue):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'])
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
if i != 0:
shared.history = copy.deepcopy(history)
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
def remove_last_message():