SD Api Pics extension, v.1.1 (#596)
This commit is contained in:
parent
5543a5089d
commit
ffd102e5c0
6 changed files with 282 additions and 102 deletions
|
@ -1,3 +1,4 @@
|
|||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
@ -16,11 +17,10 @@ import modules.shared as shared
|
|||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
local_rank = None
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||
|
@ -182,6 +182,23 @@ def load_model(model_name):
|
|||
return model, tokenizer
|
||||
|
||||
|
||||
def clear_torch_cache():
|
||||
gc.collect()
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def unload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
|
||||
|
||||
def reload_model():
|
||||
shared.model = shared.tokenizer = None
|
||||
clear_torch_cache()
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
|
||||
def load_soft_prompt(name):
|
||||
if name == 'None':
|
||||
shared.soft_prompt = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue