Allow API requests to use parameter presets

This commit is contained in:
oobabooga 2023-06-13 20:34:35 -03:00
parent 8936160e54
commit 474dc7355a
8 changed files with 96 additions and 58 deletions

View file

@ -33,7 +33,6 @@ import re
import sys
import time
import traceback
from datetime import datetime
from functools import partial
from pathlib import Path
from threading import Lock
@ -44,7 +43,7 @@ import yaml
from PIL import Image
import modules.extensions as extensions_module
from modules import chat, shared, training, ui, utils
from modules import chat, presets, shared, training, ui, utils
from modules.extensions import apply_extensions
from modules.github import clone_or_pull_repository
from modules.html_generator import chat_html_wrapper
@ -80,53 +79,6 @@ def load_lora_wrapper(selected_loras):
yield ("Successfuly applied the LoRAs")
def load_preset_values(preset_menu, state, return_dict=False):
generate_params = {
'do_sample': True,
'temperature': 1,
'top_p': 1,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'num_beams': 1,
'penalty_alpha': 0,
'min_length': 0,
'length_penalty': 1,
'no_repeat_ngram_size': 0,
'early_stopping': False,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
}
with open(Path(f'presets/{preset_menu}.yaml'), 'r') as infile:
preset = yaml.safe_load(infile)
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
if return_dict:
return generate_params
else:
state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
def generate_preset_yaml(state):
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
return yaml.dump(data, sort_keys=False)
def current_time():
return f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}"
def load_prompt(fname):
if fname in ['None', '']:
return ''
@ -251,7 +203,7 @@ def get_model_specific_settings(model):
return model_settings
def load_model_specific_settings(model, state, return_dict=False):
def load_model_specific_settings(model, state):
model_settings = get_model_specific_settings(model)
for k in model_settings:
if k in state:
@ -448,7 +400,7 @@ def create_chat_settings_menus():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
generate_params = presets.load_preset(default_preset)
with gr.Row():
with gr.Column():
with gr.Row():
@ -515,7 +467,7 @@ def create_settings_menus(default_preset):
shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.')
shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming')
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']])
def create_file_saving_menus():
@ -578,7 +530,7 @@ def create_file_saving_event_handlers():
shared.gradio['save_preset'].click(
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
presets.generate_preset_yaml, shared.gradio['interface_state'], shared.gradio['save_contents']).then(
lambda: 'presets/', None, shared.gradio['save_root']).then(
lambda: 'My Preset.yaml', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
@ -1043,7 +995,7 @@ def create_interface():
shared.gradio['save_prompt'].click(
lambda x: x, shared.gradio['textbox'], shared.gradio['save_contents']).then(
lambda: 'prompts/', None, shared.gradio['save_root']).then(
lambda: current_time() + '.txt', None, shared.gradio['save_filename']).then(
lambda: utils.current_time() + '.txt', None, shared.gradio['save_filename']).then(
lambda: gr.update(visible=True), None, shared.gradio['file_saver'])
shared.gradio['delete_prompt'].click(