Fix output_ids decoding for Qwen/Qwen-7B-Chat (#5045)
This commit is contained in:
parent
dbe438564e
commit
1b8b61b928
1 changed files with 9 additions and 2 deletions
|
@ -265,8 +265,15 @@ def apply_stopping_strings(reply, all_stop_strings):
|
||||||
|
|
||||||
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
||||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
||||||
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from and shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁')) and not reply.startswith(' '):
|
|
||||||
reply = ' ' + reply
|
# Handle tokenizers that do not add the leading space for the first token
|
||||||
|
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):
|
||||||
|
first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from]))
|
||||||
|
if isinstance(first_token, (bytes,)):
|
||||||
|
first_token = first_token.decode('utf8')
|
||||||
|
|
||||||
|
if first_token.startswith('▁'):
|
||||||
|
reply = ' ' + reply
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue