Reorder some functions
This commit is contained in:
parent
e2fddd9584
commit
13ac55fa18
4 changed files with 33 additions and 34 deletions
|
@ -62,6 +62,22 @@ class Exllamav2Model:
|
|||
result.generator = generator
|
||||
return result, result
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.temperature = state['temperature']
|
||||
|
@ -114,19 +130,3 @@ class Exllamav2Model:
|
|||
pass
|
||||
|
||||
return output
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string, add_bos=True)
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
if isinstance(ids, list):
|
||||
ids = torch.tensor([ids])
|
||||
elif isinstance(ids, torch.Tensor) and ids.numel() == 1:
|
||||
ids = ids.view(1, -1)
|
||||
|
||||
return self.tokenizer.decode(ids)[0]
|
||||
|
||||
def get_logits(self, token_ids, **kwargs):
|
||||
self.cache.current_seq_len = 0
|
||||
self.model.forward(token_ids[:, :-1], self.cache, input_mask=None, preprocess_only=True)
|
||||
return self.model.forward(token_ids[:, -1:], self.cache, input_mask=None, **kwargs).float().cpu()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue