New feature: "random preset" button (#4647)
This commit is contained in:
parent
d1a58da52f
commit
83b64e7fc1
3 changed files with 51 additions and 2 deletions
|
@ -1,8 +1,12 @@
|
|||
import functools
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from modules import shared
|
||||
from modules.loaders import loaders_samplers
|
||||
|
||||
|
||||
def default_preset():
|
||||
return {
|
||||
|
@ -63,6 +67,45 @@ def load_preset_for_ui(name, state):
|
|||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def random_preset(state):
|
||||
params_and_values = {
|
||||
'remove_tail_tokens': {
|
||||
'top_p': [0.5, 0.8, 0.9, 0.95, 0.99],
|
||||
'min_p': [0.5, 0.2, 0.1, 0.05, 0.01],
|
||||
'top_k': [3, 5, 10, 20, 30, 40],
|
||||
'typical_p': [0.2, 0.575, 0.95],
|
||||
'tfs': [0.5, 0.8, 0.9, 0.95, 0.99],
|
||||
'top_a': [0.5, 0.2, 0.1, 0.05, 0.01],
|
||||
'epsilon_cutoff': [1, 3, 5, 7, 9],
|
||||
'eta_cutoff': [3, 6, 9, 12, 15, 18],
|
||||
},
|
||||
'flatten_distribution': {
|
||||
'temperature': [0.5, 0.7, 0.8, 1, 1.2, 1.5, 2.0],
|
||||
},
|
||||
'repetition': {
|
||||
'repetition_penalty': [1, 1.05, 1.1, 1.15, 1.20, 1.25],
|
||||
'presence_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
||||
'frequency_penalty': [0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
|
||||
},
|
||||
'other': {
|
||||
'temperature_last': [True, False],
|
||||
}
|
||||
}
|
||||
|
||||
generate_params = default_preset()
|
||||
for cat in params_and_values:
|
||||
choices = list(params_and_values[cat].keys())
|
||||
if shared.args.loader is not None:
|
||||
choices = [x for x in choices if x in loaders_samplers[shared.args.loader]]
|
||||
|
||||
if len(choices) > 0:
|
||||
choice = random.choice(choices)
|
||||
generate_params[choice] = random.choice(params_and_values[cat][choice])
|
||||
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
defaults = default_preset()
|
||||
data = {k: state[k] for k in presets_params()}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue