extensions/openai: load extension settings via settings.yaml (#3953)

This commit is contained in:
Chenxiao Wang 2023-09-18 09:39:29 +08:00 committed by GitHub
parent cc8eda298a
commit 347aed4254
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 16 deletions

View file

@ -25,10 +25,16 @@ import speech_recognition as sr
from pydub import AudioSegment
params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
# default params
'port': 5001,
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
# optional params
'sd_webui_url': '',
'debug': 0
}
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
@ -251,7 +257,7 @@ class Handler(BaseHTTPRequestHandler):
self.return_json(response)
elif '/images/generations' in self.path:
if 'SD_WEBUI_URL' not in os.environ:
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt']
@ -313,12 +319,13 @@ class Handler(BaseHTTPRequestHandler):
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
public_url = _run_cloudflared(port, port + 1)
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')