Add chat API (#2233)
This commit is contained in:
parent
2aa01e2303
commit
c5af549d4b
8 changed files with 317 additions and 67 deletions
|
@ -50,7 +50,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
impersonate = kwargs.get('impersonate', False)
|
||||
_continue = kwargs.get('_continue', False)
|
||||
also_return_rows = kwargs.get('also_return_rows', False)
|
||||
history = state.get('history', shared.history['internal'])
|
||||
history = kwargs.get('history', shared.history)['internal']
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
|
||||
# Finding the maximum prompt size
|
||||
|
@ -59,11 +59,11 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
all_substrings = {
|
||||
'chat': get_turn_substrings(state, instruct=False),
|
||||
'instruct': get_turn_substrings(state, instruct=True)
|
||||
}
|
||||
|
||||
substrings = all_substrings['instruct' if is_instruct else 'chat']
|
||||
|
||||
# Creating the template for "chat-instruct" mode
|
||||
|
@ -179,10 +179,11 @@ def extract_message_from_reply(reply, state):
|
|||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||
output = copy.deepcopy(history)
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield shared.history['visible']
|
||||
yield output
|
||||
return
|
||||
|
||||
# Defining some variables
|
||||
|
@ -200,20 +201,27 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
|
||||
text = apply_extensions('input', text)
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||
else:
|
||||
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
|
||||
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
|
||||
if regenerate:
|
||||
shared.history['visible'].pop()
|
||||
shared.history['internal'].pop()
|
||||
output['visible'].pop()
|
||||
output['internal'].pop()
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'] + [[visible_text, shared.processing_message]], 'internal': output['internal']}
|
||||
elif _continue:
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
|
||||
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
|
||||
last_reply = [output['internal'][-1][1], output['visible'][-1][1]]
|
||||
if loading_message:
|
||||
yield {'visible': output['visible'][:-1] + [[visible_text, last_reply[1] + '...']], 'internal': output['internal']}
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
kwargs = {
|
||||
'_continue': _continue,
|
||||
'history': output,
|
||||
}
|
||||
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
|
@ -232,22 +240,23 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
if shared.stop_everything:
|
||||
return shared.history['visible']
|
||||
yield output
|
||||
return
|
||||
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
output['internal'].append(['', ''])
|
||||
output['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
shared.history['internal'][-1] = [text, last_reply[0] + reply]
|
||||
shared.history['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
yield shared.history['visible']
|
||||
output['internal'][-1] = [text, last_reply[0] + reply]
|
||||
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
|
||||
yield output
|
||||
elif not (j == 0 and visible_reply.strip() == ''):
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
yield shared.history['visible']
|
||||
output['internal'][-1] = [text, reply.lstrip(' ')]
|
||||
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
|
||||
yield output
|
||||
|
||||
if next_character_found:
|
||||
break
|
||||
|
@ -257,7 +266,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
else:
|
||||
cumulative_reply = reply
|
||||
|
||||
yield shared.history['visible']
|
||||
yield output
|
||||
|
||||
|
||||
def impersonate_wrapper(text, state):
|
||||
|
@ -291,21 +300,24 @@ def impersonate_wrapper(text, state):
|
|||
yield cumulative_reply
|
||||
|
||||
|
||||
def generate_chat_reply(text, state, regenerate=False, _continue=False):
|
||||
def generate_chat_reply(text, history, state, regenerate=False, _continue=False, loading_message=True):
|
||||
if regenerate or _continue:
|
||||
text = ''
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield shared.history['visible']
|
||||
if (len(history['visible']) == 1 and not history['visible'][0][0]) or len(history['internal']) == 0:
|
||||
yield history
|
||||
return
|
||||
|
||||
for history in chatbot_wrapper(text, state, regenerate=regenerate, _continue=_continue):
|
||||
for history in chatbot_wrapper(text, history, state, regenerate=regenerate, _continue=_continue, loading_message=loading_message):
|
||||
yield history
|
||||
|
||||
|
||||
# Same as above but returns HTML
|
||||
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
||||
for history in generate_chat_reply(text, state, regenerate, _continue):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'])
|
||||
for i, history in enumerate(generate_chat_reply(text, shared.history, state, regenerate, _continue, loading_message=True)):
|
||||
if i != 0:
|
||||
shared.history = copy.deepcopy(history)
|
||||
|
||||
yield chat_html_wrapper(history['visible'], state['name1'], state['name2'], state['mode'], state['chat_style'])
|
||||
|
||||
|
||||
def remove_last_message():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue