Add types to the encode/decode/token-count endpoints

This commit is contained in:
oobabooga 2023-11-07 19:05:36 -08:00
parent f6ca9cfcdc
commit 1b69694fe9
5 changed files with 47 additions and 36 deletions

View file

@ -101,7 +101,7 @@ class LlamaCppModel:
return self.model.tokenize(string)
def decode(self, ids):
def decode(self, ids, **kwargs):
return self.model.detokenize(ids).decode('utf-8')
def get_logits(self, tokens):

View file

@ -145,7 +145,7 @@ def decode(output_ids, skip_special_tokens=True):
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
return shared.tokenizer.decode(output_ids, skip_special_tokens)
return shared.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
def get_encoded_length(prompt):