Two new options: truncation length and ban eos token

This commit is contained in:
oobabooga 2023-04-11 18:46:06 -03:00 committed by GitHub
parent 749c08a4ff
commit cacbcda208
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 48 deletions

View file

@ -18,35 +18,35 @@ from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
def generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"]
is_instruct = state['mode'] == 'instruct'
rows = [f"{state['context'].strip()}\n"]
# Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
max_length = min(get_max_prompt_length(state), chat_prompt_size)
if is_instruct:
prefix1 = f"{name1}\n"
prefix2 = f"{name2}\n"
prefix1 = f"{state['name1']}\n"
prefix2 = f"{state['name2']}\n"
else:
prefix1 = f"{name1}: "
prefix2 = f"{name2}: "
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
i = len(shared.history['internal']) - 1
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
string = shared.history['internal'][i][0]
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
i -= 1
if impersonate:
@ -58,13 +58,13 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
# Adding the user message
user_input = fix_newlines(user_input)
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
limit = 3
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
@ -139,15 +139,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
text = apply_extensions(text, "input")
# Generating the prompt
kwargs = {
'end_of_turn': state['end_of_turn'],
'is_instruct': state['mode'] == 'instruct',
'_continue': _continue
}
if custom_generate_chat_prompt is None:
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
prompt = generate_chat_prompt(text, state)
else:
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
prompt = custom_generate_chat_prompt(text, state)
# Yield *Is typing...*
if not any((regenerate, _continue)):
@ -197,7 +192,7 @@ def impersonate_wrapper(text, state):
# Defining some variables
cumulative_reply = ''
eos_token = '\n' if state['stop_at_newline'] else None
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
prompt = generate_chat_prompt(text, state, impersonate=True)
stopping_strings = get_stopping_strings(state)
# Yield *Is typing...*