Refactor chat functions (#2003)
This commit is contained in:
parent
4e9da22c58
commit
638c6a65a2
8 changed files with 138 additions and 157 deletions
|
|
@ -35,18 +35,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
if isinstance(a, str):
|
||||
answer = a
|
||||
else:
|
||||
answer = a[0]
|
||||
answer = a
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'text': answer if shared.is_chat() else answer[len(prompt):]
|
||||
'text': answer[len(prompt):]
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
|
|
|||
|
|
@ -26,19 +26,14 @@ async def _handle_connection(websocket, path):
|
|||
generate_params['stream'] = True
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = len(prompt) if not shared.is_chat() else 0
|
||||
skip_index = len(prompt)
|
||||
message_num = 0
|
||||
|
||||
for a in generator:
|
||||
to_send = ''
|
||||
if isinstance(a, str):
|
||||
to_send = a[skip_index:]
|
||||
else:
|
||||
to_send = a[0][skip_index:]
|
||||
|
||||
to_send = a[skip_index:]
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue