Refactor chat functions (#2003)

This commit is contained in:
oobabooga 2023-05-11 15:37:04 -03:00 committed by GitHub
parent 4e9da22c58
commit 638c6a65a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 138 additions and 157 deletions

View file

@ -35,18 +35,15 @@ class Handler(BaseHTTPRequestHandler):
generate_params['stream'] = False
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
answer = a
response = json.dumps({
'results': [{
'text': answer if shared.is_chat() else answer[len(prompt):]
'text': answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))

View file

@ -26,19 +26,14 @@ async def _handle_connection(websocket, path):
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = len(prompt) if not shared.is_chat() else 0
skip_index = len(prompt)
message_num = 0
for a in generator:
to_send = ''
if isinstance(a, str):
to_send = a[skip_index:]
else:
to_send = a[0][skip_index:]
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,