Make the code more like PEP8 for readability (#862)

This commit is contained in:
oobabooga 2023-04-07 00:15:45 -03:00 committed by GitHub
parent 848c4edfd5
commit ea6e77df72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 302 additions and 165 deletions

View file

@ -9,6 +9,7 @@ params = {
'port': 5000,
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers()
prompt = body['prompt']
prompt_lines = [l.strip() for l in prompt.split('\n')]
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler):
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
'num_beams': int(body.get('num_beams',1)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler):
}
generator = generate_reply(
prompt,
prompt,
generate_params,
stopping_strings=body.get('stopping_strings', []),
)
@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler):
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting KoboldAI compatible api at {public_url}/api')
except ImportError:
@ -95,5 +96,6 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever()
def setup():
Thread(target=run_server, daemon=True).start()

View file

@ -5,14 +5,16 @@ params = {
"bias string": " *I am so happy*",
}
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
"""
return string
def output_modifier(string):
"""
This function is applied to the model outputs.
@ -20,6 +22,7 @@ def output_modifier(string):
return string
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
behavior.
"""
if params['activate'] == True:
if params['activate']:
return f'{string} {params["bias string"].strip()} '
else:
return string
def ui():
# Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

View file

@ -20,16 +20,18 @@ user_info = None
if not shared.args.no_stream:
print("Please add --no-stream. This extension is not meant to be used with streaming.")
raise ValueError
# Check if the API is valid and refresh the UI accordingly.
def check_valid_api():
global user, user_info, params
user = ElevenLabsUser(params['api_key'])
user_info = user._get_subscription_data()
print('checking api')
if params['activate'] == False:
if not params['activate']:
return gr.update(value='Disconnected')
elif user_info is None:
print('Incorrect API Key')
@ -37,24 +39,28 @@ def check_valid_api():
else:
print('Got an API Key!')
return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list
def refresh_voices():
global user, user_info
your_voices = [None]
if user_info is not None:
for voice in user.get_available_voices():
your_voices.append(voice.initialName)
return gr.Dropdown.update(choices=your_voices)
return gr.Dropdown.update(choices=your_voices)
else:
return
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
return re.sub('\*[^\*]*?(\*|$)', '', string)
def input_modifier(string):
"""
@ -64,16 +70,17 @@ def input_modifier(string):
return string
def output_modifier(string):
"""
This function is applied to the model outputs.
"""
global params, wav_idx, user, user_info
if params['activate'] == False:
if not params['activate']:
return string
elif user_info == None:
elif user_info is None:
return string
string = remove_surrounded_chars(string)
@ -84,7 +91,7 @@ def output_modifier(string):
if string == '':
string = 'empty reply, try regenerating'
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
voice = user.get_voices_by_name(params['selected_voice'])[0]
audio_data = voice.generate_audio_bytes(string)
@ -94,6 +101,7 @@ def output_modifier(string):
wav_idx += 1
return string
def ui():
# Gradio elements
@ -110,4 +118,4 @@ def ui():
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
connect.click(check_valid_api, [], connection_status)
connect.click(refresh_voices, [], voice)
connect.click(refresh_voices, [], voice)

View file

@ -85,7 +85,7 @@ def select_character(evt: gr.SelectData):
def ui():
with gr.Accordion("Character gallery", open=False):
update = gr.Button("Refresh")
gr.HTML(value="<style>"+generate_css()+"</style>")
gr.HTML(value="<style>" + generate_css() + "</style>")
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
label="",
samples=generate_html(),
@ -93,4 +93,4 @@ def ui():
samples_per_page=50
)
update.click(generate_html, [], gallery)
gallery.select(select_character, None, gradio['character_menu'])
gallery.select(select_character, None, gradio['character_menu'])

View file

@ -7,14 +7,16 @@ params = {
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
def input_modifier(string):
"""
This function is applied to your text inputs before
they are fed into the model.
"""
"""
return GoogleTranslator(source=params['language string'], target='en').translate(string)
def output_modifier(string):
"""
This function is applied to the model outputs.
@ -22,6 +24,7 @@ def output_modifier(string):
return GoogleTranslator(source='en', target=params['language string']).translate(string)
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
return string
def ui():
# Finding the language name from the language code to use as the default value
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]

View file

@ -4,12 +4,14 @@ import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name):
if name == 'None':
return ''
else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui():
if not shared.is_chat():
choices = ['None'] + list(df['Prompt name'])

View file

@ -12,30 +12,33 @@ from PIL import Image
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
# parameters which can be customized in settings.json of webui
params = {
'enable_SD_api': False,
'address': 'http://127.0.0.1:7860',
'save_img': False,
'SD_model': 'NeverEndingDream', # not really used right now
'SD_model': 'NeverEndingDream', # not really used right now
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
'negative_prompt': '(worst quality, low quality:1.3)',
'side_length': 512,
'restore_faces': False
}
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string):
"""
This function is applied to your text inputs before
@ -51,7 +54,7 @@ def input_modifier(string):
lowstr = string.lower()
# TODO: refactor out to separate handler and also replace detection with a regexp
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
picture_response = True
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
shared.processing_message = "*Is sending a picture...*"
@ -62,6 +65,8 @@ def input_modifier(string):
return string
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
global params, pic_id
@ -77,13 +82,13 @@ def get_SD_pictures(description):
"restore_faces": params['restore_faces'],
"negative_prompt": params['negative_prompt']
}
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
r = response.json()
visible_result = ""
for img_str in r['images']:
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
if params['save_img']:
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
image.save(output_file.as_posix())
@ -96,11 +101,13 @@ def get_SD_pictures(description):
image_bytes = buffered.getvalue()
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
return visible_result
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string):
"""
This function is applied to the model outputs.
@ -130,6 +137,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state
return image + "\n" + text
def bot_prefix_modifier(string):
"""
This function is only applied in chat mode. It modifies
@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
return string
def force_pic():
global picture_response
picture_response = True
def ui():
# Gradio elements
@ -153,7 +163,7 @@ def ui():
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
with gr.Column():
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
with gr.Row():
force_btn = gr.Button("Force the next response to be a picture")
generate_now_btn = gr.Button("Generate an image response to the input")
@ -162,9 +172,9 @@ def ui():
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
with gr.Row():
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
# Event functions to update the parameters in the backend
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
@ -176,4 +186,4 @@ def ui():
force_btn.click(force_pic)
generate_now_btn.click(force_pic)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

View file

@ -17,11 +17,13 @@ input_hijack = {
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True)
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text
def ui():
picture_select = gr.Image(label='Send a picture', type='pil')
@ -42,4 +45,4 @@ def ui():
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear the picture from the upload field
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)