Make the bos token optional
This commit is contained in:
parent
4961f43702
commit
bd04ff27ad
3 changed files with 12 additions and 4 deletions
|
@ -35,6 +35,7 @@ settings = {
|
|||
'greeting': 'Hello there!',
|
||||
'end_of_turn': '',
|
||||
'stop_at_newline': False,
|
||||
'add_bos_token': True,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
'chat_prompt_size_max': 2048,
|
||||
|
|
|
@ -22,7 +22,7 @@ def get_max_prompt_length(tokens):
|
|||
return max_length
|
||||
|
||||
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
def encode(prompt, tokens_to_generate=0, add_special_tokens=True, add_bos_token=True):
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
input_ids = shared.tokenizer.encode(str(prompt))
|
||||
input_ids = np.array(input_ids).reshape(1, len(input_ids))
|
||||
|
@ -30,6 +30,12 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
|||
else:
|
||||
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
|
||||
|
||||
# This is a hack for making replies more creative.
|
||||
if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
# Llama adds this extra token when the first character is '\n', and this
|
||||
# compromises the stopping criteria, so we just remove it
|
||||
if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
|
@ -158,7 +164,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'])
|
||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue