Make the code more like PEP8 for readability (#862)
This commit is contained in:
parent
848c4edfd5
commit
ea6e77df72
28 changed files with 302 additions and 165 deletions
|
@ -9,6 +9,7 @@ params = {
|
|||
'port': 5000,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
|
@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
|
||||
|
@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler):
|
|||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
|
||||
'num_beams': int(body.get('num_beams',1)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
|
@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
}
|
||||
|
||||
generator = generate_reply(
|
||||
prompt,
|
||||
prompt,
|
||||
generate_params,
|
||||
stopping_strings=body.get('stopping_strings', []),
|
||||
)
|
||||
|
@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||
def run_server():
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
if shared.args.share:
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting KoboldAI compatible api at {public_url}/api')
|
||||
except ImportError:
|
||||
|
@ -95,5 +96,6 @@ def run_server():
|
|||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
|
|
|
@ -5,14 +5,16 @@ params = {
|
|||
"bias string": " *I am so happy*",
|
||||
}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
|
@ -20,6 +22,7 @@ def output_modifier(string):
|
|||
|
||||
return string
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
|
@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
|
|||
behavior.
|
||||
"""
|
||||
|
||||
if params['activate'] == True:
|
||||
if params['activate']:
|
||||
return f'{string} {params["bias string"].strip()} '
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Gradio elements
|
||||
activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
|
||||
|
|
|
@ -20,16 +20,18 @@ user_info = None
|
|||
if not shared.args.no_stream:
|
||||
print("Please add --no-stream. This extension is not meant to be used with streaming.")
|
||||
raise ValueError
|
||||
|
||||
|
||||
# Check if the API is valid and refresh the UI accordingly.
|
||||
|
||||
|
||||
def check_valid_api():
|
||||
|
||||
|
||||
global user, user_info, params
|
||||
|
||||
user = ElevenLabsUser(params['api_key'])
|
||||
user_info = user._get_subscription_data()
|
||||
print('checking api')
|
||||
if params['activate'] == False:
|
||||
if not params['activate']:
|
||||
return gr.update(value='Disconnected')
|
||||
elif user_info is None:
|
||||
print('Incorrect API Key')
|
||||
|
@ -37,24 +39,28 @@ def check_valid_api():
|
|||
else:
|
||||
print('Got an API Key!')
|
||||
return gr.update(value='Connected')
|
||||
|
||||
|
||||
# Once the API is verified, get the available voices and update the dropdown list
|
||||
|
||||
|
||||
def refresh_voices():
|
||||
|
||||
|
||||
global user, user_info
|
||||
|
||||
|
||||
your_voices = [None]
|
||||
if user_info is not None:
|
||||
for voice in user.get_available_voices():
|
||||
your_voices.append(voice.initialName)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
return gr.Dropdown.update(choices=your_voices)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
|
@ -64,16 +70,17 @@ def input_modifier(string):
|
|||
|
||||
return string
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
"""
|
||||
|
||||
global params, wav_idx, user, user_info
|
||||
|
||||
if params['activate'] == False:
|
||||
|
||||
if not params['activate']:
|
||||
return string
|
||||
elif user_info == None:
|
||||
elif user_info is None:
|
||||
return string
|
||||
|
||||
string = remove_surrounded_chars(string)
|
||||
|
@ -84,7 +91,7 @@ def output_modifier(string):
|
|||
|
||||
if string == '':
|
||||
string = 'empty reply, try regenerating'
|
||||
|
||||
|
||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
|
||||
voice = user.get_voices_by_name(params['selected_voice'])[0]
|
||||
audio_data = voice.generate_audio_bytes(string)
|
||||
|
@ -94,6 +101,7 @@ def output_modifier(string):
|
|||
wav_idx += 1
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
|
@ -110,4 +118,4 @@ def ui():
|
|||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||
api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
|
||||
connect.click(check_valid_api, [], connection_status)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
connect.click(refresh_voices, [], voice)
|
||||
|
|
|
@ -85,7 +85,7 @@ def select_character(evt: gr.SelectData):
|
|||
def ui():
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gr.HTML(value="<style>"+generate_css()+"</style>")
|
||||
gr.HTML(value="<style>" + generate_css() + "</style>")
|
||||
gallery = gr.Dataset(components=[gr.HTML(visible=False)],
|
||||
label="",
|
||||
samples=generate_html(),
|
||||
|
@ -93,4 +93,4 @@ def ui():
|
|||
samples_per_page=50
|
||||
)
|
||||
update.click(generate_html, [], gallery)
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
gallery.select(select_character, None, gradio['character_menu'])
|
||||
|
|
|
@ -7,14 +7,16 @@ params = {
|
|||
|
||||
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
they are fed into the model.
|
||||
"""
|
||||
"""
|
||||
|
||||
return GoogleTranslator(source=params['language string'], target='en').translate(string)
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
|
@ -22,6 +24,7 @@ def output_modifier(string):
|
|||
|
||||
return GoogleTranslator(source='en', target=params['language string']).translate(string)
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
|
@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
|
|||
|
||||
return string
|
||||
|
||||
|
||||
def ui():
|
||||
# Finding the language name from the language code to use as the default value
|
||||
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
|
||||
|
|
|
@ -4,12 +4,14 @@ import pandas as pd
|
|||
|
||||
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
|
||||
|
||||
|
||||
def get_prompt_by_name(name):
|
||||
if name == 'None':
|
||||
return ''
|
||||
else:
|
||||
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
|
||||
|
||||
|
||||
def ui():
|
||||
if not shared.is_chat():
|
||||
choices = ['None'] + list(df['Prompt name'])
|
||||
|
|
|
@ -12,30 +12,33 @@ from PIL import Image
|
|||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
# parameters which can be customized in settings.json of webui
|
||||
params = {
|
||||
'enable_SD_api': False,
|
||||
'address': 'http://127.0.0.1:7860',
|
||||
'save_img': False,
|
||||
'SD_model': 'NeverEndingDream', # not really used right now
|
||||
'SD_model': 'NeverEndingDream', # not really used right now
|
||||
'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
|
||||
'negative_prompt': '(worst quality, low quality:1.3)',
|
||||
'side_length': 512,
|
||||
'restore_faces': False
|
||||
}
|
||||
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
|
||||
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
streaming_state = shared.args.no_stream # remember if chat streaming was enabled
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
pic_id = 0
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
return re.sub('\*[^\*]*?(\*|$)','',string)
|
||||
return re.sub('\*[^\*]*?(\*|$)', '', string)
|
||||
|
||||
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
|
||||
|
||||
|
||||
def input_modifier(string):
|
||||
"""
|
||||
This function is applied to your text inputs before
|
||||
|
@ -51,7 +54,7 @@ def input_modifier(string):
|
|||
lowstr = string.lower()
|
||||
|
||||
# TODO: refactor out to separate handler and also replace detection with a regexp
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
|
||||
picture_response = True
|
||||
shared.args.no_stream = True # Disable streaming cause otherwise the SD-generated picture would return as a dud
|
||||
shared.processing_message = "*Is sending a picture...*"
|
||||
|
@ -62,6 +65,8 @@ def input_modifier(string):
|
|||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
|
||||
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params, pic_id
|
||||
|
@ -77,13 +82,13 @@ def get_SD_pictures(description):
|
|||
"restore_faces": params['restore_faces'],
|
||||
"negative_prompt": params['negative_prompt']
|
||||
}
|
||||
|
||||
|
||||
response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
|
||||
r = response.json()
|
||||
|
||||
visible_result = ""
|
||||
for img_str in r['images']:
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
|
||||
if params['save_img']:
|
||||
output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
|
||||
image.save(output_file.as_posix())
|
||||
|
@ -96,11 +101,13 @@ def get_SD_pictures(description):
|
|||
image_bytes = buffered.getvalue()
|
||||
img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
|
||||
visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
|
||||
|
||||
|
||||
return visible_result
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
|
||||
|
||||
def output_modifier(string):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
|
@ -130,6 +137,7 @@ def output_modifier(string):
|
|||
shared.args.no_stream = streaming_state
|
||||
return image + "\n" + text
|
||||
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
"""
|
||||
This function is only applied in chat mode. It modifies
|
||||
|
@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
|
|||
|
||||
return string
|
||||
|
||||
|
||||
def force_pic():
|
||||
global picture_response
|
||||
picture_response = True
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
|
@ -153,7 +163,7 @@ def ui():
|
|||
save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
|
||||
with gr.Column():
|
||||
address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
|
||||
|
||||
|
||||
with gr.Row():
|
||||
force_btn = gr.Button("Force the next response to be a picture")
|
||||
generate_now_btn = gr.Button("Generate an image response to the input")
|
||||
|
@ -162,9 +172,9 @@ def ui():
|
|||
prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
|
||||
with gr.Row():
|
||||
negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
|
||||
dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
|
||||
dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
|
||||
# model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
|
||||
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
|
||||
save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
|
||||
|
@ -176,4 +186,4 @@ def ui():
|
|||
|
||||
force_btn.click(force_pic)
|
||||
generate_now_btn.click(force_pic)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
|
|
@ -17,11 +17,13 @@ input_hijack = {
|
|||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||
|
||||
|
||||
def caption_image(raw_image):
|
||||
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
|
||||
out = model.generate(**inputs, max_new_tokens=100)
|
||||
return processor.decode(out[0], skip_special_tokens=True)
|
||||
|
||||
|
||||
def generate_chat_picture(picture, name1, name2):
|
||||
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
|
||||
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
|
||||
|
@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
|
|||
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def ui():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
|
||||
|
@ -42,4 +45,4 @@ def ui():
|
|||
picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear the picture from the upload field
|
||||
picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
|
||||
picture_select.upload(lambda: None, [], [picture_select], show_progress=False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue