Add max_tokens_second param (#3533)

This commit is contained in:
oobabooga 2023-08-29 17:44:31 -03:00 committed by GitHub
parent fe1f7c6513
commit cec8db52e5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 24 additions and 3 deletions

View file

@ -47,6 +47,7 @@ settings = {
'truncation_length_max': 16384,
'custom_stopping_strings': '',
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'ban_eos_token': False,
'add_bos_token': True,
'skip_special_tokens': True,

View file

@ -80,10 +80,22 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
if is_stream:
cur_time = time.time()
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
last_update = cur_time
# Maximum number of tokens/second
if state['max_tokens_second'] > 0:
diff = 1 / state['max_tokens_second'] - (cur_time - last_update)
if diff > 0:
time.sleep(diff)
last_update = time.time()
yield reply
# Limit updates to 24 per second to not stress low latency networks
else:
if cur_time - last_update > 0.041666666666666664:
last_update = cur_time
yield reply
if stop_found:
break

View file

@ -93,6 +93,7 @@ def list_interface_input_elements():
elements = [
'max_new_tokens',
'auto_max_new_tokens',
'max_tokens_second',
'seed',
'temperature',
'top_p',

View file

@ -105,7 +105,6 @@ def create_ui(default_preset):
with gr.Column():
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha', info='For Contrastive Search. do_sample must be unchecked.')
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
@ -114,6 +113,7 @@ def create_ui(default_preset):
with gr.Row():
with gr.Column():
shared.gradio['truncation_length'] = gr.Slider(value=get_truncation_length(), minimum=shared.settings['truncation_length_min'], maximum=shared.settings['truncation_length_max'], step=256, label='Truncate the prompt up to this length', info='The leftmost tokens are removed if the prompt exceeds this length. Most models require this to be at most 2048.')
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum number of tokens/second', info='To make text readable in real time.')
shared.gradio['custom_stopping_strings'] = gr.Textbox(lines=1, value=shared.settings["custom_stopping_strings"] or None, label='Custom stopping strings', info='In addition to the defaults. Written between "" and separated by commas.', placeholder='"\\n", "\\nYou:"')
with gr.Column():
shared.gradio['auto_max_new_tokens'] = gr.Checkbox(value=shared.settings['auto_max_new_tokens'], label='auto_max_new_tokens', info='Expand max_new_tokens to the available context length.')