Model-aware prompts and presets

This commit is contained in:
oobabooga 2023-03-02 11:25:04 -03:00
parent 024d30d1b4
commit 169209805d
3 changed files with 39 additions and 25 deletions

View file

@ -94,8 +94,8 @@ def upload_soft_prompt(file):
return name
def create_settings_menus():
generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True)
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
with gr.Row():
with gr.Column():
@ -104,7 +104,7 @@ def create_settings_menus():
ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
with gr.Column():
with gr.Row():
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
@ -150,8 +150,8 @@ available_presets = get_available_presets()
available_characters = get_available_characters()
available_softprompts = get_available_softprompts()
# Default extensions
extensions_module.available_extensions = get_available_extensions()
# Activate the default extensions
if shared.args.chat or shared.args.cai_chat:
for extension in shared.settings['chat_default_extensions']:
shared.args.extensions = shared.args.extensions or []
@ -165,7 +165,7 @@ else:
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
# Choosing the default model
# Default model
if shared.args.model is not None:
shared.model_name = shared.args.model
else:
@ -184,16 +184,12 @@ else:
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
# UI settings
# Default UI settings
gen_events = []
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
default_text = shared.settings['prompt_gpt4chan']
elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
default_text = 'User: \n'
else:
default_text = shared.settings['prompt']
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as shared.gradio['interface']:
@ -257,7 +253,7 @@ if shared.args.chat or shared.args.cai_chat:
with gr.Column():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
create_settings_menus()
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
if shared.args.extensions is not None:
@ -321,7 +317,7 @@ elif shared.args.notebook:
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_settings_menus()
create_settings_menus(default_preset)
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
@ -345,7 +341,7 @@ else:
with gr.Column():
shared.gradio['Stop'] = gr.Button('Stop')
create_settings_menus()
create_settings_menus(default_preset)
if shared.args.extensions is not None:
extensions_module.create_extensions_block()