Add bot prefix modifier option in extensions

This commit is contained in:
oobabooga 2023-01-29 10:11:59 -03:00
parent c1c129196e
commit e5ff4ddfc8
3 changed files with 25 additions and 5 deletions

View file

@ -235,10 +235,12 @@ def apply_extensions(text, typ):
for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
if extension_state[ext][0] == True:
ext_string = f"extensions.{ext}.script"
if typ == "input":
if typ == "input" and hasattr(eval(ext_string), "input_modifier"):
text = eval(f"{ext_string}.input_modifier(text)")
else:
elif typ == "output" and hasattr(eval(ext_string), "output_modifier"):
text = eval(f"{ext_string}.output_modifier(text)")
elif typ == "bot_prefix" and hasattr(eval(ext_string), "bot_prefix_modifier"):
text = eval(f"{ext_string}.bot_prefix_modifier(text)")
return text
def update_extensions_parameters(*kwargs):
@ -274,7 +276,6 @@ def create_extensions_block():
btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
return extensions_ui_elements, btn_extensions
def get_available_models():
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
@ -353,7 +354,7 @@ if args.chat or args.cai_chat:
if history_size != 0 and count >= history_size:
break
rows.append(f"{name1}: {text}\n")
rows.append(f"{name2}:")
rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
rows.pop(1)
@ -376,7 +377,7 @@ if args.chat or args.cai_chat:
idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
idx = idx[len(previous_idx)-1]
reply = reply[idx + len(f"\n{name2}:"):]
reply = reply[idx + 1 + len(apply_extensions(f"{name2}:", "bot_prefix")):]
if check:
reply = reply.split('\n')[0].strip()
else: