initial multi-lora support (#1103)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Alex "mcmonkey" Goodwin 2023-04-14 10:52:06 -07:00 committed by GitHub
parent ebb81eb176
commit 64e3b44e0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 43 additions and 24 deletions

View file

@ -88,9 +88,10 @@ def load_model_wrapper(selected_model):
yield traceback.format_exc()
def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora)
return selected_lora
def load_lora_wrapper(selected_loras):
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
add_lora_to_model(selected_loras)
yield ("Successfuly applied the LoRAs")
def load_preset_values(preset_menu, state, return_dict=False):
@ -275,12 +276,14 @@ def create_model_menus():
with gr.Column():
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=get_available_loras(), value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=get_available_loras(), value=shared.lora_names, label='LoRA(s)')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
with gr.Column():
unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model")
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row():
unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model")
with gr.Row():
with gr.Column():
@ -338,7 +341,7 @@ def create_model_menus():
update_model_parameters, shared.gradio['interface_state'], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
@ -428,8 +431,8 @@ def create_interface():
# Defining some variables
gen_events = []
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
if shared.lora_name != "None":
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
if len(shared.lora_names) == 1:
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title = 'Text generation web UI'
@ -861,7 +864,7 @@ if __name__ == "__main__":
# Load the model
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
add_lora_to_model(shared.args.lora)
add_lora_to_model([shared.args.lora])
# Launch the web UI
create_interface()