Merge branch 'main' into add-train-lora-tab
This commit is contained in:
commit
e439228ed8
17 changed files with 186 additions and 95 deletions
104
server.py
104
server.py
|
@ -4,6 +4,7 @@ import re
|
|||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
@ -36,6 +37,13 @@ def get_available_models():
|
|||
def get_available_presets():
|
||||
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
||||
|
||||
def get_available_prompts():
|
||||
prompts = []
|
||||
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
|
||||
prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
|
||||
prompts += ['None']
|
||||
return prompts
|
||||
|
||||
def get_available_characters():
|
||||
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
|
||||
|
||||
|
@ -48,12 +56,17 @@ def get_available_softprompts():
|
|||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
if selected_model != shared.model_name:
|
||||
shared.model_name = selected_model
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
unload_model()
|
||||
if selected_model != '':
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
return selected_model
|
||||
|
||||
|
@ -91,7 +104,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
|||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
|
@ -116,9 +129,43 @@ def create_model_and_preset_menus():
|
|||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
def save_prompt(text):
|
||||
fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
|
||||
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
return f"Saved prompt to prompts/{fname}"
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
else:
|
||||
with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
|
||||
def create_prompt_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
|
||||
ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Column():
|
||||
shared.gradio['save_prompt'] = gr.Button('Save prompt')
|
||||
shared.gradio['status'] = gr.Markdown('Ready')
|
||||
|
||||
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
|
||||
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
create_model_and_preset_menus()
|
||||
with gr.Column():
|
||||
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
|
@ -149,12 +196,6 @@ def create_settings_menus(default_preset):
|
|||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
@ -169,8 +210,7 @@ def create_settings_menus(default_preset):
|
|||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
|
||||
|
@ -235,8 +275,9 @@ if shared.args.lora:
|
|||
|
||||
# Default UI settings
|
||||
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
||||
if default_text == '':
|
||||
if shared.lora_name != "None":
|
||||
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
||||
else:
|
||||
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
title ='Text generation web UI'
|
||||
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
|
||||
|
@ -257,8 +298,8 @@ def create_interface():
|
|||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
|
@ -271,8 +312,6 @@ def create_interface():
|
|||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
create_model_and_preset_menus()
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
|
@ -366,19 +405,25 @@ def create_interface():
|
|||
|
||||
elif shared.args.notebook:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Column(scale=4):
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
|
||||
with gr.Column(scale=1):
|
||||
gr.Markdown("\n")
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
|
||||
create_prompt_menus()
|
||||
|
||||
create_model_and_preset_menus()
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
|
@ -402,7 +447,7 @@ def create_interface():
|
|||
with gr.Column():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
|
||||
create_model_and_preset_menus()
|
||||
create_prompt_menus()
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tab('Raw'):
|
||||
|
@ -411,6 +456,7 @@ def create_interface():
|
|||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue