Two new options: truncation length and ban eos token
This commit is contained in:
parent
749c08a4ff
commit
cacbcda208
6 changed files with 62 additions and 48 deletions
|
@ -18,35 +18,35 @@ from modules.text_generation import (encode, generate_reply,
|
|||
get_max_prompt_length)
|
||||
|
||||
|
||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
|
||||
is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
|
||||
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
|
||||
def generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{context.strip()}\n"]
|
||||
is_instruct = state['mode'] == 'instruct'
|
||||
rows = [f"{state['context'].strip()}\n"]
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
if is_instruct:
|
||||
prefix1 = f"{name1}\n"
|
||||
prefix2 = f"{name2}\n"
|
||||
prefix1 = f"{state['name1']}\n"
|
||||
prefix2 = f"{state['name2']}\n"
|
||||
else:
|
||||
prefix1 = f"{name1}: "
|
||||
prefix2 = f"{name2}: "
|
||||
prefix1 = f"{state['name1']}: "
|
||||
prefix2 = f"{state['name2']}: "
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||
string = shared.history['internal'][i][0]
|
||||
if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
|
@ -58,13 +58,13 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
|
|||
# Adding the user message
|
||||
user_input = fix_newlines(user_input)
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
|
||||
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
limit = 3
|
||||
|
||||
while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
|
||||
while len(rows) > limit and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
|
@ -139,15 +139,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
text = apply_extensions(text, "input")
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {
|
||||
'end_of_turn': state['end_of_turn'],
|
||||
'is_instruct': state['mode'] == 'instruct',
|
||||
'_continue': _continue
|
||||
}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
prompt = generate_chat_prompt(text, state)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], **kwargs)
|
||||
prompt = custom_generate_chat_prompt(text, state)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
|
@ -197,7 +192,7 @@ def impersonate_wrapper(text, state):
|
|||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
prompt = generate_chat_prompt(text, state['max_new_tokens'], state['name1'], state['name2'], state['context'], state['chat_prompt_size'], end_of_turn=state['end_of_turn'], impersonate=True)
|
||||
prompt = generate_chat_prompt(text, state, impersonate=True)
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Yield *Is typing...*
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue