From bf22d16ebcee96430d6845c9786bbdab5e74af17 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 23 Mar 2023 21:56:26 -0300 Subject: [PATCH] Clear cache while switching LoRAs --- modules/LoRA.py | 15 +++++++++------ modules/callbacks.py | 8 +------- server.py | 14 +++----------- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index 5f77e34..1c03826 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -2,19 +2,22 @@ from pathlib import Path import modules.shared as shared from modules.models import load_model +from modules.text_generation import clear_torch_cache +def reload_model(): + shared.model = shared.tokenizer = None + clear_torch_cache() + shared.model, shared.tokenizer = load_model(shared.model_name) + def add_lora_to_model(lora_name): from peft import PeftModel - # Is there a more efficient way of returning to the base model? - if lora_name == "None": - print("Reloading the model to remove the LoRA...") - shared.model, shared.tokenizer = load_model(shared.model_name) - else: + reload_model() + + if lora_name != "None": print(f"Adding the LoRA {lora_name} to the model...") - params = {} if not shared.args.cpu: params['dtype'] = shared.model.dtype diff --git a/modules/callbacks.py b/modules/callbacks.py index 2ae9d90..50a6918 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -1,11 +1,10 @@ -import gc from queue import Queue from threading import Thread import torch import transformers -import modules.shared as shared +from modules.text_generation import clear_torch_cache # Copied from https://github.com/PygmalionAI/gradio-ui/ @@ -90,8 +89,3 @@ class Iteratorize: def __exit__(self, exc_type, exc_val, exc_tb): self.stop_now = True clear_torch_cache() - -def clear_torch_cache(): - gc.collect() - if not shared.args.cpu: - torch.cuda.empty_cache() diff --git a/server.py b/server.py index cdf7aa9..068f380 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,3 @@ -import gc import io import json import re @@ -8,7 +7,6 @@ import zipfile from pathlib import Path import gradio as gr -import torch import modules.chat as chat import modules.extensions as extensions_module @@ -17,7 +15,7 @@ import modules.ui as ui from modules.html_generator import generate_chat_html from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt -from modules.text_generation import generate_reply +from modules.text_generation import clear_torch_cache, generate_reply # Loading custom settings settings_file = None @@ -56,21 +54,15 @@ def load_model_wrapper(selected_model): if selected_model != shared.model_name: shared.model_name = selected_model shared.model = shared.tokenizer = None - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() + clear_torch_cache() shared.model, shared.tokenizer = load_model(shared.model_name) return selected_model def load_lora_wrapper(selected_lora): shared.lora_name = selected_lora - default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] - - if not shared.args.cpu: - gc.collect() - torch.cuda.empty_cache() add_lora_to_model(selected_lora) + default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')] return selected_lora, default_text