Add "Save current settings for this model" button

This commit is contained in:
oobabooga 2023-04-15 12:54:02 -03:00
parent b9dcba7762
commit ac189011cb
3 changed files with 40 additions and 12 deletions

View file

@ -21,6 +21,7 @@ from pathlib import Path
import gradio as gr
import psutil
import torch
import yaml
from PIL import Image
import modules.extensions as extensions_module
@ -233,7 +234,7 @@ def get_model_specific_settings(model):
model_settings = {}
for pat in settings:
if re.match(pat, model.lower()):
if re.match(pat.lower(), model.lower()):
for k in settings[pat]:
model_settings[k] = settings[pat][k]
@ -249,6 +250,29 @@ def load_model_specific_settings(model, state, return_dict=False):
return state
def save_model_settings(model, state):
if model == 'None':
yield ("Not saving the settings because no model is loaded.")
return
with Path(f'{shared.args.model_dir}/config-user.yaml') as p:
if p.exists():
user_config = yaml.safe_load(open(p, 'r').read())
else:
user_config = {}
if model not in user_config:
user_config[model] = {}
for k in ui.list_model_elements():
user_config[model][k] = state[k]
with open(p, 'w') as f:
f.write(yaml.dump(user_config))
yield (f"Settings for {model} saved to {p}")
def create_model_menus():
# Finding the default values for the GPU and CPU memories
total_mem = []
@ -285,10 +309,12 @@ def create_model_menus():
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras(), 'value': shared.lora_names}, 'refresh-button')
with gr.Column():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row():
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply the selected LoRAs')
with gr.Row():
unload = gr.Button("Unload the model")
reload = gr.Button("Reload the model")
save_settings = gr.Button("Save current settings for this model")
with gr.Row():
with gr.Column():
@ -344,7 +370,11 @@ def create_model_menus():
unload_model, None, None).then(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
update_model_parameters, shared.gradio['interface_state'], None).then(
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=True)
load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_status'], show_progress=False)
save_settings.click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
save_model_settings, [shared.gradio[k] for k in ['model_menu', 'interface_state']], shared.gradio['model_status'], show_progress=False)
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)