SD Api Pics extension, v.1.1 (#596)

This commit is contained in:
Φφ 2023-04-08 03:36:04 +03:00 committed by GitHub
parent 5543a5089d
commit ffd102e5c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 282 additions and 102 deletions

View file

@ -4,14 +4,7 @@ import torch
from peft import PeftModel
import modules.shared as shared
from modules.models import load_model
from modules.text_generation import clear_torch_cache
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
from modules.models import reload_model
def add_lora_to_model(lora_name):

View file

@ -1,3 +1,4 @@
import gc
import json
import os
import re
@ -16,11 +17,10 @@ import modules.shared as shared
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
local_rank = None
if shared.args.deepspeed:
import deepspeed
from transformers.deepspeed import (HfDeepSpeedConfig,
@ -182,6 +182,23 @@ def load_model(model_name):
return model, tokenizer
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def unload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
def reload_model():
shared.model = shared.tokenizer = None
clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name)
def load_soft_prompt(name):
if name == 'None':
shared.soft_prompt = False

View file

@ -1,4 +1,3 @@
import gc
import re
import time
import traceback
@ -12,7 +11,7 @@ from modules.callbacks import (Iteratorize, Stream,
_SentinelTokenStoppingCriteria)
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
from modules.models import clear_torch_cache, local_rank
def get_max_prompt_length(tokens):
@ -101,12 +100,6 @@ def formatted_outputs(reply, model_name):
return reply
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
torch.cuda.empty_cache()
def set_manual_seed(seed):
if seed != -1:
torch.manual_seed(seed)