Clear cache while switching LoRAs

This commit is contained in:
oobabooga 2023-03-23 21:56:26 -03:00 committed by GitHub
parent 4578e88ffd
commit bf22d16ebc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 24 deletions

View file

@ -2,19 +2,22 @@ from pathlib import Path
import modules.shared as shared import modules.shared as shared
from modules.models import load_model from modules.models import load_model
from modules.text_generation import clear_torch_cache
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
from peft import PeftModel from peft import PeftModel
# Is there a more efficient way of returning to the base model? reload_model()
if lora_name == "None":
print("Reloading the model to remove the LoRA...")
shared.model, shared.tokenizer = load_model(shared.model_name)
else:
print(f"Adding the LoRA {lora_name} to the model...")
if lora_name != "None":
print(f"Adding the LoRA {lora_name} to the model...")
params = {} params = {}
if not shared.args.cpu: if not shared.args.cpu:
params['dtype'] = shared.model.dtype params['dtype'] = shared.model.dtype

View file

@ -1,11 +1,10 @@
import gc
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
import torch import torch
import transformers import transformers
import modules.shared as shared from modules.text_generation import clear_torch_cache
# Copied from https://github.com/PygmalionAI/gradio-ui/ # Copied from https://github.com/PygmalionAI/gradio-ui/
@ -90,8 +89,3 @@ class Iteratorize:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.stop_now = True self.stop_now = True
clear_torch_cache() clear_torch_cache()
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()

View file

@ -1,4 +1,3 @@
import gc
import io import io
import json import json
import re import re
@ -8,7 +7,6 @@ import zipfile
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import torch
import modules.chat as chat import modules.chat as chat
import modules.extensions as extensions_module import modules.extensions as extensions_module
@ -17,7 +15,7 @@ import modules.ui as ui
from modules.html_generator import generate_chat_html from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply from modules.text_generation import clear_torch_cache, generate_reply
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
@ -56,21 +54,15 @@ def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
if not shared.args.cpu: clear_torch_cache()
gc.collect()
torch.cuda.empty_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
return selected_model return selected_model
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora shared.lora_name = selected_lora
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
return selected_lora, default_text return selected_lora, default_text