Add repetition penalty range parameter to transformers (#2916)
This commit is contained in:
parent
c6cae106e7
commit
3443219cbc
12 changed files with 55 additions and 5 deletions
|
@ -230,7 +230,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue