This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View file

@ -3,6 +3,7 @@ import time
import requests
from extensions.openai.errors import *
def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
@ -15,7 +16,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
width, height = [ int(x) for x in size.split('x') ] # ignore the restrictions on size
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
# to hack on better generation, edit default payload.
payload = {
@ -23,7 +24,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
'width': width,
'height': height,
'batch_size': n,
'restore_faces': True, # slightly less horrible
'restore_faces': True, # slightly less horrible
}
resp = {
@ -37,7 +38,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
response = requests.post(url=sd_url, json=payload)
r = response.json()
if response.status_code != 200 or 'images' not in r:
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code = response.status_code)
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':
@ -45,4 +46,4 @@ def generations(prompt: str, size: str, response_format: str, n: int):
else:
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
return resp
return resp