initial multi-lora support (#1103)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
ebb81eb176
commit
64e3b44e0f
4 changed files with 43 additions and 24 deletions
25
server.py
25
server.py
|
@ -88,9 +88,10 @@ def load_model_wrapper(selected_model):
|
|||
yield traceback.format_exc()
|
||||
|
||||
|
||||
def load_lora_wrapper(selected_lora):
|
||||
add_lora_to_model(selected_lora)
|
||||
return selected_lora
|
||||
def load_lora_wrapper(selected_loras):
|
||||
yield ("Applying the following LoRAs to {}:\n\n{}".format(shared.model_name, '\n'.join(selected_loras)))
|
||||
add_lora_to_model(selected_loras)
|
||||
yield ("Successfuly applied the LoRAs")
|
||||
|
||||
|
||||
def load_preset_values(preset_menu, state, return_dict=False):
|
||||
|
@ -275,12 +276,14 @@ def create_model_menus():
|
|||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=get_available_loras(), value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(multiselect=True, choices=get_available_loras(), value=shared.lora_names, label='LoRA(s)')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
unload = gr.Button("Unload the model")
|
||||
reload = gr.Button("Reload the model")
|
||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
|
||||
with gr.Row():
|
||||
unload = gr.Button("Unload the model")
|
||||
reload = gr.Button("Reload the model")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -338,7 +341,7 @@ def create_model_menus():
|
|||
update_model_parameters, shared.gradio['interface_state'], None).then(
|
||||
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
|
||||
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
|
||||
|
||||
|
@ -428,8 +431,8 @@ def create_interface():
|
|||
# Defining some variables
|
||||
gen_events = []
|
||||
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
if shared.lora_name != "None":
|
||||
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
|
||||
if len(shared.lora_names) == 1:
|
||||
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
|
||||
else:
|
||||
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
|
||||
title = 'Text generation web UI'
|
||||
|
@ -861,7 +864,7 @@ if __name__ == "__main__":
|
|||
# Load the model
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
if shared.args.lora:
|
||||
add_lora_to_model(shared.args.lora)
|
||||
add_lora_to_model([shared.args.lora])
|
||||
|
||||
# Launch the web UI
|
||||
create_interface()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue