diff --git a/modules/exllamav2.py b/modules/exllamav2.py index a75ede4..0287a17 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -64,7 +64,7 @@ class Exllamav2Model: return result, result def encode(self, string, **kwargs): - return self.tokenizer.encode(string, add_bos=True) + return self.tokenizer.encode(string, add_bos=True, encode_special_tokens=True) def decode(self, ids, **kwargs): if isinstance(ids, list): @@ -72,7 +72,7 @@ class Exllamav2Model: elif isinstance(ids, torch.Tensor) and ids.numel() == 1: ids = ids.view(1, -1) - return self.tokenizer.decode(ids)[0] + return self.tokenizer.decode(ids, decode_special_tokens=True)[0] def get_logits(self, token_ids, **kwargs): self.cache.current_seq_len = 0 @@ -97,7 +97,7 @@ class Exllamav2Model: if len(to_ban) > 0: settings.disallow_tokens(self.tokenizer, to_ban) - ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token']) + ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'], encode_special_tokens=True) ids = ids[:, -get_max_prompt_length(state):] initial_len = ids.shape[-1] @@ -119,7 +119,7 @@ class Exllamav2Model: if i == 0 and self.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): has_leading_space = True - decoded_text = self.tokenizer.decode(ids[:, initial_len:])[0] + decoded_text = self.tokenizer.decode(ids[:, initial_len:], decode_special_tokens=not state['skip_special_tokens'])[0] if has_leading_space: decoded_text = ' ' + decoded_text diff --git a/modules/loaders.py b/modules/loaders.py index ab10e0a..bd3f04a 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -231,6 +231,7 @@ loaders_samplers = { 'ban_eos_token', 'add_bos_token', 'custom_token_bans', + 'skip_special_tokens', 'auto_max_new_tokens', }, 'ExLlamav2_HF': {