diff --git a/server.py b/server.py index 2b76510..7bec72c 100644 --- a/server.py +++ b/server.py @@ -116,6 +116,14 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s +def encode(prompt, tokens): + if not args.cpu: + torch.cuda.empty_cache() + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() + else: + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens) + return input_ids + def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): global model, tokenizer, model_name, loaded_preset, preset @@ -131,14 +139,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok preset = infile.read() loaded_preset = inference_settings - if not args.cpu: - torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() - cuda = ".cuda()" - else: - input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens) - cuda = "" + input_ids = encode(question, tokens) + cuda = ".cuda()" if args.cpu else "" if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}") else: @@ -217,16 +220,20 @@ elif args.chat or args.cai_chat: def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): text = chat_response_cleaner(text) - question = f"{context}\n\n" - for i in range(len(history)): - if args.cai_chat: - question += f"{name1}: {history[i][0].strip()}\n" - question += f"{name2}: {history[i][1].strip()}\n" - else: - question += f"{name1}: {history[i][0][3:-5].strip()}\n" - question += f"{name2}: {history[i][1][3:-5].strip()}\n" - question += f"{name1}: {text}\n" - question += f"{name2}:" + rows = [f"{context}\n\n"] + i = len(history)-1 + while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: + rows.insert(1, f"{name2}: {history[i][1].strip()}\n") + rows.insert(1, f"{name1}: {history[i][0].strip()}\n") + i -= 1 + rows.append(f"{name1}: {text}\n") + rows.append(f"{name2}:") + + while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens: + rows.pop(1) + rows.pop(1) + + question = ''.join(rows) if check: reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]