Implement sessions + add basic multi-user support (#2991)
This commit is contained in:
parent
1f8cae14f9
commit
4b1804a438
17 changed files with 595 additions and 414 deletions
|
@ -3,7 +3,9 @@ from pathlib import Path
|
|||
|
||||
import elevenlabs
|
||||
import gradio as gr
|
||||
|
||||
from modules import chat, shared
|
||||
from modules.utils import gradio
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
|
@ -35,24 +37,24 @@ def refresh_voices_dd():
|
|||
return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
|
||||
|
||||
|
||||
def remove_tts_from_history():
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
def remove_tts_from_history(history):
|
||||
for i, entry in enumerate(history['internal']):
|
||||
history['visible'][i] = [history['visible'][i][0], entry[1]]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def toggle_text_in_history():
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
def toggle_text_in_history(history):
|
||||
for i, entry in enumerate(history['visible']):
|
||||
visible_reply = entry[1]
|
||||
if visible_reply.startswith('<audio'):
|
||||
if params['show_text']:
|
||||
reply = shared.history['internal'][i][1]
|
||||
shared.history['visible'][i] = [
|
||||
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"
|
||||
]
|
||||
reply = history['internal'][i][1]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
else:
|
||||
shared.history['visible'][i] = [
|
||||
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"
|
||||
]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
|
@ -150,25 +152,24 @@ def ui():
|
|||
convert_cancel = gr.Button('Cancel', visible=False)
|
||||
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
if shared.is_chat():
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
||||
|
|
|
@ -10,7 +10,6 @@ import requests
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import reload_model, unload_model
|
||||
from modules.ui import create_refresh_button
|
||||
|
||||
|
@ -126,7 +125,7 @@ def input_modifier(string):
|
|||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
def get_SD_pictures(description):
|
||||
def get_SD_pictures(description, character):
|
||||
|
||||
global params
|
||||
|
||||
|
@ -160,7 +159,7 @@ def get_SD_pictures(description):
|
|||
if params['save_img']:
|
||||
img_data = base64.b64decode(img_str)
|
||||
|
||||
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
|
||||
variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}'
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
@ -186,7 +185,7 @@ def get_SD_pictures(description):
|
|||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
def output_modifier(string):
|
||||
def output_modifier(string, state):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
@ -213,7 +212,7 @@ def output_modifier(string):
|
|||
else:
|
||||
text = string
|
||||
|
||||
string = get_SD_pictures(string) + "\n" + text
|
||||
string = get_SD_pictures(string, state['character_menu']) + "\n" + text
|
||||
|
||||
return string
|
||||
|
||||
|
|
|
@ -3,9 +3,10 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from modules import chat, shared
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.utils import gradio
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
|
@ -56,20 +57,24 @@ def load_model():
|
|||
return model
|
||||
|
||||
|
||||
def remove_tts_from_history():
|
||||
for i, entry in enumerate(shared.history['internal']):
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
|
||||
def remove_tts_from_history(history):
|
||||
for i, entry in enumerate(history['internal']):
|
||||
history['visible'][i] = [history['visible'][i][0], entry[1]]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def toggle_text_in_history():
|
||||
for i, entry in enumerate(shared.history['visible']):
|
||||
def toggle_text_in_history(history):
|
||||
for i, entry in enumerate(history['visible']):
|
||||
visible_reply = entry[1]
|
||||
if visible_reply.startswith('<audio'):
|
||||
if params['show_text']:
|
||||
reply = shared.history['internal'][i][1]
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
reply = history['internal'][i][1]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
|
||||
else:
|
||||
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
|
||||
|
||||
return history
|
||||
|
||||
|
||||
def state_modifier(state):
|
||||
|
@ -80,7 +85,7 @@ def state_modifier(state):
|
|||
return state
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
def input_modifier(string, state):
|
||||
if not params['activate']:
|
||||
return string
|
||||
|
||||
|
@ -99,7 +104,7 @@ def history_modifier(history):
|
|||
return history
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
def output_modifier(string, state):
|
||||
global model, current_params, streaming_state
|
||||
for i in params:
|
||||
if params[i] != current_params[i]:
|
||||
|
@ -116,7 +121,7 @@ def output_modifier(string):
|
|||
if string == '':
|
||||
string = '*Empty reply, try regenerating*'
|
||||
else:
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
|
||||
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
|
||||
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
|
||||
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
|
||||
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
|
||||
|
@ -155,23 +160,24 @@ def ui():
|
|||
|
||||
gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)')
|
||||
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
if shared.is_chat():
|
||||
# Convert history with confirmation
|
||||
convert_arr = [convert_confirm, convert, convert_cancel]
|
||||
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
|
||||
convert_confirm.click(
|
||||
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
|
||||
remove_tts_from_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
|
||||
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, None, None).then(
|
||||
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
|
||||
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
|
||||
# Toggle message text in history
|
||||
show_text.change(
|
||||
lambda x: params.update({"show_text": x}), show_text, None).then(
|
||||
toggle_text_in_history, gradio('history'), gradio('history')).then(
|
||||
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
|
||||
chat.redraw_html, shared.reload_inputs, gradio('display'))
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
activate.change(lambda x: params.update({"activate": x}), activate, None)
|
||||
|
|
|
@ -96,6 +96,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
|
|||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
global chat_collector
|
||||
|
||||
history = state['history']
|
||||
|
||||
if state['mode'] == 'instruct':
|
||||
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
|
||||
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
|
||||
|
@ -104,29 +106,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
def make_single_exchange(id_):
|
||||
output = ''
|
||||
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
|
||||
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
|
||||
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
|
||||
return output
|
||||
|
||||
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
|
||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(shared.history['internal'])
|
||||
hist_size = len(history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
chunks.append(make_single_exchange(i))
|
||||
|
||||
add_chunks_to_collector(chunks, chat_collector)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
query = '\n'.join(history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
|
||||
additional_context = '\n'
|
||||
for id_ in best_ids:
|
||||
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
|
||||
additional_context += make_single_exchange(id_)
|
||||
|
||||
logger.warning(f'Adding the following new context:\n{additional_context}')
|
||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||
kwargs['history'] = {
|
||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'visible': ''
|
||||
}
|
||||
except RuntimeError:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue