Merge branch 'main' into Brawlence-main

This commit is contained in:
oobabooga 2023-03-26 23:40:51 -03:00
commit e07c9e3093
20 changed files with 270 additions and 203 deletions

View file

@ -1,4 +1,3 @@
import gc
import io
import json
import re
@ -8,7 +7,6 @@ import zipfile
from pathlib import Path
import gradio as gr
import torch
import modules.chat as chat
import modules.extensions as extensions_module
@ -17,7 +15,7 @@ import modules.ui as ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply
from modules.text_generation import clear_torch_cache, generate_reply
# Loading custom settings
settings_file = None
@ -56,9 +54,7 @@ def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
shared.model = shared.tokenizer = None
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model
@ -75,13 +71,8 @@ def unload_model():
print("Model weights unloaded.")
def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
return selected_lora, default_text
@ -258,14 +249,13 @@ else:
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
print(shared.args.lora)
shared.lora_name = shared.args.lora
add_lora_to_model(shared.lora_name)
add_lora_to_model(shared.args.lora)
# Default UI settings
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if default_text == '':
if shared.lora_name != "None":
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
else:
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
@ -354,7 +344,7 @@ def create_interface():
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@ -395,19 +385,22 @@ def create_interface():
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
with gr.Column(scale=4):
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
create_model_and_preset_menus()
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Column(scale=1):
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_model_and_preset_menus()
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)