From ac6065d5edf2a82035a037ecdf68fba16abd682f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 26 Jan 2023 13:45:19 -0300 Subject: [PATCH] Fix character loading bug --- server.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index 5f80ae4..66bd18b 100644 --- a/server.py +++ b/server.py @@ -183,9 +183,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok cuda = "" if args.cpu else ".cuda()" n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] input_ids = encode(question, tokens) - # The stopping_criteria code below was copied from - # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py if stopping_string is not None: + # The stopping_criteria code below was copied from + # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py t = encode(stopping_string, 0, add_special_tokens=False) stopping_criteria_list = transformers.StoppingCriteriaList([ _SentinelTokenStoppingCriteria( @@ -382,16 +382,19 @@ if args.chat or args.cai_chat: return generate_chat_html(_history, name1, name2, character) def tokenize_dialogue(dialogue, name1, name2): + history = [] + dialogue = re.sub('', '', dialogue) dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue) - idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)] + if len(idx) == 0: + return history + messages = [] for i in range(len(idx)-1): messages.append(dialogue[idx[i]:idx[i+1]].strip()) messages.append(dialogue[idx[-1]:].strip()) - history = [] entry = ['', ''] for i in messages: if i.startswith(f'{name1}:'):