Add refresh buttons for the model/preset/character menus

This commit is contained in:
oobabooga 2023-01-22 00:02:46 -03:00
parent bc664ecf3b
commit 434d4b128c
5 changed files with 73 additions and 18 deletions

View file

@ -1,18 +1,19 @@
import re
import gc
import time
import glob
from sys import exit
import torch
import argparse
import json
from sys import exit
from pathlib import Path
import gradio as gr
import transformers
from html_generator import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
import gc
from tqdm import tqdm
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from modules.html_generator import *
from modules.ui import *
transformers.logging.set_verbosity_error()
@ -36,9 +37,18 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
args = parser.parse_args()
loaded_preset = None
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_presets = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
available_characters = sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
def get_available_models():
return sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
def get_available_characters():
return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
settings = {
'max_new_tokens': 200,
@ -227,7 +237,7 @@ else:
default_text = settings['prompt']
description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}"
css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem} #refresh-button {flex: none; margin: 0; padding: 0; min-width: 50px; border: none; box-shadow: none; border-radius: 0} #download-label, #upload-label {min-height: 0}"
if args.chat or args.cai_chat:
history = []
character = None
@ -413,24 +423,30 @@ if args.chat or args.cai_chat:
with gr.Row():
with gr.Column():
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column():
history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size (0 for no limit)', value=settings['history_size'])
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
with gr.Row():
character_menu = gr.Dropdown(choices=["None"]+available_characters, value="None", label='Character')
character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
with gr.Row():
check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
with gr.Row():
with gr.Column():
gr.Markdown("Upload chat history")
gr.Markdown("Upload chat history", elem_id="upload-label")
upload = gr.File(type='binary')
with gr.Column():
gr.Markdown("Download chat history")
gr.Markdown("Download chat history", elem_id="download-label")
save_btn = gr.Button(value="Click me")
download = gr.File()
@ -473,9 +489,13 @@ elif args.notebook:
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
with gr.Row():
with gr.Column():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
with gr.Column():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
gen_event = btn.click(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream, api_name="textgen")
gen_event2 = textbox.submit(generate_reply, [textbox, length_slider, preset_menu, model_menu], [textbox, markdown, html], show_progress=args.no_stream)
@ -488,8 +508,12 @@ else:
with gr.Column():
textbox = gr.Textbox(value=default_text, lines=15, label='Input')
length_slider = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
with gr.Row():
preset_menu = gr.Dropdown(choices=available_presets, value=settings['preset'], label='Settings preset')
create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
with gr.Row():
model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
btn = gr.Button("Generate")
with gr.Row():
with gr.Column():