diff --git a/README.md b/README.md index 1c8a4d9..40af19c 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ Then browse to Optionally, you can use the following command-line flags: --model model-name: load this model by default. + --notebook: Launch the webui in notebook mode, where the output is written to the same text box as the input. ## Presets diff --git a/server.py b/server.py index 2783743..613b9b5 100644 --- a/server.py +++ b/server.py @@ -11,7 +11,8 @@ from transformers import AutoTokenizer from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel parser = argparse.ArgumentParser() -parser.add_argument('--model', type=str, help='Name of the model to load by default') +parser.add_argument('--model', type=str, help='Name of the model to load by default.') +parser.add_argument('--notebook', action='store_true', help='Launch the webui in notebook mode, where the output is written to the same text box as the input.') args = parser.parse_args() loaded_preset = None available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]")))) @@ -79,7 +80,10 @@ def generate_reply(question, temperature, max_length, inference_settings, select if model_name.startswith('gpt4chan'): reply = fix_gpt4chan(reply) - return reply + if model_name.lower().startswith('galactica'): + return reply, reply + else: + return reply, '' # Choosing the default model if args.model is not None: @@ -104,20 +108,40 @@ if model_name.startswith('gpt4chan'): else: default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:" -interface = gr.Interface( - generate_reply, - inputs=[ - gr.Textbox(value=default_text, lines=15), - gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), - gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200), - gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"), - gr.Dropdown(choices=available_models, value=model_name), - ], - outputs=[ - gr.Textbox(placeholder="", lines=15), - ], - title="Text generation lab", - description=f"Generate text using Large Language Models.", -) +if args.notebook: + with gr.Blocks() as interface: + gr.Markdown( + f""" + # Text generation lab + Generate text using Large Language Models. + """ + ) + + textbox = gr.Textbox(value=default_text, lines=23) + temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7) + length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200) + preset_menu = gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default") + model_menu = gr.Dropdown(choices=available_models, value=model_name) + btn = gr.Button("Generate") + markdown = gr.Markdown() + + btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [textbox, markdown], show_progress=False) +else: + interface = gr.Interface( + generate_reply, + inputs=[ + gr.Textbox(value=default_text, lines=15), + gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7), + gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200), + gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"), + gr.Dropdown(choices=available_models, value=model_name), + ], + outputs=[ + gr.Textbox(placeholder="", lines=15), + gr.Markdown() + ], + title="Text generation lab", + description=f"Generate text using Large Language Models.", + ) interface.launch(share=False, server_name="0.0.0.0")