lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
24 changed files with 146 additions and 125 deletions
|
@ -22,6 +22,7 @@ params = {
|
|||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_access_control_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
|
@ -72,8 +73,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
if not no_debug:
|
||||
debug_msg(r_utf8)
|
||||
|
||||
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
||||
|
||||
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
||||
|
||||
error_resp = {
|
||||
'error': {
|
||||
'message': message,
|
||||
|
@ -84,10 +85,10 @@ class Handler(BaseHTTPRequestHandler):
|
|||
}
|
||||
if internal_message:
|
||||
print(internal_message)
|
||||
#error_resp['internal_message'] = internal_message
|
||||
# error_resp['internal_message'] = internal_message
|
||||
|
||||
self.return_json(error_resp, code)
|
||||
|
||||
|
||||
def openai_error_handler(func):
|
||||
def wrapper(self):
|
||||
try:
|
||||
|
@ -156,7 +157,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||
|
||||
|
||||
for resp in response:
|
||||
self.send_sse(resp)
|
||||
|
||||
|
@ -182,7 +183,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
instruction = body['instruction']
|
||||
input = body.get('input', '')
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
|
||||
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||
|
@ -205,7 +206,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
|
@ -225,15 +226,15 @@ class Handler(BaseHTTPRequestHandler):
|
|||
elif self.path == '/api/v1/token-count':
|
||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||
response = token_count(body['prompt'])
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/encode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
|
||||
response = token_encode(body['input'], encoding_format)
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/decode':
|
||||
|
@ -241,7 +242,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_decode(body['input'], encoding_format)
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue