Add repetition penalty range parameter to transformers (#2916)
This commit is contained in:
parent
c6cae106e7
commit
3443219cbc
12 changed files with 55 additions and 5 deletions
|
@ -15,6 +15,7 @@ def load_preset(name):
|
|||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
|
@ -46,9 +47,9 @@ def load_preset_memoized(name):
|
|||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue