Make the bos token optional

This commit is contained in:
oobabooga 2023-04-10 16:44:22 -03:00
parent 4961f43702
commit bd04ff27ad
3 changed files with 12 additions and 4 deletions

View file

@ -35,6 +35,7 @@ settings = {
'greeting': 'Hello there!',
'end_of_turn': '',
'stop_at_newline': False,
'add_bos_token': True,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048,

View file

@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
# This is a hack for making replies more creative.
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Llama adds this extra token when the first character is '\n', and this
# compromises the stopping criteria, so we just remove it
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
input_ids = input_ids[:, 1:]
@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
input_ids = encode(question, generate_state['max_new_tokens'])
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
original_input_ids = input_ids
output = input_ids[0]