From ca13acdfa001862afd98d07958b0d8896dff9738 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 17 Jan 2023 20:16:23 -0300 Subject: [PATCH] Ensure that the chat prompt will always contain < 2048 tokens This way, we can keep the context string at the top of the prompt even if you keep talking to the bot for hours. Before this commit, the prompt would be simply truncated and the context string would eventually be lost. --- server.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/server.py b/server.py index 2b76510..7bec72c 100644 --- a/server.py +++ b/server.py @@ -116,6 +116,14 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s +def encode(prompt, tokens): + if not args.cpu: + torch.cuda.empty_cache() + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() + else: + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens) + return input_ids + def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): global model, tokenizer, model_name, loaded_preset, preset @@ -131,14 +139,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = infile.read() loaded_preset = inference_settings - if not args.cpu: - torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() - cuda = ".cuda()" - else: - input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens) - cuda = "" + input_ids = encode(question, tokens) + cuda = ".cuda()" if args.cpu else "" if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}") else: @@ -217,16 +220,20 @@ elif args.chat or args.cai_chat: def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): text = chat_response_cleaner(text) - question = f"{context}\n\n" - for i in range(len(history)): - if args.cai_chat: - question += f"{name1}: {history[i][0].strip()}\n" - question += f"{name2}: {history[i][1].strip()}\n" - else: - question += f"{name1}: {history[i][0][3:-5].strip()}\n" - question += f"{name2}: {history[i][1][3:-5].strip()}\n" - question += f"{name1}: {text}\n" - question += f"{name2}:" + rows = [f"{context}\n\n"] + i = len(history)-1 + while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: + rows.insert(1, f"{name2}: {history[i][1].strip()}\n") + rows.insert(1, f"{name1}: {history[i][0].strip()}\n") + i -= 1 + rows.append(f"{name1}: {text}\n") + rows.append(f"{name2}:") + + while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens: + rows.pop(1) + rows.pop(1) + + question = ''.join(rows) if check: reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]