This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View file

@ -22,6 +22,7 @@ params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
}
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")
@ -72,8 +73,8 @@ class Handler(BaseHTTPRequestHandler):
if not no_debug:
debug_msg(r_utf8)
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
error_resp = {
'error': {
'message': message,
@ -84,10 +85,10 @@ class Handler(BaseHTTPRequestHandler):
}
if internal_message:
print(internal_message)
#error_resp['internal_message'] = internal_message
# error_resp['internal_message'] = internal_message
self.return_json(error_resp, code)
def openai_error_handler(func):
def wrapper(self):
try:
@ -156,7 +157,7 @@ class Handler(BaseHTTPRequestHandler):
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
else:
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
for resp in response:
self.send_sse(resp)
@ -182,7 +183,7 @@ class Handler(BaseHTTPRequestHandler):
instruction = body['instruction']
input = body.get('input', '')
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
response = OAIedits.edits(instruction, input, temperature, top_p)
@ -205,7 +206,7 @@ class Handler(BaseHTTPRequestHandler):
input = body.get('input', body.get('text', ''))
if not input:
raise InvalidRequestError("Missing required argument input", params='input')
if type(input) is str:
input = [input]
@ -225,15 +226,15 @@ class Handler(BaseHTTPRequestHandler):
elif self.path == '/api/v1/token-count':
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
response = token_count(body['prompt'])
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/encode':
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
encoding_format = body.get('encoding_format', '')
response = token_encode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
elif self.path == '/api/v1/token/decode':
@ -241,7 +242,7 @@ class Handler(BaseHTTPRequestHandler):
encoding_format = body.get('encoding_format', '')
response = token_decode(body['input'], encoding_format)
self.return_json(response, no_debug=True)
else: