Lint the openai extension

This commit is contained in:
oobabooga 2023-09-15 20:11:16 -07:00
parent 760510db52
commit 8f97e87cac
12 changed files with 79 additions and 69 deletions

View file

@ -1,7 +1,8 @@
import os
import time
import requests
from extensions.openai.errors import *
from extensions.openai.errors import ServiceUnavailableError
def generations(prompt: str, size: str, response_format: str, n: int):
@ -14,7 +15,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later!
base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
sd_defaults = {
'sampler_name': 'DPM++ 2M Karras', # vast improvement
'steps': 30,
@ -56,7 +57,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
r = response.json()
if response.status_code != 200 or 'images' not in r:
print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors',None))
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
# r['parameters']...
for b64_json in r['images']:
if response_format == 'b64_json':