diff --git a/server.py b/server.py index 77feecb..5ff89ff 100644 --- a/server.py +++ b/server.py @@ -9,6 +9,7 @@ import gradio as gr import transformers from html_generator import * from transformers import AutoTokenizer, AutoModelForCausalLM +import warnings parser = argparse.ArgumentParser() @@ -20,12 +21,15 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.') args = parser.parse_args() + loaded_preset = None available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))))) available_models = [item for item in available_models if not item.endswith('.txt')] available_models = sorted(available_models, key=str.lower) available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt'))))) +transformers.logging.set_verbosity_error() + def load_model(model_name): print(f"Loading {model_name}...") t0 = time.time() @@ -188,10 +192,15 @@ if args.notebook: elif args.chat: history = [] - def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + # This gets the new line characters right. + def chat_response_cleaner(text): text = text.replace('\n', '\n\n') text = re.sub(r"\n{3,}", "\n\n", text) text = text.strip() + return text + + def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): + text = chat_response_cleaner(text) question = context+'\n\n' for i in range(len(history)): @@ -209,9 +218,7 @@ elif args.chat: idx = reply.find(f"\n{name1}:") if idx != -1: reply = reply[:idx] - reply = reply.replace('\n', '\n\n') - reply = re.sub(r"\n{3,}", "\n\n", reply) - reply = reply.strip() + reply = chat_response_cleaner(response) history.append((text, reply)) return history