Multiple histories for each character (#4022)

This commit is contained in:
oobabooga 2023-09-21 17:19:32 -03:00 committed by GitHub
parent 029da9563f
commit 00ab450c13
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 240 additions and 197 deletions

View file

@ -4,6 +4,7 @@ import functools
import html
import json
import re
from datetime import datetime
from pathlib import Path
import gradio as gr
@ -297,8 +298,25 @@ def generate_chat_reply(text, state, regenerate=False, _continue=False, loading_
yield history
# Same as above but returns HTML for the UI
def character_is_loaded(state, raise_exception=False):
if state['mode'] in ['chat', 'chat-instruct'] and state['name2'] == '':
logger.error('It looks like no character is loaded. Please load one under Parameters > Character.')
if raise_exception:
raise ValueError
return False
else:
return True
def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
'''
Same as above but returns HTML for the UI
'''
if not character_is_loaded(state):
return
if state['start_with'] != '' and not _continue:
if regenerate:
text, state['history'] = remove_last_message(state['history'])
@ -359,86 +377,132 @@ def send_dummy_reply(text, state):
return history
def clear_chat_log(state):
greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
mode = state['mode']
history = state['history']
history['visible'] = []
history['internal'] = []
if mode != 'instruct':
if greeting != '':
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
return history
def redraw_html(history, name1, name2, mode, style, reset_cache=False):
return chat_html_wrapper(history, name1, name2, mode, style, reset_cache=reset_cache)
def save_history(history, path=None):
p = path or Path('logs/exported_history.json')
def start_new_chat(state):
mode = state['mode']
history = {'internal': [], 'visible': []}
if mode != 'instruct':
greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
if greeting != '':
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
save_history(history, unique_id, state['character_menu'], state['mode'])
return history
def get_history_file_path(unique_id, character, mode):
if mode == 'instruct':
p = Path(f'logs/instruct/{unique_id}.json')
else:
p = Path(f'logs/chat/{character}/{unique_id}.json')
return p
def save_history(history, unique_id, character, mode):
if shared.args.multi_user:
return
p = get_history_file_path(unique_id, character, mode)
if not p.parent.is_dir():
p.parent.mkdir(parents=True)
with open(p, 'w', encoding='utf-8') as f:
f.write(json.dumps(history, indent=4))
return p
def find_all_histories(state):
if shared.args.multi_user:
return ['']
if state['mode'] == 'instruct':
paths = Path('logs/instruct').glob('*.json')
else:
character = state['character_menu']
# Handle obsolete filenames and paths
old_p = Path(f'logs/{character}_persistent.json')
new_p = Path(f'logs/persistent_{character}.json')
if old_p.exists():
logger.warning(f"Renaming {old_p} to {new_p}")
old_p.rename(new_p)
if new_p.exists():
unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
p = get_history_file_path(unique_id, character, state['mode'])
logger.warning(f"Moving {new_p} to {p}")
p.parent.mkdir(exist_ok=True)
new_p.rename(p)
paths = Path(f'logs/chat/{character}').glob('*.json')
histories = sorted(paths, key=lambda x: x.stat().st_mtime, reverse=True)
histories = [path.stem for path in histories]
return histories
def load_history(file, history):
def load_latest_history(state):
'''
Loads the latest history for the given character in chat or chat-instruct
mode, or the latest instruct history for instruct mode.
'''
if shared.args.multi_user:
return start_new_chat(state)
histories = find_all_histories(state)
if len(histories) > 0:
unique_id = Path(histories[0]).stem
history = load_history(unique_id, state['character_menu'], state['mode'])
else:
history = start_new_chat(state)
return history
def load_history(unique_id, character, mode):
p = get_history_file_path(unique_id, character, mode)
f = json.loads(open(p, 'rb').read())
if 'internal' in f and 'visible' in f:
history = f
else:
history = {
'internal': f['data'],
'visible': f['data_visible']
}
return history
def load_history_json(file, history):
try:
file = file.decode('utf-8')
j = json.loads(file)
if 'internal' in j and 'visible' in j:
return j
f = json.loads(file)
if 'internal' in f and 'visible' in f:
history = f
else:
return history
history = {
'internal': f['data'],
'visible': f['data_visible']
}
return history
except:
return history
def save_persistent_history(history, character, mode):
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
save_history(history, path=Path(f'logs/persistent_{character}.json'))
def load_persistent_history(state):
if shared.session_is_loading:
shared.session_is_loading = False
return state['history']
if state['mode'] == 'instruct':
return state['history']
character = state['character_menu']
greeting = replace_character_names(state['greeting'], state['name1'], state['name2'])
should_load_history = (not shared.args.multi_user and character not in ['None', '', None])
old_p = Path(f'logs/{character}_persistent.json')
p = Path(f'logs/persistent_{character}.json')
if should_load_history and old_p.exists():
logger.warning(f"Renaming {old_p} to {p}")
old_p.rename(p)
if should_load_history and p.exists():
f = json.loads(open(p, 'rb').read())
if 'internal' in f and 'visible' in f:
history = f
else:
history = {'internal': [], 'visible': []}
history['internal'] = f['data']
history['visible'] = f['data_visible']
else:
history = {'internal': [], 'visible': []}
if greeting != "":
history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
history['visible'] += [['', apply_extensions('output', greeting, state, is_chat=True)]]
return history
def delete_history(unique_id, character, mode):
p = get_history_file_path(unique_id, character, mode)
delete_file(p)
def replace_character_names(text, name1, name2):
@ -465,61 +529,55 @@ def load_character(character, name1, name2, instruct=False):
greeting_field = 'greeting'
picture = None
# Delete the profile picture cache, if any
if Path("cache/pfp_character.png").exists() and not instruct:
Path("cache/pfp_character.png").unlink()
if instruct:
name1 = name2 = ''
folder = 'instruction-templates'
else:
folder = 'characters'
if character not in ['None', '', None]:
picture = generate_pfp_cache(character)
filepath = None
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists():
break
filepath = None
for extension in ["yml", "yaml", "json"]:
filepath = Path(f'{folder}/{character}.{extension}')
if filepath.exists():
break
if filepath is None:
logger.error(f"Could not find character file for {character} in {folder} folder. Please check your spelling.")
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")
if filepath is None or not filepath.exists():
logger.error(f"Could not find the character \"{character}\" inside {folder}/. No character has been loaded.")
raise ValueError
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
file_contents = open(filepath, 'r', encoding='utf-8').read()
data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
# Finding the bot's name
for k in ['name', 'bot', '<|bot|>', 'char_name']:
if k in data and data[k] != '':
name2 = data[k]
break
if Path("cache/pfp_character.png").exists() and not instruct:
Path("cache/pfp_character.png").unlink()
# Find the user name (if any)
for k in ['your_name', 'user', '<|user|>']:
if k in data and data[k] != '':
name1 = data[k]
break
picture = generate_pfp_cache(character)
if 'context' in data:
context = data['context']
if not instruct:
context = context.strip() + '\n'
elif "char_persona" in data:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
# Finding the bot's name
for k in ['name', 'bot', '<|bot|>', 'char_name']:
if k in data and data[k] != '':
name2 = data[k]
break
if greeting_field in data:
greeting = data[greeting_field]
# Find the user name (if any)
for k in ['your_name', 'user', '<|user|>']:
if k in data and data[k] != '':
name1 = data[k]
break
if 'turn_template' in data:
turn_template = data['turn_template']
if 'context' in data:
context = data['context']
if not instruct:
context = context.strip() + '\n'
elif "char_persona" in data:
context = build_pygmalion_style_context(data)
greeting_field = 'char_greeting'
else:
context = shared.settings['context']
name2 = shared.settings['name2']
greeting = shared.settings['greeting']
if greeting_field in data:
greeting = data[greeting_field]
if 'turn_template' in data:
turn_template = data['turn_template']
return name1, name2, picture, greeting, context, turn_template.replace("\n", r"\n")