Fix exllama tokenizers (#3954)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
8d85425e09
commit
ed6b6411fb
2 changed files with 31 additions and 10 deletions
|
@ -48,7 +48,7 @@ class Exllamav2Model:
|
|||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
result.generator = generator
|
||||
return result, tokenizer
|
||||
return result, result
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
|
@ -65,7 +65,7 @@ class Exllamav2Model:
|
|||
if len(to_ban) > 0:
|
||||
settings.disallow_tokens(self.tokenizer, to_ban)
|
||||
|
||||
ids = self.tokenizer.encode(prompt)
|
||||
ids = self.tokenizer.encode(prompt, add_bos=state['add_bos_token'])
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
initial_len = ids.shape[-1]
|
||||
|
||||
|
@ -104,7 +104,12 @@ class Exllamav2Model:
|
|||
return output
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string)
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
|
||||
def decode(self, string, **kwargs):
|
||||
return self.tokenizer.decode(string)[0]
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, int):
|
||||
ids = torch.tensor([[ids]])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue