Add skip_special_tokens checkbox for Dolly model (#1218)
This commit is contained in:
parent
a9c7ef4159
commit
b937c9d8c2
9 changed files with 35 additions and 15 deletions
|
@ -41,6 +41,7 @@ settings = {
|
|||
'stop_at_newline': False,
|
||||
'add_bos_token': True,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'truncation_length': 2048,
|
||||
'truncation_length_min': 0,
|
||||
'truncation_length_max': 4096,
|
||||
|
|
|
@ -57,14 +57,13 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
else:
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
if skip_special_tokens:
|
||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
reply = reply.replace(r'<|endoftext|>', '')
|
||||
return reply
|
||||
else:
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
|
@ -184,7 +183,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
output = input_ids[0]
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
|
||||
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
|
||||
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
|
@ -231,11 +230,12 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
|
@ -256,18 +256,20 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
|
@ -276,18 +278,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
|
|
|
@ -25,7 +25,7 @@ def list_model_elements():
|
|||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
||||
elements += list_model_elements()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue