Implement sessions + add basic multi-user support (#2991)

This commit is contained in:
oobabooga 2023-07-04 00:03:30 -03:00 committed by GitHub
parent 1f8cae14f9
commit 4b1804a438
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 595 additions and 414 deletions

View file

@ -3,7 +3,9 @@ from pathlib import Path
import elevenlabs
import gradio as gr
from modules import chat, shared
from modules.utils import gradio
params = {
'activate': True,
@ -35,24 +37,24 @@ def refresh_voices_dd():
return gr.Dropdown.update(value=all_voices[0], choices=all_voices)
def remove_tts_from_history():
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
def remove_tts_from_history(history):
for i, entry in enumerate(history['internal']):
history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history():
for i, entry in enumerate(shared.history['visible']):
def toggle_text_in_history(history):
for i, entry in enumerate(history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
if params['show_text']:
reply = shared.history['internal'][i][1]
shared.history['visible'][i] = [
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"
]
reply = history['internal'][i][1]
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else:
shared.history['visible'][i] = [
shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"
]
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return history
def remove_surrounded_chars(string):
@ -150,25 +152,24 @@ def ui():
convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
if shared.is_chat():
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({'activate': x}), activate, None)

View file

@ -10,7 +10,6 @@ import requests
import torch
from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model
from modules.ui import create_refresh_button
@ -126,7 +125,7 @@ def input_modifier(string):
return string
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
def get_SD_pictures(description, character):
global params
@ -160,7 +159,7 @@ def get_SD_pictures(description):
if params['save_img']:
img_data = base64.b64decode(img_str)
variadic = f'{date.today().strftime("%Y_%m_%d")}/{shared.character}_{int(time.time())}'
variadic = f'{date.today().strftime("%Y_%m_%d")}/{character}_{int(time.time())}'
output_file = Path(f'extensions/sd_api_pictures/outputs/{variadic}.png')
output_file.parent.mkdir(parents=True, exist_ok=True)
@ -186,7 +185,7 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string):
def output_modifier(string, state):
"""
This function is applied to the model outputs.
"""
@ -213,7 +212,7 @@ def output_modifier(string):
else:
text = string
string = get_SD_pictures(string) + "\n" + text
string = get_SD_pictures(string, state['character_menu']) + "\n" + text
return string

View file

@ -3,9 +3,10 @@ from pathlib import Path
import gradio as gr
import torch
from modules import chat, shared
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.utils import gradio
torch._C._jit_set_profiling_mode(False)
@ -56,20 +57,24 @@ def load_model():
return model
def remove_tts_from_history():
for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
def remove_tts_from_history(history):
for i, entry in enumerate(history['internal']):
history['visible'][i] = [history['visible'][i][0], entry[1]]
return history
def toggle_text_in_history():
for i, entry in enumerate(shared.history['visible']):
def toggle_text_in_history(history):
for i, entry in enumerate(history['visible']):
visible_reply = entry[1]
if visible_reply.startswith('<audio'):
if params['show_text']:
reply = shared.history['internal'][i][1]
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
reply = history['internal'][i][1]
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
history['visible'][i] = [history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return history
def state_modifier(state):
@ -80,7 +85,7 @@ def state_modifier(state):
return state
def input_modifier(string):
def input_modifier(string, state):
if not params['activate']:
return string
@ -99,7 +104,7 @@ def history_modifier(history):
return history
def output_modifier(string):
def output_modifier(string, state):
global model, current_params, streaming_state
for i in params:
if params[i] != current_params[i]:
@ -116,7 +121,7 @@ def output_modifier(string):
if string == '':
string = '*Empty reply, try regenerating*'
else:
output_file = Path(f'extensions/silero_tts/outputs/{shared.character}_{int(time.time())}.wav')
output_file = Path(f'extensions/silero_tts/outputs/{state["character_menu"]}_{int(time.time())}.wav')
prosody = '<prosody rate="{}" pitch="{}">'.format(params['voice_speed'], params['voice_pitch'])
silero_input = f'<speak>{prosody}{xmlesc(string)}</prosody></speak>'
model.save_wav(ssml_text=silero_input, speaker=params['speaker'], sample_rate=int(params['sample_rate']), audio_path=str(output_file))
@ -155,23 +160,24 @@ def ui():
gr.Markdown('[Click here for Silero audio samples](https://oobabooga.github.io/silero-samples/index.html)')
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
if shared.is_chat():
# Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(
lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr).then(
remove_tts_from_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, None, None).then(
chat.save_history, shared.gradio['mode'], None, show_progress=False).then(
chat.redraw_html, shared.reload_inputs, shared.gradio['display'])
# Toggle message text in history
show_text.change(
lambda x: params.update({"show_text": x}), show_text, None).then(
toggle_text_in_history, gradio('history'), gradio('history')).then(
chat.save_persistent_history, gradio('history', 'character_menu', 'mode'), None).then(
chat.redraw_html, shared.reload_inputs, gradio('display'))
# Event functions to update the parameters in the backend
activate.change(lambda x: params.update({"activate": x}), activate, None)

View file

@ -96,6 +96,8 @@ def apply_settings(chunk_count, chunk_count_initial, time_weight):
def custom_generate_chat_prompt(user_input, state, **kwargs):
global chat_collector
history = state['history']
if state['mode'] == 'instruct':
results = collector.get_sorted(user_input, n_results=params['chunk_count'])
additional_context = '\nYour reply should be based on the context below:\n\n' + '\n'.join(results)
@ -104,29 +106,29 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
def make_single_exchange(id_):
output = ''
output += f"{state['name1']}: {shared.history['internal'][id_][0]}\n"
output += f"{state['name2']}: {shared.history['internal'][id_][1]}\n"
output += f"{state['name1']}: {history['internal'][id_][0]}\n"
output += f"{state['name2']}: {history['internal'][id_][1]}\n"
return output
if len(shared.history['internal']) > params['chunk_count'] and user_input != '':
if len(history['internal']) > params['chunk_count'] and user_input != '':
chunks = []
hist_size = len(shared.history['internal'])
hist_size = len(history['internal'])
for i in range(hist_size-1):
chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
query = '\n'.join(history['internal'][-1] + [user_input])
try:
best_ids = chat_collector.get_ids_sorted(query, n_results=params['chunk_count'], n_initial=params['chunk_count_initial'], time_weight=params['time_weight'])
additional_context = '\n'
for id_ in best_ids:
if shared.history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
if history['internal'][id_][0] != '<|BEGIN-VISIBLE-CHAT|>':
additional_context += make_single_exchange(id_)
logger.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context
kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'internal': [history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError: