Add auto_max_new_tokens parameter (#3419)

This commit is contained in:
oobabooga 2023-08-02 14:52:20 -03:00 committed by GitHub
parent 0d9932815c
commit e931844fe2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 17 additions and 0 deletions

View file

@ -116,6 +116,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'ExLlama_HF': {
'temperature',
@ -139,6 +140,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'ExLlama': {
'temperature',
@ -176,6 +178,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'GPTQ-for-LLaMa': {
'temperature',
@ -203,6 +206,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
'llama.cpp': {
'temperature',
@ -237,6 +241,7 @@ loaders_samplers = {
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
'auto_max_new_tokens',
},
}

View file

@ -36,6 +36,7 @@ settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 4096,
'auto_max_new_tokens': False,
'seed': -1,
'character': 'None',
'name1': 'You',

View file

@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed))
if state['auto_max_new_tokens']:
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
# Add the encoded tokens to generate_params
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)

View file

@ -79,6 +79,7 @@ def list_model_elements():
def list_interface_input_elements():
elements = [
'max_new_tokens',
'auto_max_new_tokens',
'seed',
'temperature',
'top_p',