From ef17da70af109f3b543a6d0550dd5a2e8c0067a3 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 20 Aug 2023 08:50:32 -0700 Subject: [PATCH] Fix ExLlama truncation --- modules/exllama.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/exllama.py b/modules/exllama.py index 30c3763..25bf0e5 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -111,7 +111,7 @@ class ExllamaModel: self.generator.end_beam_search() # Tokenizing the input - ids = self.generator.tokenizer.encode(prompt) + ids = self.generator.tokenizer.encode(prompt, max_seq_len=self.model.config.max_seq_len) ids = ids[:, -get_max_prompt_length(state):] if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - ids.shape[-1] @@ -141,7 +141,7 @@ class ExllamaModel: alpha = state['guidance_scale'] prompts = [prompt, state['negative_prompt'] or ''] - ids, mask = self.tokenizer.encode(prompts, return_mask=True) + ids, mask = self.tokenizer.encode(prompts, return_mask=True, max_seq_len=self.model.config.max_seq_len) if state['auto_max_new_tokens']: max_new_tokens = state['truncation_length'] - ids[0].shape[-1] else: @@ -181,7 +181,7 @@ class ExllamaModel: return output def encode(self, string, **kwargs): - return self.tokenizer.encode(string) + return self.tokenizer.encode(string, max_seq_len=self.model.config.max_seq_len) def decode(self, string, **kwargs): return self.tokenizer.decode(string)[0]