Allow API requests to use parameter presets
This commit is contained in:
parent
8936160e54
commit
474dc7355a
8 changed files with 96 additions and 58 deletions
60
server.py
60
server.py
|
@ -33,7 +33,6 @@ import re
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
@ -44,7 +43,7 @@ import yaml
|
|||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import chat, shared, training, ui, utils
|
||||
from modules import chat, presets, shared, training, ui, utils
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.github import clone_or_pull_repository
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
|
@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras):
|
|||
yield ("Successfuly applied the LoRAs")
|
||||
|
||||
|
||||
def load_preset_values(preset_menu, state, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
}
|
||||
|
||||
with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile:
|
||||
preset = yaml.safe_load(infile)
|
||||
|
||||
for k in preset:
|
||||
generate_params[k] = preset[k]
|
||||
|
||||
generate_params['temperature'] = min(1.99, generate_params['temperature'])
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
return yaml.dump(data, sort_keys=False)
|
||||
|
||||
|
||||
def current_time():
|
||||
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
|
||||
|
||||
|
||||
def load_prompt(fname):
|
||||
if fname in ['None', '']:
|
||||
return ''
|
||||
|
@ -251,7 +203,7 @@ def get_model_specific_settings(model):
|
|||
return model_settings
|
||||
|
||||
|
||||
def load_model_specific_settings(model, state, return_dict=False):
|
||||
def load_model_specific_settings(model, state):
|
||||
model_settings = get_model_specific_settings(model)
|
||||
for k in model_settings:
|
||||
if k in state:
|
||||
|
@ -448,7 +400,7 @@ def create_chat_settings_menus():
|
|||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
|
||||
generate_params = presets.load_preset(default_preset)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
|
@ -515,7 +467,7 @@ def create_settings_menus(default_preset):
|
|||
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
|
||||
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
|
||||
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
|
||||
|
||||
|
||||
def create_file_saving_menus():
|
||||
|
@ -578,7 +530,7 @@ def create_file_saving_event_handlers():
|
|||
|
||||
shared.gradio['save_preset'].click(
|
||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||
generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
|
||||
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
|
||||
lambda: 'presets/', None, shared.gradio['save_root']).then(
|
||||
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
|
||||
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
|
||||
|
@ -1043,7 +995,7 @@ def create_interface():
|
|||
shared.gradio['save_prompt'].click(
|
||||
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then(
|
||||
lambda: 'prompts/', None, shared.gradio['save_root']).then(
|
||||
lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then(
|
||||
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then(
|
||||
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
|
||||
|
||||
shared.gradio['delete_prompt'].click(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue