parent
5ac4e4da8b
commit
70b088843d
2 changed files with 112 additions and 60 deletions
|
@ -4,7 +4,7 @@ from threading import Thread
|
|||
|
||||
from websockets.server import serve
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import generate_reply
|
||||
|
@ -12,72 +12,82 @@ from modules.text_generation import generate_reply
|
|||
PATH = '/api/v1/stream'
|
||||
|
||||
|
||||
@with_api_lock
|
||||
async def _handle_stream_message(websocket, message):
|
||||
message = json.loads(message)
|
||||
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = 0
|
||||
message_num = 0
|
||||
|
||||
for a in generator:
|
||||
to_send = a[skip_index:]
|
||||
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
|
||||
@with_api_lock
|
||||
async def _handle_chat_stream_message(websocket, message):
|
||||
body = json.loads(message)
|
||||
|
||||
user_input = body['user_input']
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = True
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'history': a
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
|
||||
if path == '/api/v1/stream':
|
||||
async for message in websocket:
|
||||
message = json.loads(message)
|
||||
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = 0
|
||||
message_num = 0
|
||||
|
||||
for a in generator:
|
||||
to_send = a[skip_index:]
|
||||
if to_send is None or chr(0xfffd) in to_send: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
await _handle_stream_message(websocket, message)
|
||||
|
||||
elif path == '/api/v1/chat-stream':
|
||||
async for message in websocket:
|
||||
body = json.loads(message)
|
||||
|
||||
user_input = body['user_input']
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = True
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'history': a
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
await _handle_chat_stream_message(websocket, message)
|
||||
|
||||
else:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue