Add support for custom chat styles (#1917)

This commit is contained in:
oobabooga 2023-05-08 12:35:03 -03:00 committed by GitHub
parent b040b4110d
commit b5260b24f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 234 additions and 74 deletions

View file

@ -12,6 +12,8 @@ from pathlib import Path
import markdown
from PIL import Image, ImageOps
from modules.utils import get_available_chat_styles
# This is to store the paths to the thumbnails of the profile pictures
image_cache = {}
@ -19,13 +21,14 @@ with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r
readable_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f:
_4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_bubble_chat_style.css', 'r') as f:
bubble_chat_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read()
# Custom chat styles
chat_styles = {}
for k in get_available_chat_styles():
chat_styles[k] = open(Path(f'css/chat_style-{k}.css'), 'r').read()
def fix_newlines(string):
string = string.replace('\n', '\n\n')
@ -185,8 +188,8 @@ def generate_instruct_html(history):
return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
def generate_cai_chat_html(history, name1, name2, style, reset_cache=False):
output = f'<style>{chat_styles[style]}</style><div class="chat" id="chat">'
# We use ?name2 and ?time.time() to force the browser to reset caches
img_bot = f'<img src="file/cache/pfp_character.png?{name2}">' if Path("cache/pfp_character.png").exists() else ''
@ -235,7 +238,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
def generate_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{bubble_chat_css}</style><div class="chat" id="chat">'
output = f'<style>{chat_styles["wpp"]}</style><div class="chat" id="chat">'
for i, _row in enumerate(history[::-1]):
row = [convert_to_markdown(entry) for entry in _row]
@ -267,12 +270,10 @@ def generate_chat_html(history, name1, name2, reset_cache=False):
return output
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache)
elif mode == "chat":
return generate_chat_html(history, name1, name2)
elif mode == "instruct":
def chat_html_wrapper(history, name1, name2, mode, style, reset_cache=False):
if mode == 'instruct':
return generate_instruct_html(history)
elif style == 'wpp':
return generate_chat_html(history, name1, name2)
else:
return ''
return generate_cai_chat_html(history, name1, name2, style, reset_cache)