diff --git a/.gitignore b/.gitignore index 3685291..a9c47a5 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ repositories settings.json img_bot* img_me* +prompts/[0-9]* diff --git a/css/main.css b/css/main.css index 97879f0..3f04409 100644 --- a/css/main.css +++ b/css/main.css @@ -37,12 +37,6 @@ text-decoration: none !important; } -svg { - display: unset !important; - vertical-align: middle !important; - margin: 5px; -} - ol li p, ul li p { display: inline-block; } @@ -64,3 +58,8 @@ ol li p, ul li p { padding: 15px; padding: 15px; } + +span.math.inline { + font-size: 27px; + vertical-align: baseline !important; +} diff --git a/modules/callbacks.py b/modules/callbacks.py index 8d30d61..d85f406 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -54,7 +54,7 @@ class Iteratorize: self.stop_now = False def _callback(val): - if self.stop_now: + if self.stop_now or shared.stop_everything: raise ValueError self.q.put(val) diff --git a/modules/chat.py b/modules/chat.py index 1a43cf3..cc3c45c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check): reply = fix_newlines(reply) return reply, next_character_found -def stop_everything_event(): - shared.stop_everything = True - def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): - shared.stop_everything = False just_started = True eos_token = '\n' if check else None name1_original = name1 diff --git a/modules/text_generation.py b/modules/text_generation.py index 9b2c233..20a07ca 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -99,9 +99,13 @@ def set_manual_seed(seed): if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) +def stop_everything_event(): + shared.stop_everything = True + def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): clear_torch_cache() set_manual_seed(seed) + shared.stop_everything = False t0 = time.time() original_question = question @@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi break yield formatted_outputs(reply, shared.model_name) - yield formatted_outputs(reply, shared.model_name) - # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: for i in range(max_new_tokens//8+1): diff --git a/server.py b/server.py index caca85c..cf37dc5 100644 --- a/server.py +++ b/server.py @@ -14,7 +14,8 @@ import modules.extensions as extensions_module from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt -from modules.text_generation import clear_torch_cache, generate_reply +from modules.text_generation import (clear_torch_cache, generate_reply, + stop_everything_event) # Loading custom settings settings_file = None @@ -133,7 +134,7 @@ def save_prompt(text): fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt" with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: f.write(text) - return f"Saved prompt to prompts/{fname}" + return f"Saved to prompts/{fname}" def load_prompt(fname): if fname in ['None', '']: @@ -154,7 +155,7 @@ def create_prompt_menus(): shared.gradio['save_prompt'] = gr.Button('Save prompt') shared.gradio['status'] = gr.Markdown('Ready') - shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True) + shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) def create_settings_menus(default_preset): @@ -364,7 +365,7 @@ def create_interface(): gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) @@ -415,11 +416,15 @@ def create_interface(): shared.gradio['html'] = gr.HTML() with gr.Row(): - shared.gradio['Generate'] = gr.Button('Generate') - shared.gradio['Stop'] = gr.Button('Stop') + with gr.Column(): + with gr.Row(): + shared.gradio['Generate'] = gr.Button('Generate') + shared.gradio['Stop'] = gr.Button('Stop') + with gr.Column(): + pass with gr.Column(scale=1): - gr.Markdown("\n") + gr.HTML('
') shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) create_prompt_menus() @@ -431,7 +436,7 @@ def create_interface(): output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") else: @@ -465,7 +470,7 @@ def create_interface(): gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) - shared.gradio['Stop'].click(None, None, None, cancels=gen_events) + shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") with gr.Tab("Training", elem_id="training-tab"):