Make OpenAI API the default API (#4430)
This commit is contained in:
parent
84d957ba62
commit
ec17a5d2b7
22 changed files with 769 additions and 1432 deletions
|
@ -1,351 +1,255 @@
|
|||
import json
|
||||
import os
|
||||
import ssl
|
||||
import traceback
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import extensions.openai.completions as OAIcompletions
|
||||
import extensions.openai.edits as OAIedits
|
||||
import extensions.openai.embeddings as OAIembeddings
|
||||
import extensions.openai.images as OAIimages
|
||||
import extensions.openai.models as OAImodels
|
||||
import extensions.openai.moderations as OAImoderations
|
||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||
from extensions.openai.errors import (
|
||||
InvalidRequestError,
|
||||
OpenAIError,
|
||||
ServiceUnavailableError
|
||||
)
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
|
||||
import cgi
|
||||
import speech_recognition as sr
|
||||
import uvicorn
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from fastapi import Depends, FastAPI, Header, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from pydub import AudioSegment
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
to_dict
|
||||
)
|
||||
|
||||
params = {
|
||||
# default params
|
||||
'port': 5001,
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'all-mpnet-base-v2',
|
||||
|
||||
# optional params
|
||||
'sd_webui_url': '',
|
||||
'debug': 0
|
||||
}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_access_control_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
self.send_header("Access-Control-Allow-Credentials", "true")
|
||||
self.send_header(
|
||||
"Access-Control-Allow-Methods",
|
||||
"GET,HEAD,OPTIONS,POST,PUT"
|
||||
)
|
||||
self.send_header(
|
||||
"Access-Control-Allow-Headers",
|
||||
"Origin, Accept, X-Requested-With, Content-Type, "
|
||||
"Access-Control-Request-Method, Access-Control-Request-Headers, "
|
||||
"Authorization"
|
||||
)
|
||||
|
||||
def do_OPTIONS(self):
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
self.wfile.write("OK".encode('utf-8'))
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
expected_api_key = shared.args.api_key
|
||||
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
def start_sse(self):
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'text/event-stream')
|
||||
self.send_header('Cache-Control', 'no-cache')
|
||||
# self.send_header('Connection', 'keep-alive')
|
||||
self.end_headers()
|
||||
|
||||
def send_sse(self, chunk: dict):
|
||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
app = FastAPI(dependencies=[Depends(verify_api_key)])
|
||||
|
||||
def end_sse(self):
|
||||
response = 'data: [DONE]\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
# Configure CORS settings to allow all origins, methods, and headers
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
|
||||
allow_headers=[
|
||||
"Origin",
|
||||
"Accept",
|
||||
"X-Requested-With",
|
||||
"Content-Type",
|
||||
"Access-Control-Request-Method",
|
||||
"Access-Control-Request-Headers",
|
||||
"Authorization",
|
||||
],
|
||||
)
|
||||
|
||||
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||
self.send_response(code)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
|
||||
response = json.dumps(ret)
|
||||
r_utf8 = response.encode('utf-8')
|
||||
@app.options("/")
|
||||
async def options_route():
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
self.send_header('Content-Length', str(len(r_utf8)))
|
||||
self.end_headers()
|
||||
|
||||
self.wfile.write(r_utf8)
|
||||
if not no_debug:
|
||||
debug_msg(r_utf8)
|
||||
@app.post('/v1/completions', response_model=CompletionResponse)
|
||||
@app.post('/v1/generate', response_model=CompletionResponse)
|
||||
async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
error_resp = {
|
||||
'error': {
|
||||
'message': message,
|
||||
'code': code,
|
||||
'type': error_type,
|
||||
'param': param,
|
||||
}
|
||||
}
|
||||
if internal_message:
|
||||
print(error_type, message)
|
||||
print(internal_message)
|
||||
# error_resp['internal_message'] = internal_message
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
self.return_json(error_resp, code)
|
||||
else:
|
||||
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
return JSONResponse(response)
|
||||
|
||||
def openai_error_handler(func):
|
||||
def wrapper(self):
|
||||
try:
|
||||
func(self)
|
||||
except InvalidRequestError as e:
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
|
||||
except OpenAIError as e:
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
|
||||
except Exception as e:
|
||||
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||
|
||||
return wrapper
|
||||
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
|
||||
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
|
||||
@openai_error_handler
|
||||
def do_GET(self):
|
||||
debug_msg(self.requestline)
|
||||
debug_msg(self.headers)
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
yield {"data": json.dumps(resp)}
|
||||
|
||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||
is_legacy = 'engines' in self.path
|
||||
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||
if is_legacy and not is_list:
|
||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
resp = OAImodels.load_model(model_name)
|
||||
elif is_list:
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
else:
|
||||
model_name = self.path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info(model_name)
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
self.return_json(resp)
|
||||
else:
|
||||
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
return JSONResponse(response)
|
||||
|
||||
elif '/billing/usage' in self.path:
|
||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||
self.return_json({"total_usage": 0}, no_debug=True)
|
||||
|
||||
else:
|
||||
self.send_error(404)
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/engines")
|
||||
async def handle_models(request: Request):
|
||||
path = request.url.path
|
||||
is_legacy = 'engines' in path
|
||||
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
|
||||
|
||||
@openai_error_handler
|
||||
def do_POST(self):
|
||||
if is_legacy and not is_list:
|
||||
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
resp = OAImodels.load_model(model_name)
|
||||
elif is_list:
|
||||
resp = OAImodels.list_models(is_legacy)
|
||||
else:
|
||||
model_name = path[len('/v1/models/'):]
|
||||
resp = OAImodels.model_info(model_name)
|
||||
|
||||
if '/v1/audio/transcriptions' in self.path:
|
||||
r = sr.Recognizer()
|
||||
return JSONResponse(content=resp)
|
||||
|
||||
# Parse the form data
|
||||
form = cgi.FieldStorage(
|
||||
fp=self.rfile,
|
||||
headers=self.headers,
|
||||
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
|
||||
)
|
||||
|
||||
audio_file = form['file'].file
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
# Convert AudioSegment to raw data
|
||||
raw_data = audio_data.raw_data
|
||||
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whipser_language = form.getvalue('language', None)
|
||||
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
transcription = {"text": ""}
|
||||
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
self.return_json(transcription, no_debug=True)
|
||||
return
|
||||
|
||||
debug_msg(self.requestline)
|
||||
debug_msg(self.headers)
|
||||
@app.get('/v1/billing/usage')
|
||||
def handle_billing_usage():
|
||||
'''
|
||||
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||
'''
|
||||
return JSONResponse(content={"total_usage": 0})
|
||||
|
||||
content_length = self.headers.get('Content-Length')
|
||||
transfer_encoding = self.headers.get('Transfer-Encoding')
|
||||
|
||||
if content_length:
|
||||
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
|
||||
elif transfer_encoding == 'chunked':
|
||||
chunks = []
|
||||
while True:
|
||||
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
|
||||
if chunk_size == 0:
|
||||
break # End of chunks
|
||||
chunks.append(self.rfile.read(chunk_size))
|
||||
self.rfile.readline() # Consume the trailing newline after each chunk
|
||||
body = json.loads(b''.join(chunks).decode('utf-8'))
|
||||
else:
|
||||
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
|
||||
self.end_headers()
|
||||
return
|
||||
@app.post('/v1/audio/transcriptions')
|
||||
async def handle_audio_transcription(request: Request):
|
||||
r = sr.Recognizer()
|
||||
|
||||
debug_msg(body)
|
||||
form = await request.form()
|
||||
audio_file = await form["file"].read()
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
if '/completions' in self.path or '/generate' in self.path:
|
||||
# Convert AudioSegment to raw data
|
||||
raw_data = audio_data.raw_data
|
||||
|
||||
if not shared.model:
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whipser_language = form.getvalue('language', None)
|
||||
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
is_legacy = '/generate' in self.path
|
||||
is_streaming = body.get('stream', False)
|
||||
transcription = {"text": ""}
|
||||
|
||||
if is_streaming:
|
||||
self.start_sse()
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
response = []
|
||||
if 'chat' in self.path:
|
||||
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||
return JSONResponse(content=transcription)
|
||||
|
||||
for resp in response:
|
||||
self.send_sse(resp)
|
||||
|
||||
self.end_sse()
|
||||
@app.post('/v1/images/generations')
|
||||
async def handle_image_generation(request: Request):
|
||||
|
||||
else:
|
||||
response = ''
|
||||
if 'chat' in self.path:
|
||||
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
self.return_json(response)
|
||||
body = await request.json()
|
||||
prompt = body['prompt']
|
||||
size = body.get('size', '1024x1024')
|
||||
response_format = body.get('response_format', 'url') # or b64_json
|
||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||
|
||||
elif '/edits' in self.path:
|
||||
# deprecated
|
||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
return JSONResponse(response)
|
||||
|
||||
if not shared.model:
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
req_params = get_default_req_params()
|
||||
@app.post("/v1/embeddings")
|
||||
async def handle_embeddings(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
|
||||
instruction = body['instruction']
|
||||
input = body.get('input', '')
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
self.return_json(response)
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
elif '/images/generations' in self.path:
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
prompt = body['prompt']
|
||||
size = default(body, 'size', '1024x1024')
|
||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
||||
@app.post("/v1/moderations")
|
||||
async def handle_moderations(request: Request):
|
||||
body = await request.json()
|
||||
input = body["input"]
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
response = OAImoderations.moderations(input)
|
||||
return JSONResponse(response)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif '/embeddings' in self.path:
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
@app.post("/api/v1/token-count")
|
||||
async def handle_token_count(request: Request):
|
||||
body = await request.json()
|
||||
response = token_count(body['prompt'])
|
||||
return JSONResponse(response)
|
||||
|
||||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
@app.post("/api/v1/token/encode")
|
||||
async def handle_token_encode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
response = token_encode(body["input"], encoding_format)
|
||||
return JSONResponse(response)
|
||||
|
||||
response = OAIembeddings.embeddings(input, encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif '/moderations' in self.path:
|
||||
input = body['input']
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
response = OAImoderations.moderations(input)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||
response = token_count(body['prompt'])
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/encode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_encode(body['input'], encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/decode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_decode(body['input'], encoding_format)
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
else:
|
||||
self.send_error(404)
|
||||
@app.post("/api/v1/token/decode")
|
||||
async def handle_token_decode(request: Request):
|
||||
body = await request.json()
|
||||
encoding_format = body.get("encoding_format", "")
|
||||
response = token_decode(body["input"], encoding_format)
|
||||
return JSONResponse(response, no_debug=True)
|
||||
|
||||
|
||||
def run_server():
|
||||
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
|
||||
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
|
||||
if ssl_verify:
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(ssl_certfile, ssl_keyfile)
|
||||
server.socket = context.wrap_socket(server.socket, server_side=True)
|
||||
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
|
||||
|
||||
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
|
||||
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
|
||||
|
||||
if shared.args.public_api:
|
||||
def on_start(public_url: str):
|
||||
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
|
||||
|
||||
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
|
||||
else:
|
||||
if ssl_verify:
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
if ssl_keyfile and ssl_certfile:
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
|
||||
else:
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
|
||||
server.serve_forever()
|
||||
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
|
||||
|
||||
if shared.args.api_key:
|
||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||
|
||||
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||
|
||||
|
||||
def setup():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue