Add skip_special_tokens checkbox for Dolly model (#1218)

This commit is contained in:
oobabooga 2023-04-16 14:24:49 -03:00 committed by GitHub
parent a9c7ef4159
commit b937c9d8c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 35 additions and 15 deletions

View file

@ -57,14 +57,13 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
return input_ids.cuda()
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else:
def decode(output_ids, skip_special_tokens=True):
if skip_special_tokens:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply
else:
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
def generate_softprompt_input_tensors(input_ids):
@ -184,7 +183,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
output = input_ids[0]
if shared.args.verbose:
print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
@ -231,11 +230,12 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
output = shared.model.generate(**generate_params)[0]
if cuda:
output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
@ -256,18 +256,20 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if not shared.is_chat():
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
@ -276,18 +278,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)