Add auto_max_new_tokens parameter (#3419)
This commit is contained in:
parent
0d9932815c
commit
e931844fe2
12 changed files with 17 additions and 0 deletions
|
@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
if state['auto_max_new_tokens']:
|
||||
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue