Make OpenAI API the default API (#4430)

This commit is contained in:
oobabooga 2023-11-06 02:38:29 -03:00 committed by GitHub
parent 84d957ba62
commit ec17a5d2b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 769 additions and 1432 deletions

View file

@ -1,351 +1,255 @@
import json
import os
import ssl
import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
import extensions.openai.completions as OAIcompletions
import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import (
InvalidRequestError,
OpenAIError,
ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg
from modules import shared
import cgi
import speech_recognition as sr
import uvicorn
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from fastapi import Depends, FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.requests import Request
from fastapi.responses import JSONResponse
from modules import shared
from modules.logging_colors import logger
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from .typing import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
CompletionResponse,
to_dict
)
params = {
# default params
'port': 5001,
'embedding_device': 'cpu',
'embedding_model': 'all-mpnet-base-v2',
# optional params
'sd_webui_url': '',
'debug': 0
}
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
self.send_header("Access-Control-Allow-Credentials", "true")
self.send_header(
"Access-Control-Allow-Methods",
"GET,HEAD,OPTIONS,POST,PUT"
)
self.send_header(
"Access-Control-Allow-Headers",
"Origin, Accept, X-Requested-With, Content-Type, "
"Access-Control-Request-Method, Access-Control-Request-Headers, "
"Authorization"
)
def do_OPTIONS(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write("OK".encode('utf-8'))
def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
def start_sse(self):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
self.end_headers()
def send_sse(self, chunk: dict):
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
debug_msg(response[:-4])
self.wfile.write(response.encode('utf-8'))
app = FastAPI(dependencies=[Depends(verify_api_key)])
def end_sse(self):
response = 'data: [DONE]\r\n\r\n'
debug_msg(response[:-4])
self.wfile.write(response.encode('utf-8'))
# Configure CORS settings to allow all origins, methods, and headers
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["GET", "HEAD", "OPTIONS", "POST", "PUT"],
allow_headers=[
"Origin",
"Accept",
"X-Requested-With",
"Content-Type",
"Access-Control-Request-Method",
"Access-Control-Request-Headers",
"Authorization",
],
)
def return_json(self, ret: dict, code: int = 200, no_debug=False):
self.send_response(code)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
response = json.dumps(ret)
r_utf8 = response.encode('utf-8')
@app.options("/")
async def options_route():
return JSONResponse(content="OK")
self.send_header('Content-Length', str(len(r_utf8)))
self.end_headers()
self.wfile.write(r_utf8)
if not no_debug:
debug_msg(r_utf8)
@app.post('/v1/completions', response_model=CompletionResponse)
@app.post('/v1/generate', response_model=CompletionResponse)
async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
if request_data.stream:
async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
error_resp = {
'error': {
'message': message,
'code': code,
'type': error_type,
'param': param,
}
}
if internal_message:
print(error_type, message)
print(internal_message)
# error_resp['internal_message'] = internal_message
return EventSourceResponse(generator()) # SSE streaming
self.return_json(error_resp, code)
else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
def openai_error_handler(func):
def wrapper(self):
try:
func(self)
except InvalidRequestError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
except OpenAIError as e:
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
except Exception as e:
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
return wrapper
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
@openai_error_handler
def do_GET(self):
debug_msg(self.requestline)
debug_msg(self.headers)
if request_data.stream:
async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
yield {"data": json.dumps(resp)}
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
is_legacy = 'engines' in self.path
is_list = self.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
if is_legacy and not is_list:
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = self.path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
return EventSourceResponse(generator()) # SSE streaming
self.return_json(resp)
else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
return JSONResponse(response)
elif '/billing/usage' in self.path:
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
self.return_json({"total_usage": 0}, no_debug=True)
else:
self.send_error(404)
@app.get("/v1/models")
@app.get("/v1/engines")
async def handle_models(request: Request):
path = request.url.path
is_legacy = 'engines' in path
is_list = request.url.path.split('?')[0].split('#')[0] in ['/v1/engines', '/v1/models']
@openai_error_handler
def do_POST(self):
if is_legacy and not is_list:
model_name = path[path.find('/v1/engines/') + len('/v1/engines/'):]
resp = OAImodels.load_model(model_name)
elif is_list:
resp = OAImodels.list_models(is_legacy)
else:
model_name = path[len('/v1/models/'):]
resp = OAImodels.model_info(model_name)
if '/v1/audio/transcriptions' in self.path:
r = sr.Recognizer()
return JSONResponse(content=resp)
# Parse the form data
form = cgi.FieldStorage(
fp=self.rfile,
headers=self.headers,
environ={'REQUEST_METHOD': 'POST', 'CONTENT_TYPE': self.headers['Content-Type']}
)
audio_file = form['file'].file
audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""}
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
self.return_json(transcription, no_debug=True)
return
debug_msg(self.requestline)
debug_msg(self.headers)
@app.get('/v1/billing/usage')
def handle_billing_usage():
'''
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
'''
return JSONResponse(content={"total_usage": 0})
content_length = self.headers.get('Content-Length')
transfer_encoding = self.headers.get('Transfer-Encoding')
if content_length:
body = json.loads(self.rfile.read(int(content_length)).decode('utf-8'))
elif transfer_encoding == 'chunked':
chunks = []
while True:
chunk_size = int(self.rfile.readline(), 16) # Read the chunk size
if chunk_size == 0:
break # End of chunks
chunks.append(self.rfile.read(chunk_size))
self.rfile.readline() # Consume the trailing newline after each chunk
body = json.loads(b''.join(chunks).decode('utf-8'))
else:
self.send_response(400, "Bad Request: Either Content-Length or Transfer-Encoding header expected.")
self.end_headers()
return
@app.post('/v1/audio/transcriptions')
async def handle_audio_transcription(request: Request):
r = sr.Recognizer()
debug_msg(body)
form = await request.form()
audio_file = await form["file"].read()
audio_data = AudioSegment.from_file(audio_file)
if '/completions' in self.path or '/generate' in self.path:
# Convert AudioSegment to raw data
raw_data = audio_data.raw_data
if not shared.model:
raise ServiceUnavailableError("No model loaded.")
# Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whipser_language = form.getvalue('language', None)
whipser_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
is_legacy = '/generate' in self.path
is_streaming = body.get('stream', False)
transcription = {"text": ""}
if is_streaming:
self.start_sse()
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whipser_language, model=whipser_model)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e:
print("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError"
response = []
if 'chat' in self.path:
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
return JSONResponse(content=transcription)
for resp in response:
self.send_sse(resp)
self.end_sse()
@app.post('/v1/images/generations')
async def handle_image_generation(request: Request):
else:
response = ''
if 'chat' in self.path:
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.completions(body, is_legacy=is_legacy)
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
self.return_json(response)
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
elif '/edits' in self.path:
# deprecated
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
return JSONResponse(response)
if not shared.model:
raise ServiceUnavailableError("No model loaded.")
req_params = get_default_req_params()
@app.post("/v1/embeddings")
async def handle_embeddings(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
instruction = body['instruction']
input = body.get('input', '')
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
input = body.get('input', body.get('text', ''))
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAIedits.edits(instruction, input, temperature, top_p)
if type(input) is str:
input = [input]
self.return_json(response)
response = OAIembeddings.embeddings(input, encoding_format)
return JSONResponse(response)
elif '/images/generations' in self.path:
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt']
size = default(body, 'size', '1024x1024')
response_format = default(body, 'response_format', 'url') # or b64_json
n = default(body, 'n', 1) # ignore the batch limits of max 10
@app.post("/v1/moderations")
async def handle_moderations(request: Request):
body = await request.json()
input = body["input"]
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
response = OAImoderations.moderations(input)
return JSONResponse(response)
self.return_json(response, no_debug=True)
elif '/embeddings' in self.path:
encoding_format = body.get('encoding_format', '')
@app.post("/api/v1/token-count")
async def handle_token_count(request: Request):
body = await request.json()
response = token_count(body['prompt'])
return JSONResponse(response)
input = body.get('input', body.get('text', ''))
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
if type(input) is str:
input = [input]
@app.post("/api/v1/token/encode")
async def handle_token_encode(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
response = token_encode(body["input"], encoding_format)
return JSONResponse(response)
response = OAIembeddings.embeddings(input, encoding_format)
self.return_json(response, no_debug=True)
elif '/moderations' in self.path:
input = body['input']
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
response = OAImoderations.moderations(input)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
response = token_count(body['prompt'])
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/encode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_encode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/decode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_decode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
else:
self.send_error(404)
@app.post("/api/v1/token/decode")
async def handle_token_decode(request: Request):
body = await request.json()
encoding_format = body.get("encoding_format", "")
response = token_decode(body["input"], encoding_format)
return JSONResponse(response, no_debug=True)
def run_server():
port = int(os.environ.get('OPENEDAI_PORT', params.get('port', 5001)))
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', port)
server = ThreadingHTTPServer(server_addr, Handler)
ssl_certfile=os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile=os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
ssl_verify=True if (ssl_keyfile and ssl_certfile) else False
if ssl_verify:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ssl_certfile, ssl_keyfile)
server.socket = context.wrap_socket(server.socket, server_side=True)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(port, port + 1)
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')
server_addr = '0.0.0.0' if shared.args.listen else '127.0.0.1'
port = int(os.environ.get('OPENEDAI_PORT', shared.args.api_port))
ssl_certfile = os.environ.get('OPENEDAI_CERT_PATH', shared.args.ssl_certfile)
ssl_keyfile = os.environ.get('OPENEDAI_KEY_PATH', shared.args.ssl_keyfile)
if shared.args.public_api:
def on_start(public_url: str):
logger.info(f'OpenAI compatible API URL:\n\n{public_url}/v1\n')
_start_cloudflared(port, shared.args.public_api_id, max_attempts=3, on_start=on_start)
else:
if ssl_verify:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=https://{server_addr[0]}:{server_addr[1]}/v1')
if ssl_keyfile and ssl_certfile:
logger.info(f'OpenAI compatible API URL:\n\nhttps://{server_addr}:{port}/v1\n')
else:
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
server.serve_forever()
logger.info(f'OpenAI compatible API URL:\n\nhttp://{server_addr}:{port}/v1\n')
if shared.args.api_key:
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
def setup():