Update truncation length based on max_seq_len/n_ctx
This commit is contained in:
parent
e6eda5c2da
commit
0c9e818bb8
2 changed files with 25 additions and 4 deletions
|
|
@ -145,12 +145,14 @@ def create_event_handlers():
|
|||
apply_model_settings_to_state, gradio('model_menu', 'interface_state'), gradio('interface_state')).then(
|
||||
ui.apply_interface_values, gradio('interface_state'), gradio(ui.list_interface_input_elements()), show_progress=False).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False)
|
||||
load_model_wrapper, gradio('model_menu', 'loader', 'autoload_model'), gradio('model_status'), show_progress=False).success(
|
||||
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||
|
||||
shared.gradio['load_model'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success(
|
||||
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||
|
||||
shared.gradio['unload_model'].click(
|
||||
unload_model, None, None).then(
|
||||
|
|
@ -160,7 +162,8 @@ def create_event_handlers():
|
|||
unload_model, None, None).then(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
update_model_parameters, gradio('interface_state'), None).then(
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False)
|
||||
partial(load_model_wrapper, autoload=True), gradio('model_menu', 'loader'), gradio('model_status'), show_progress=False).success(
|
||||
update_truncation_length, gradio('truncation_length', 'interface_state'), gradio('truncation_length'))
|
||||
|
||||
shared.gradio['save_model_settings'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
|
|
@ -235,3 +238,12 @@ def download_model_wrapper(repo_id, progress=gr.Progress()):
|
|||
except:
|
||||
progress(1.0)
|
||||
yield traceback.format_exc().replace('\n', '\n\n')
|
||||
|
||||
|
||||
def update_truncation_length(current_length, state):
|
||||
if state['loader'] in ['ExLlama', 'ExLlama_HF']:
|
||||
return state['max_seq_len']
|
||||
elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
return state['n_ctx']
|
||||
else:
|
||||
return current_length
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue