Add a "Filter by loader" menu to the Parameters tab

This commit is contained in:
oobabooga 2023-07-31 18:44:00 -07:00
parent abea8d9ad3
commit 84297d05c4
2 changed files with 174 additions and 3 deletions

View file

@ -89,6 +89,175 @@ loaders_and_params = {
]
}
loaders_samplers = {
'Transformers': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
},
'ExLlama_HF': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
},
'ExLlama': {
'temperature',
'top_p',
'top_k',
'typical_p',
'repetition_penalty',
'repetition_penalty_range',
'seed',
'ban_eos_token',
},
'AutoGPTQ': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
},
'GPTQ-for-LLaMa': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'penalty_alpha',
'num_beams',
'length_penalty',
'early_stopping',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
},
'llama.cpp': {
'temperature',
'top_p',
'top_k',
'tfs',
'repetition_penalty',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
},
'llamacpp_HF': {
'temperature',
'top_p',
'top_k',
'typical_p',
'epsilon_cutoff',
'eta_cutoff',
'tfs',
'top_a',
'repetition_penalty',
'repetition_penalty_range',
'encoder_repetition_penalty',
'no_repeat_ngram_size',
'min_length',
'seed',
'do_sample',
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'ban_eos_token',
'add_bos_token',
'skip_special_tokens',
},
}
@functools.cache
def list_all_samplers():
all_samplers = set()
for k in loaders_samplers:
for sampler in loaders_samplers[k]:
all_samplers.add(sampler)
return sorted(all_samplers)
def blacklist_samplers(loader):
all_samplers = list_all_samplers()
if loader == 'All':
return [gr.update(visible=True) for sampler in all_samplers]
else:
return [gr.update(visible=True) if sampler in loaders_samplers[loader] else gr.update(visible=False) for sampler in all_samplers]
def get_gpu_memory_keys():
return [k for k in shared.gradio if k.startswith('gpu_memory')]