Add HTTPS support to APIs (openai and default) (#4270)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Jesus Alvarez 2023-10-12 21:31:13 -07:00 committed by GitHub
parent 43be1be598
commit ed66ca3cdf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 53 additions and 11 deletions

View file

@ -1,4 +1,5 @@
import json
import ssl
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
@ -14,6 +15,7 @@ from modules.text_generation import (
stop_everything_event
)
from modules.utils import get_available_models
from modules.logging_colors import logger
def get_model_info():
@ -199,11 +201,18 @@ class Handler(BaseHTTPRequestHandler):
def _run_server(port: int, share: bool = False, tunnel_id=str):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
ssl_certfile = shared.args.ssl_certfile
ssl_keyfile = shared.args.ssl_keyfile
ssl_verify = True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
def on_start(public_url: str):
print(f'Starting non-streaming server at public url {public_url}/api')
logger.info(f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
@ -211,8 +220,10 @@ def _run_server(port: int, share: bool = False, tunnel_id=str):
except Exception:
pass
else:
print(
f'Starting API at http://{address}:{port}/api')
if ssl_verify:
logger.info(f'Starting API at https://{address}:{port}/api')
else:
logger.info(f'Starting API at http://{address}:{port}/api')
server.serve_forever()