From 54bf55372b363e2492e6dd0ae7402643fd498723 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 16 Jan 2023 13:43:23 -0300 Subject: [PATCH] Truncate prompts to 2048 characters --- server.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index 41fb429..936ae86 100644 --- a/server.py +++ b/server.py @@ -96,6 +96,7 @@ def load_model(model_name): tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) else: tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/")) + tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") return model, tokenizer @@ -134,10 +135,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok if not args.cpu: torch.cuda.empty_cache() - input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda() + input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda() cuda = ".cuda()" else: - input_ids = tokenizer.encode(str(question), return_tensors='pt') + input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens) cuda = "" if eos_token is None: @@ -231,10 +232,12 @@ elif args.chat or args.cai_chat: if check: reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0] - reply = reply[len(question):].split('\n')[0].strip() + idx = reply.rfind(question[-500:]) + reply = reply[idx+min(500, len(question)):].split('\n')[0].strip() else: reply = generate_reply(question, tokens, inference_settings, selected_model)[0] - reply = reply[len(question):] + idx = reply.rfind(question[-500:]) + reply = reply[idx+min(500, len(question)):] idx = reply.find(f"\n{name1}:") if idx != -1: reply = reply[:idx]