Add chat API (#2233)

This commit is contained in:
oobabooga 2023-05-20 18:42:17 -03:00 committed by GitHub
parent 2aa01e2303
commit c5af549d4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 317 additions and 67 deletions

View file

@ -4,6 +4,7 @@ from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import encode, generate_reply
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
'text': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/chat':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
user_input = body['user_input']
history = body['history']
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = False
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
answer = history
for a in generator:
answer = a
response = json.dumps({
'results': [{
'history': answer
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)

View file

@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
import extensions.api.streaming_api as streaming_api
from modules import shared
def setup():
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)

View file

@ -6,6 +6,7 @@ from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'
@ -13,42 +14,72 @@ PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
if path != PATH:
print(f'Streaming api: unknown path: {path}')
return
if path == '/api/v1/stream':
async for message in websocket:
message = json.loads(message)
async for message in websocket:
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generate_params['stream'] = True
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
# As we stream, only send the new bytes.
skip_index = 0
message_num = 0
for a in generator:
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
for a in generator:
to_send = a[skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
'event': 'stream_end',
'message_num': message_num
}))
await asyncio.sleep(0)
elif path == '/api/v1/chat-stream':
async for message in websocket:
body = json.loads(message)
skip_index += len(to_send)
message_num += 1
user_input = body['user_input']
history = body['history']
generate_params = build_parameters(body, chat=True)
generate_params['stream'] = True
regenerate = body.get('regenerate', False)
_continue = body.get('_continue', False)
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
generator = generate_chat_reply(
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
message_num = 0
for a in generator:
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'history': a
}))
await asyncio.sleep(0)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
else:
print(f'Streaming api: unknown path: {path}')
return
async def _run(host: str, port: int):

View file

@ -3,18 +3,11 @@ import traceback
from threading import Thread
from typing import Callable, Optional
from modules.text_generation import get_encoded_length
from modules import shared
from modules.chat import load_character
def build_parameters(body):
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
def build_parameters(body, chat=False):
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
@ -33,13 +26,34 @@ def build_parameters(body):
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
'add_bos_token': bool(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)),
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
if chat:
character = body.get('character')
instruction_template = body.get('instruction_template')
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
generate_params.update({
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
'mode': str(body.get('mode', 'chat')),
'name1': name1,
'name2': name2,
'context': context,
'greeting': greeting,
'name1_instruct': name1_instruct,
'name2_instruct': name2_instruct,
'context_instruct': context_instruct,
'turn_template': turn_template,
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
})
return generate_params

View file

@ -4,12 +4,12 @@ import textwrap
import gradio as gr
from bs4 import BeautifulSoup
from modules import chat, shared
from .chromadb import add_chunks_to_collector, make_collector
from .download_urls import download_urls
params = {
'chunk_count': 5,
'chunk_length': 700,
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
data_chunks = [x for y in data_chunks for x in y]
else:
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
yield cumulative
add_chunks_to_collector(data_chunks, collector)
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
logging.warning(f'Adding the following new context:\n{additional_context}')
state['context'] = state['context'].strip() + '\n' + additional_context
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids]
kwargs['history'] = {
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
'visible': ''
}
except RuntimeError:
logging.error("Couldn't query the database, moving on...")