Refactor chat functions (#2003)

This commit is contained in:
oobabooga 2023-05-11 15:37:04 -03:00 committed by GitHub
parent 4e9da22c58
commit 638c6a65a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 157 deletions

View file

@ -35,18 +35,15 @@ class Handler(BaseHTTPRequestHandler):
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
answer = a
response = json.dumps({
'results': [{
'text': answer if shared.is_chat() else answer[len(prompt):]
'text': answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))

View file

@ -26,19 +26,14 @@ async def _handle_connection(websocket, path):
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = len(prompt) if not shared.is_chat() else 0
skip_index = len(prompt)
message_num = 0
for a in generator:
to_send = ''
if isinstance(a, str):
to_send = a[skip_index:]
else:
to_send = a[0][skip_index:]
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,

View file

@ -3,9 +3,7 @@ from pathlib import Path
import elevenlabs
import gradio as gr
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
params = {
'activate': True,
@ -31,14 +29,12 @@ def refresh_voices_dd():
return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
def remove_tts_from_history(name1, name2, mode, style):
def remove_tts_from_history():
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style)
def toggle_text_in_history(name1, name2, mode, style):
def toggle_text_in_history():
for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
@ -52,8 +48,6 @@ def toggle_text_in_history(name1, name2, mode, style):
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"
]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style)
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
@ -152,22 +146,23 @@ def ui():
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(
lambda: [gr.update(visible=True), gr.update(visible=False),
gr.update(visible=True)], None, convert_arr
)
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True),
gr.update(visible=False)], None, convert_arr
)
convert_confirm.click(
remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']], shared.gradio['display']
)
convert_confirm.click(chat.save_history, shared.gradio['mode'], [], show_progress=False)
convert_cancel.click(
lambda: [gr.update(visible=False), gr.update(visible=True),
gr.update(visible=False)], None, convert_arr
)
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({'activate': x}), activate, None)
@ -175,11 +170,5 @@ def ui():
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
# connect.click(check_valid_api, [], connection_status)
refresh.click(refresh_voices_dd, [], voice)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(
toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']], shared.gradio['display']
)
show_text.change(chat.save_history, shared.gradio['mode'], [], show_progress=False)
# Event functions to update the parameters in the backend
autoplay.change(lambda x: params.update({"autoplay": x}), autoplay, None)

View file

@ -43,5 +43,5 @@ def ui():
picture_select.upload(
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
lambda: None, None, picture_select, show_progress=False)

View file

@ -3,9 +3,9 @@ from pathlib import Path
import gradio as gr
import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
from extensions.silero_tts import tts_preprocessor
torch._C._jit_set_profiling_mode(False)
@ -56,14 +56,12 @@ def load_model():
return model
def remove_tts_from_history(name1, name2, mode, style):
def remove_tts_from_history():
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style)
def toggle_text_in_history(name1, name2, mode, style):
def toggle_text_in_history():
for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
@ -73,8 +71,6 @@ def toggle_text_in_history(name1, name2, mode, style):
else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat_html_wrapper(shared.history['visible'], name1, name2, mode, style)
def state_modifier(state):
state['stream'] = False
@ -169,15 +165,20 @@ def ui():
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']], shared.gradio['display'])
convert_confirm.click(chat.save_history, shared.gradio['mode'], [], show_progress=False)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'mode', 'chat_style']], shared.gradio['display'])
show_text.change(chat.save_history, shared.gradio['mode'], [], show_progress=False)
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)