Implement sessions + add basic multi-user support (#2991)
This commit is contained in:
parent
1f8cae14f9
commit
4b1804a438
17 changed files with 595 additions and 414 deletions
|
@ -3,9 +3,10 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from modules import chat, shared
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.utils import gradio
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
|
@ -56,20 +57,24 @@ def load_model():
|
|||
return model
|
||||
|
||||
|
||||
def remove_tts_from_history():
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
def remove_tts_from_history(history):
|
||||
for i, entry in enumerate(history['internal']):
|
||||
history['visible'][i] = [history['visible'][i][0], entry[1]]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def toggle_text_in_history():
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
def toggle_text_in_history(history):
|
||||
for i, entry in enumerate(history['visible']):
|
||||
visible_reply = entry[1]
|
||||
if visible_reply.startswith('<audio'):
|
||||
if params['show_text']:
|
||||
reply = shared.history['internal'][i][1]
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
reply = history['internal'][i][1]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
else:
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
|
@ -80,7 +85,7 @@ def state_modifier(state):
|
|||
return state
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
def input_modifier(string, state):
|
||||
if not params['activate']:
|
||||
return string
|
||||
|
||||
|
@ -99,7 +104,7 @@ def history_modifier(history):
|
|||
return history
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
def output_modifier(string, state):
|
||||
global model, current_params, streaming_state
|
||||
for i in params:
|
||||
if params[i] != current_params[i]:
|
||||
|
@ -116,7 +121,7 @@ def output_modifier(string):
|
|||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
else:
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
||||
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
|
@ -155,23 +160,24 @@ def ui():
|
|||
|
||||
gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)')
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
if shared.is_chat():
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue