Ensure that the chat prompt will always contain < 2048 tokens
This way, we can keep the context string at the top of the prompt even if you keep talking to the bot for hours. Before this commit, the prompt would be simply truncated and the context string would eventually be lost.
This commit is contained in:
parent
6456777b09
commit
ca13acdfa0
1 changed files with 24 additions and 17 deletions
41
server.py
41
server.py
|
@ -116,6 +116,14 @@ def fix_galactica(s):
|
||||||
s = s.replace(r'$$', r'$')
|
s = s.replace(r'$$', r'$')
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def encode(prompt, tokens):
|
||||||
|
if not args.cpu:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
||||||
|
else:
|
||||||
|
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
||||||
|
return input_ids
|
||||||
|
|
||||||
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
|
|
||||||
|
@ -131,14 +139,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||||
preset = infile.read()
|
preset = infile.read()
|
||||||
loaded_preset = inference_settings
|
loaded_preset = inference_settings
|
||||||
|
|
||||||
if not args.cpu:
|
input_ids = encode(question, tokens)
|
||||||
torch.cuda.empty_cache()
|
|
||||||
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
|
|
||||||
cuda = ".cuda()"
|
|
||||||
else:
|
|
||||||
input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
|
|
||||||
cuda = ""
|
|
||||||
|
|
||||||
|
cuda = ".cuda()" if args.cpu else ""
|
||||||
if eos_token is None:
|
if eos_token is None:
|
||||||
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, {preset}){cuda}")
|
||||||
else:
|
else:
|
||||||
|
@ -217,16 +220,20 @@ elif args.chat or args.cai_chat:
|
||||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||||
text = chat_response_cleaner(text)
|
text = chat_response_cleaner(text)
|
||||||
|
|
||||||
question = f"{context}\n\n"
|
rows = [f"{context}\n\n"]
|
||||||
for i in range(len(history)):
|
i = len(history)-1
|
||||||
if args.cai_chat:
|
while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
|
||||||
question += f"{name1}: {history[i][0].strip()}\n"
|
rows.insert(1, f"{name2}: {history[i][1].strip()}\n")
|
||||||
question += f"{name2}: {history[i][1].strip()}\n"
|
rows.insert(1, f"{name1}: {history[i][0].strip()}\n")
|
||||||
else:
|
i -= 1
|
||||||
question += f"{name1}: {history[i][0][3:-5].strip()}\n"
|
rows.append(f"{name1}: {text}\n")
|
||||||
question += f"{name2}: {history[i][1][3:-5].strip()}\n"
|
rows.append(f"{name2}:")
|
||||||
question += f"{name1}: {text}\n"
|
|
||||||
question += f"{name2}:"
|
while len(rows) > 3 and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
|
||||||
|
rows.pop(1)
|
||||||
|
rows.pop(1)
|
||||||
|
|
||||||
|
question = ''.join(rows)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue