Add LoRA support
This commit is contained in:
parent
ee164d1821
commit
104293f411
6 changed files with 51 additions and 8 deletions
25
server.py
25
server.py
|
@ -17,6 +17,7 @@ import modules.ui as ui
|
|||
from modules.html_generator import generate_chat_html
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import generate_reply
|
||||
from modules.LoRA import add_lora_to_model
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
|
@ -48,6 +49,9 @@ def get_available_extensions():
|
|||
def get_available_softprompts():
|
||||
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
if selected_model != shared.model_name:
|
||||
shared.model_name = selected_model
|
||||
|
@ -59,6 +63,13 @@ def load_model_wrapper(selected_model):
|
|||
|
||||
return selected_model
|
||||
|
||||
def load_lora_wrapper(selected_lora):
|
||||
if not shared.args.cpu:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
add_lora_to_model(selected_lora)
|
||||
return selected_lora
|
||||
|
||||
def load_preset_values(preset_menu, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
|
@ -181,6 +192,7 @@ available_models = get_available_models()
|
|||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
available_softprompts = get_available_softprompts()
|
||||
available_loras = get_available_loras()
|
||||
|
||||
# Default extensions
|
||||
extensions_module.available_extensions = get_available_extensions()
|
||||
|
@ -401,6 +413,19 @@ def create_interface():
|
|||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
with gr.Tab("LoRA", elem_id="lora"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown("Load")
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||
with gr.Column():
|
||||
gr.Markdown("Train (TODO)")
|
||||
gr.Button("Practice your button clicking skills")
|
||||
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
|
||||
|
||||
with gr.Tab("Interface mode", elem_id="interface-mode"):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
current_mode = "default"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue