Add auto_max_new_tokens parameter (#3419)
This commit is contained in:
parent
0d9932815c
commit
e931844fe2
12 changed files with 17 additions and 0 deletions
|
@ -116,6 +116,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'ExLlama_HF': {
|
||||
'temperature',
|
||||
|
@ -139,6 +140,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'ExLlama': {
|
||||
'temperature',
|
||||
|
@ -176,6 +178,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'GPTQ-for-LLaMa': {
|
||||
'temperature',
|
||||
|
@ -203,6 +206,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
'llama.cpp': {
|
||||
'temperature',
|
||||
|
@ -237,6 +241,7 @@ loaders_samplers = {
|
|||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -36,6 +36,7 @@ settings = {
|
|||
'max_new_tokens': 200,
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 4096,
|
||||
'auto_max_new_tokens': False,
|
||||
'seed': -1,
|
||||
'character': 'None',
|
||||
'name1': 'You',
|
||||
|
|
|
@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
if state['auto_max_new_tokens']:
|
||||
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
|
|
|
@ -79,6 +79,7 @@ def list_model_elements():
|
|||
def list_interface_input_elements():
|
||||
elements = [
|
||||
'max_new_tokens',
|
||||
'auto_max_new_tokens',
|
||||
'seed',
|
||||
'temperature',
|
||||
'top_p',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue