Add auto_max_new_tokens parameter (#3419)

This commit is contained in:
oobabooga 2023-08-02 14:52:20 -03:00 committed by GitHub
parent 0d9932815c
commit e931844fe2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 17 additions and 0 deletions

View file

@ -247,6 +247,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed))
if state['auto_max_new_tokens']:
generate_params['max_new_tokens'] = state['truncation_length'] - input_ids.shape[-1]
# Add the encoded tokens to generate_params
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)