Unify the 3 interface modes (#3554)

This commit is contained in:
oobabooga 2023-08-13 01:12:15 -03:00 committed by GitHub
parent bf70c19603
commit a1a9ec895d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 660 additions and 714 deletions

View file

@ -69,28 +69,28 @@ def create_interface():
# Force some events to be triggered on page load
shared.persistent_interface_state.update({
'loader': shared.args.loader or 'Transformers',
'mode': shared.settings['mode'],
'character_menu': shared.args.character or shared.settings['character'],
'instruction_template': shared.settings['instruction_template']
})
if shared.is_chat():
shared.persistent_interface_state.update({
'mode': shared.settings['mode'],
'character_menu': shared.args.character or shared.settings['character'],
'instruction_template': shared.settings['instruction_template']
})
if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink()
if Path("cache/pfp_character.png").exists():
Path("cache/pfp_character.png").unlink()
# css/js strings
css = ui.css if not shared.is_chat() else ui.css + ui.chat_css
js = ui.main_js
css = ui.css
js = ui.js
css += apply_extensions('css')
js += apply_extensions('js')
# The input elements for the generation functions
# Interface state elements
shared.input_elements = ui.list_interface_input_elements()
with gr.Blocks(css=css, analytics_enabled=False, title=title, theme=ui.theme) as shared.gradio['interface']:
# Interface state
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
# Audio notification
if Path("notification.mp3").exists():
shared.gradio['audio_notification'] = gr.Audio(interactive=False, value="notification.mp3", elem_id="audio_notification", visible=False)
@ -102,12 +102,9 @@ def create_interface():
shared.gradio['temporary_text'] = gr.Textbox(visible=False)
# Text Generation tab
if shared.is_chat():
ui_chat.create_ui()
elif shared.args.notebook:
ui_notebook.create_ui()
else:
ui_default.create_ui()
ui_chat.create_ui()
ui_default.create_ui()
ui_notebook.create_ui()
ui_parameters.create_ui(shared.settings['preset']) # Parameters tab
ui_model_menu.create_ui() # Model tab
@ -115,12 +112,9 @@ def create_interface():
ui_session.create_ui() # Session tab
# Generation events
if shared.is_chat():
ui_chat.create_event_handlers()
elif shared.args.notebook:
ui_notebook.create_event_handlers()
else:
ui_default.create_event_handlers()
ui_chat.create_event_handlers()
ui_default.create_event_handlers()
ui_notebook.create_event_handlers()
# Other events
ui_file_saving.create_event_handlers()
@ -130,11 +124,10 @@ def create_interface():
# Interface launch events
if shared.settings['dark_theme']:
shared.gradio['interface'].load(lambda: None, None, None, _js="() => document.getElementsByTagName('body')[0].classList.add('dark')")
shared.gradio['interface'].load(lambda: None, None, None, _js=f"() => {{{js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, gradio(ui.list_interface_input_elements()), show_progress=False)
if shared.is_chat():
shared.gradio['interface'].load(chat.redraw_html, shared.reload_inputs, gradio('display'))
shared.gradio['interface'].load(chat.redraw_html, gradio(ui_chat.reload_arr), gradio('display'))
extensions_module.create_extensions_tabs() # Extensions tabs
extensions_module.create_extensions_block() # Extensions block
@ -190,16 +183,10 @@ if __name__ == "__main__":
# Activate the extensions listed on settings.yaml
extensions_module.available_extensions = utils.get_available_extensions()
if shared.is_chat():
for extension in shared.settings['chat_default_extensions']:
shared.args.extensions = shared.args.extensions or []
if extension not in shared.args.extensions:
shared.args.extensions.append(extension)
else:
for extension in shared.settings['default_extensions']:
shared.args.extensions = shared.args.extensions or []
if extension not in shared.args.extensions:
shared.args.extensions.append(extension)
for extension in shared.settings['default_extensions']:
shared.args.extensions = shared.args.extensions or []
if extension not in shared.args.extensions:
shared.args.extensions.append(extension)
available_models = utils.get_available_models()