Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View file

@ -4,6 +4,7 @@ from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)