Add HTTPS support to APIs (openai and default) (#4270)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Jesus Alvarez 2023-10-12 21:31:13 -07:00 committed by GitHub
parent 43be1be598
commit ed66ca3cdf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 11 deletions

View file

@ -1,7 +1,10 @@
import asyncio
import json
import ssl
from threading import Thread
from websockets.server import serve
from extensions.api.util import (
build_parameters,
try_start_cloudflared,
@ -10,7 +13,7 @@ from extensions.api.util import (
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
from websockets.server import serve
from modules.logging_colors import logger
PATH = '/api/v1/stream'
@ -98,16 +101,28 @@ async def _handle_connection(websocket, path):
async def _run(host: str, port: int):
async with serve(_handle_connection, host, port, ping_interval=None):
await asyncio.Future() # run forever
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
else:
context = None
async with serve(_handle_connection, host, port, ping_interval=None, ssl=context):
await asyncio.Future() # Run the server forever
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
print(f'Starting streaming server at public url {public_url}{PATH}')
logger.info(f'Starting streaming server at public url {public_url}{PATH}')
if share:
try:
@ -115,7 +130,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
except Exception as e:
print(e)
else:
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
if ssl_verify:
logger.info(f'Starting streaming server at wss://{address}:{port}{PATH}')
else:
logger.info(f'Starting streaming server at ws://{address}:{port}{PATH}')
asyncio.run(_run(host=address, port=port))