[extension/openai] add edits & image endpoints & fix prompt return in non --chat modes (#1935)

This commit is contained in:
matatonic 2023-05-11 10:06:39 -04:00 committed by GitHub
parent 23d3f6909a
commit 309b72e549
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 203 additions and 11 deletions

View file

@ -2,6 +2,8 @@ import base64
import json
import os
import time
import requests
import yaml
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
@ -48,6 +50,31 @@ def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))
def deduce_template():
# Alpaca is verbose so a good default prompt
default_template = (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
)
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']:
return default_template
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
return instruct.get('context', '') + template[:template.find('<|bot-message|>')]
except:
return default_template
def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32")
@ -120,11 +147,20 @@ class Handler(BaseHTTPRequestHandler):
self.send_error(404)
def do_POST(self):
# ... haaack.
is_chat = shared.args.chat
try:
shared.args.chat = True
self.do_POST_wrap()
finally:
shared.args.chat = is_chat
def do_POST_wrap(self):
if debug:
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if debug:
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
if debug:
print(body)
@ -150,7 +186,7 @@ class Handler(BaseHTTPRequestHandler):
truncation_length = default(shared.settings, 'truncation_length', 2048)
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it., the default for chat is "inf"
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
@ -440,6 +476,129 @@ class Handler(BaseHTTPRequestHandler):
else:
resp[resp_list][0]["text"] = answer
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/edits' in self.path:
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
created_time = int(time.time())
# Using Alpaca format, this may work with other models too.
instruction = body['instruction']
input = body.get('input', '')
instruction_template = deduce_template()
edit_task = instruction_template.format(instruction=instruction, input=input)
truncation_length = default(shared.settings, 'truncation_length', 2048)
token_count = len(encode(edit_task)[0])
max_tokens = truncation_length - token_count
req_params = {
'max_new_tokens': max_tokens,
'temperature': clamp(default(body, 'temperature', 1.0), 0.001, 1.999),
'top_p': clamp(default(body, 'top_p', 1.0), 0.001, 1.0),
'top_k': 1,
'repetition_penalty': 1.18,
'encoder_repetition_penalty': 1.0,
'suffix': None,
'stream': False,
'echo': False,
'seed': shared.settings.get('seed', -1),
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
'truncation_length': truncation_length,
'add_bos_token': shared.settings.get('add_bos_token', True),
'do_sample': True,
'typical_p': 1.0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0.0,
'length_penalty': 1,
'early_stopping': False,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': [],
}
if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=standard_stopping_strings)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
completion_token_count = len(encode(answer)[0])
resp = {
"object": "edit",
"created": created_time,
"choices": [{
"text": answer,
"index": 0,
}],
"usage": {
"prompt_tokens": token_count,
"completion_tokens": completion_token_count,
"total_tokens": token_count + completion_token_count
}
}
if debug:
print({'answer': answer, 'completion_token_count': completion_token_count})
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
# Will probably work best with the stock SD models.
# SD configuration is beyond the scope of this API.
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
response_format = default(body, 'response_format', 'url') # or b64_json
payload = {
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
'width': width,
'height': height,
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
}
resp = {
'created': int(time.time()),
'data': []
}
# TODO: support SD_WEBUI_AUTH username:password pair.
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
response = requests.post(url=sd_url, json=payload)
r = response.json()
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':
resp['data'].extend([{'b64_json': b64_json}])
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/embeddings' in self.path and embedding_model is not None:
@ -540,11 +699,12 @@ def run_server():
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting OpenAI compatible api at {public_url}/')
print(f'Starting OpenAI compatible api at\nOPENAI_API_BASE={public_url}/v1')
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting OpenAI compatible api at http://{server_addr[0]}:{server_addr[1]}/')
print(f'Starting OpenAI compatible api:\nOPENAI_API_BASE=http://{server_addr[0]}:{server_addr[1]}/v1')
server.serve_forever()