Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View file

@ -6,6 +6,7 @@ from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
@ -13,42 +14,72 @@ PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
if path != PATH:
print(f'Streaming api: unknown path: {path}')
return
if path == '/api/v1/stream':
async for message in websocket:
message = json.loads(message)
async for message in websocket:
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
for a in generator:
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
'event': 'stream_end',
'message_num': message_num
}))
await asyncio.sleep(0)
elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
skip_index += len(to_send)
message_num += 1
user_input = body['user_input']
history = body['history']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):