Add types to the encode/decode/token-count endpoints
This commit is contained in:
parent
f6ca9cfcdc
commit
1b69694fe9
5 changed files with 47 additions and 36 deletions
|
@ -3,34 +3,24 @@ from modules.text_generation import decode, encode
|
|||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
'length': len(tokens)
|
||||
}
|
||||
|
||||
|
||||
def token_encode(input, encoding_format):
|
||||
# if isinstance(input, list):
|
||||
def token_encode(input):
|
||||
tokens = encode(input)[0]
|
||||
if tokens.__class__.__name__ in ['Tensor', 'ndarray']:
|
||||
tokens = tokens.tolist()
|
||||
|
||||
return {
|
||||
'results': [{
|
||||
'tokens': tokens,
|
||||
'length': len(tokens),
|
||||
}]
|
||||
'tokens': tokens,
|
||||
'length': len(tokens),
|
||||
}
|
||||
|
||||
|
||||
def token_decode(tokens, encoding_format):
|
||||
# if isinstance(input, list):
|
||||
# if encoding_format == "base64":
|
||||
# tokens = base64_to_float_list(tokens)
|
||||
output = decode(tokens)[0]
|
||||
|
||||
def token_decode(tokens):
|
||||
output = decode(tokens)
|
||||
return {
|
||||
'results': [{
|
||||
'text': output
|
||||
}]
|
||||
'text': output
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue