Add chat API (#2233)
This commit is contained in:
parent
2aa01e2303
commit
c5af549d4b
8 changed files with 317 additions and 67 deletions
|
@ -4,6 +4,7 @@ from threading import Thread
|
|||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
|
@ -46,7 +47,37 @@ class Handler(BaseHTTPRequestHandler):
|
|||
'text': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/chat':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
user_input = body['user_input']
|
||||
history = body['history']
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = False
|
||||
|
||||
generator = generate_chat_reply(
|
||||
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
answer = history
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'history': answer
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
|
@ -58,6 +89,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
'tokens': len(tokens)
|
||||
}]
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
|
|
@ -2,6 +2,7 @@ import extensions.api.blocking_api as blocking_api
|
|||
import extensions.api.streaming_api as streaming_api
|
||||
from modules import shared
|
||||
|
||||
|
||||
def setup():
|
||||
blocking_api.start_server(shared.args.api_blocking_port, share=shared.args.public_api)
|
||||
streaming_api.start_server(shared.args.api_streaming_port, share=shared.args.public_api)
|
||||
|
|
|
@ -6,6 +6,7 @@ from websockets.server import serve
|
|||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.chat import generate_chat_reply
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
@ -13,42 +14,72 @@ PATH = '/api/v1/stream'
|
|||
|
||||
async def _handle_connection(websocket, path):
|
||||
|
||||
if path != PATH:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
if path == '/api/v1/stream':
|
||||
async for message in websocket:
|
||||
message = json.loads(message)
|
||||
|
||||
async for message in websocket:
|
||||
message = json.loads(message)
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generate_params['stream'] = True
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = 0
|
||||
message_num = 0
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = 0
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
to_send = a[skip_index:]
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
for a in generator:
|
||||
to_send = a[skip_index:]
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
elif path == '/api/v1/chat-stream':
|
||||
async for message in websocket:
|
||||
body = json.loads(message)
|
||||
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
user_input = body['user_input']
|
||||
history = body['history']
|
||||
generate_params = build_parameters(body, chat=True)
|
||||
generate_params['stream'] = True
|
||||
regenerate = body.get('regenerate', False)
|
||||
_continue = body.get('_continue', False)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
generator = generate_chat_reply(
|
||||
user_input, history, generate_params, regenerate=regenerate, _continue=_continue, loading_message=False)
|
||||
|
||||
message_num = 0
|
||||
for a in generator:
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'history': a
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
else:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
|
|
|
@ -3,18 +3,11 @@ import traceback
|
|||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from modules.text_generation import get_encoded_length
|
||||
from modules import shared
|
||||
from modules.chat import load_character
|
||||
|
||||
|
||||
def build_parameters(body):
|
||||
prompt = body['prompt']
|
||||
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and get_encoded_length('\n'.join(prompt_lines)) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
def build_parameters(body, chat=False):
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
|
@ -33,13 +26,34 @@ def build_parameters(body):
|
|||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': bool(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
if chat:
|
||||
character = body.get('character')
|
||||
instruction_template = body.get('instruction_template')
|
||||
name1, name2, _, greeting, context, _ = load_character(character, shared.settings['name1'], shared.settings['name2'], instruct=False)
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character(instruction_template, '', '', instruct=True)
|
||||
generate_params.update({
|
||||
'stop_at_newline': bool(body.get('stop_at_newline', shared.settings['stop_at_newline'])),
|
||||
'chat_prompt_size': int(body.get('chat_prompt_size', shared.settings['chat_prompt_size'])),
|
||||
'chat_generation_attempts': int(body.get('chat_generation_attempts', shared.settings['chat_generation_attempts'])),
|
||||
'mode': str(body.get('mode', 'chat')),
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': context_instruct,
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||
})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@ import textwrap
|
|||
|
||||
import gradio as gr
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from modules import chat, shared
|
||||
|
||||
from .chromadb import add_chunks_to_collector, make_collector
|
||||
from .download_urls import download_urls
|
||||
|
||||
|
||||
params = {
|
||||
'chunk_count': 5,
|
||||
'chunk_length': 700,
|
||||
|
@ -40,6 +40,7 @@ def feed_data_into_collector(corpus, chunk_len, chunk_sep):
|
|||
data_chunks = [x for y in data_chunks for x in y]
|
||||
else:
|
||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||
|
||||
cumulative += f"{len(data_chunks)} chunks have been found.\n\nAdding the chunks to the database...\n\n"
|
||||
yield cumulative
|
||||
add_chunks_to_collector(data_chunks, collector)
|
||||
|
@ -124,7 +125,10 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
logging.warning(f'Adding the following new context:\n{additional_context}')
|
||||
state['context'] = state['context'].strip() + '\n' + additional_context
|
||||
state['history'] = [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids]
|
||||
kwargs['history'] = {
|
||||
'internal': [shared.history['internal'][i] for i in range(hist_size) if i not in best_ids],
|
||||
'visible': ''
|
||||
}
|
||||
except RuntimeError:
|
||||
logging.error("Couldn't query the database, moving on...")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue