Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)

This commit is contained in:
oobabooga 2023-08-06 17:22:48 -03:00 committed by GitHub
parent 5134878344
commit 0af10ab49b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 131 additions and 42 deletions

View file

@ -9,6 +9,7 @@ def default_preset():
'do_sample': True,
'temperature': 1,
'top_p': 1,
'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
@ -17,19 +18,23 @@ def default_preset():
'repetition_penalty': 1,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'min_length': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
}
def presets_params():
return [k for k in default_preset()]
def load_preset(name):
generate_params = default_preset()
if name not in ['None', None, '']:
@ -51,12 +56,12 @@ def load_preset_memoized(name):
def load_preset_for_ui(name, state):
generate_params = load_preset(name)
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
return state, *[generate_params[k] for k in presets_params()]
def generate_preset_yaml(state):
defaults = default_preset()
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
data = {k: state[k] for k in presets_params()}
# Remove entries that are identical to the defaults
for k in list(data.keys()):