Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)
This commit is contained in:
parent
5134878344
commit
0af10ab49b
17 changed files with 131 additions and 42 deletions
|
@ -9,6 +9,7 @@ def default_preset():
|
|||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'top_k': 0,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
|
@ -17,19 +18,23 @@ def default_preset():
|
|||
'repetition_penalty': 1,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'min_length': 0,
|
||||
'guidance_scale': 1,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
}
|
||||
|
||||
|
||||
def presets_params():
|
||||
return [k for k in default_preset()]
|
||||
|
||||
|
||||
def load_preset(name):
|
||||
generate_params = default_preset()
|
||||
if name not in ['None', None, '']:
|
||||
|
@ -51,12 +56,12 @@ def load_preset_memoized(name):
|
|||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
defaults = default_preset()
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
data = {k: state[k] for k in presets_params()}
|
||||
|
||||
# Remove entries that are identical to the defaults
|
||||
for k in list(data.keys()):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue