Add types to the encode/decode/token-count endpoints

This commit is contained in:
oobabooga 2023-11-07 19:05:36 -08:00
parent f6ca9cfcdc
commit 1b69694fe9
5 changed files with 47 additions and 36 deletions

View file

@ -3,34 +3,24 @@ from modules.text_generation import decode, encode
def token_count(prompt):
tokens = encode(prompt)[0]
return {
'results': [{
'tokens': len(tokens)
}]
'length': len(tokens)
}
def token_encode(input, encoding_format):
# if isinstance(input, list):
def token_encode(input):
tokens = encode(input)[0]
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
tokens = tokens.tolist()
return {
'results': [{
'tokens': tokens,
'length': len(tokens),
}]
'tokens': tokens,
'length': len(tokens),
}
def token_decode(tokens, encoding_format):
# if isinstance(input, list):
# if encoding_format == "base64":
# tokens = base64_to_float_list(tokens)
output = decode(tokens)[0]
def token_decode(tokens):
output = decode(tokens)
return {
'results': [{
'text': output
}]
'text': output
}