extensions/openai: Fixes for: embeddings, tokens, better errors. +Docs update, +Images, +logit_bias/logprobs, +more. (#3122)
This commit is contained in:
parent
1141987a0d
commit
90a4ab631c
10 changed files with 215 additions and 143 deletions
|
@ -55,11 +55,13 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
def send_sse(self, chunk: dict):
|
||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||
debug_msg(response)
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
def end_sse(self):
|
||||
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
|
||||
response = 'data: [DONE]\r\n\r\n'
|
||||
debug_msg(response[:-4])
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||
self.send_response(code)
|
||||
|
@ -84,6 +86,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
}
|
||||
}
|
||||
if internal_message:
|
||||
print(error_type, message)
|
||||
print(internal_message)
|
||||
# error_resp['internal_message'] = internal_message
|
||||
|
||||
|
@ -93,12 +96,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||
def wrapper(self):
|
||||
try:
|
||||
func(self)
|
||||
except ServiceUnavailableError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||
except InvalidRequestError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, e.param, internal_message=e.internal_message)
|
||||
except OpenAIError as e:
|
||||
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||
self.openai_error(e.message, e.code, e.__class__.__name__, internal_message=e.internal_message)
|
||||
except Exception as e:
|
||||
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||
|
||||
|
@ -143,8 +144,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
if '/completions' in self.path or '/generate' in self.path:
|
||||
|
||||
if not shared.model:
|
||||
self.openai_error("No model loaded.")
|
||||
return
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
is_legacy = '/generate' in self.path
|
||||
is_streaming = body.get('stream', False)
|
||||
|
@ -176,8 +176,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
# deprecated
|
||||
|
||||
if not shared.model:
|
||||
self.openai_error("No model loaded.")
|
||||
return
|
||||
raise ServiceUnavailableError("No model loaded.")
|
||||
|
||||
req_params = get_default_req_params()
|
||||
|
||||
|
@ -190,7 +189,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
self.return_json(response)
|
||||
|
||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||
elif '/images/generations' in self.path:
|
||||
if not 'SD_WEBUI_URL' in os.environ:
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
prompt = body['prompt']
|
||||
size = default(body, 'size', '1024x1024')
|
||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||
|
@ -256,11 +258,11 @@ def run_server():
|
|||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE={public_url}/v1')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
else:
|
||||
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
print(f'OpenAI compatible API ready at: OPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
|
||||
|
||||
server.serve_forever()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue