From 1b8b61b9284845acfc7252a2058f4c64339fbc5a Mon Sep 17 00:00:00 2001 From: zhangningboo <110680007+zhangningboo@users.noreply.github.com> Date: Sat, 23 Dec 2023 10:11:02 +0800 Subject: [PATCH] Fix output_ids decoding for Qwen/Qwen-7B-Chat (#5045) --- modules/text_generation.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index f640b2c..2bcbd9b 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -265,8 +265,15 @@ def apply_stopping_strings(reply, all_stop_strings): def get_reply_from_output_ids(output_ids, state, starting_from=0): reply = decode(output_ids[starting_from:], state['skip_special_tokens']) - if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from and shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁')) and not reply.startswith(' '): - reply = ' ' + reply + + # Handle tokenizers that do not add the leading space for the first token + if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '): + first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])) + if isinstance(first_token, (bytes,)): + first_token = first_token.decode('utf8') + + if first_token.startswith('▁'): + reply = ' ' + reply return reply