SD Api Pics extension, v.1.1 (#596)
This commit is contained in:
parent
5543a5089d
commit
ffd102e5c0
6 changed files with 282 additions and 102 deletions
|
@ -4,14 +4,7 @@ import torch
|
|||
from peft import PeftModel
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import load_model
|
||||
from modules.text_generation import clear_torch_cache
|
||||
|
||||
|
||||
def reload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
from modules.models import reload_model
|
||||
|
||||
|
||||
def add_lora_to_model(lora_name):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -16,11 +17,10 @@ import modules.shared as shared
|
|||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||
|
@ -182,6 +182,23 @@ def load_model(model_name):
|
|||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def reload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import gc
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
|
@ -12,7 +11,7 @@ from modules.callbacks import (Iteratorize, Stream,
|
|||
_SentinelTokenStoppingCriteria)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.models import local_rank
|
||||
from modules.models import clear_torch_cache, local_rank
|
||||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
|
@ -101,12 +100,6 @@ def formatted_outputs(reply, model_name):
|
|||
return reply
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
if seed != -1:
|
||||
torch.manual_seed(seed)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue