Update truncation length based on max_seq_len/n_ctx

This commit is contained in:
oobabooga 2023-08-26 23:10:45 -07:00
parent e6eda5c2da
commit 0c9e818bb8
2 changed files with 25 additions and 4 deletions

View file

@ -113,7 +113,7 @@ def create_ui(default_preset):
with gr.Box():
with gr.Row():
with gr.Column():
shared.gradio['truncation_length'] = gr.Slider(value=shared.settings['truncation_length'], minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
with gr.Column():
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')
@ -129,3 +129,12 @@ def create_ui(default_preset):
def create_event_handlers():
shared.gradio['filter_by_loader'].change(loaders.blacklist_samplers, gradio('filter_by_loader'), gradio(loaders.list_all_samplers()), show_progress=False)
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, gradio('preset_menu', 'interface_state'), gradio('interface_state') + gradio(presets.presets_params()))
def get_truncation_length():
if shared.args.max_seq_len != shared.args_defaults.max_seq_len:
return shared.args.max_seq_len
if shared.args.n_ctx != shared.args_defaults.n_ctx:
return shared.args.n_ctx
else:
return shared.settings['truncation_length']