Add epsilon_cutoff/eta_cutoff parameters (#2258)
This commit is contained in:
parent
767a767989
commit
8ac3636966
9 changed files with 36 additions and 9 deletions
25
server.py
25
server.py
|
@ -84,6 +84,8 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 50,
|
||||
|
@ -100,13 +102,13 @@ def load_preset_values(preset_menu, state, return_dict=False):
|
|||
i = i.rstrip(',').strip().split('=')
|
||||
if len(i) == 2 and i[0].strip() != 'tokens':
|
||||
generate_params[i[0].strip()] = eval(i[1].strip())
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
|
@ -453,17 +455,24 @@ def create_settings_menus(default_preset):
|
|||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
gr.Markdown('Contrastive search')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown('Contrastive search')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
with gr.Column():
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown('Other')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['epsilon_cutoff'] = gr.Slider(0, 9, value=generate_params['epsilon_cutoff'], step=0.01, label='epsilon_cutoff', info='In units of 1e-4')
|
||||
with gr.Column():
|
||||
shared.gradio['eta_cutoff'] = gr.Slider(0, 20, value=generate_params['eta_cutoff'], step=0.01, label='eta_cutoff', info='In units of 1e-4')
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@ -485,7 +494,7 @@ def create_settings_menus(default_preset):
|
|||
with gr.Row():
|
||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue