Remove duplicate code

This commit is contained in:
oobabooga 2023-05-10 01:34:04 -03:00
parent ba445cf59f
commit bdf1274b5d
34 changed files with 32 additions and 180 deletions

View file

@ -131,6 +131,23 @@ def save_prompt(text):
def load_prompt(fname):
if fname in ['None', '']:
return ''
elif fname.startswith('Instruct-'):
fname = re.sub('^Instruct-', '', fname)
with open(Path(f'characters/instruction-following/{fname}.yaml'), 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
output = ''
if 'context' in data:
output += data['context']
replacements = {
'<|user|>': data['user'],
'<|bot|>': data['bot'],
'<|user-message|>': 'Input',
}
output += utils.replace_all(data['turn_template'].split('<|bot-message|>')[0], replacements)
return output
else:
with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
text = f.read()
@ -472,7 +489,7 @@ def create_interface():
gen_events = []
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
if len(shared.lora_names) == 1:
default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.lora_names[0].lower())), 'default')])
else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title = 'Text generation web UI'