Make the code more like PEP8 for readability (#862)

This commit is contained in:
oobabooga 2023-04-07 00:15:45 -03:00 committed by GitHub
parent 848c4edfd5
commit ea6e77df72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 302 additions and 165 deletions

View file

@ -12,30 +12,33 @@ from PIL import Image
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
# parameters which can be customized in settings.json of webui
params = {
'enable_SD_api': False,
'address': 'http://127.0.0.1:7860',
'save_img': False,
'SD_model': 'NeverEndingDream', # not really used right now
'SD_model': 'NeverEndingDream', # not really used right now
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
'negative_prompt': '(worst quality, low quality:1.3)',
'side_length': 512,
'restore_faces': False
}
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -51,7 +54,7 @@ def input_modifier(string):
lowstr = string.lower()
# TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*"
@ -62,6 +65,8 @@ def input_modifier(string):
return string
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
global params, pic_id
@ -77,13 +82,13 @@ def get_SD_pictures(description):
"restore_faces": params['restore_faces'],
"negative_prompt": params['negative_prompt']
}
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
r = response.json()
visible_result = ""
for img_str in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
if params['save_img']:
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix())
@ -96,11 +101,13 @@ def get_SD_pictures(description):
image_bytes = buffered.getvalue()
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
return visible_result
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string):
"""
This function is applied to the model outputs.
@ -130,6 +137,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state
return image + "\n" + text
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
return string
def force_pic():
global picture_response
picture_response = True
def ui():
# Gradio elements
@ -153,7 +163,7 @@ def ui():
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
with gr.Column():
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
with gr.Row():
force_btn = gr.Button("Force the next response to be a picture")
generate_now_btn = gr.Button("Generate an image response to the input")
@ -162,9 +172,9 @@ def ui():
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
with gr.Row():
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
# Event functions to update the parameters in the backend
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
@ -176,4 +186,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)