Reorganize model loading UI completely (#2720)
This commit is contained in:
parent
57be2eecdf
commit
7ef6a50e84
16 changed files with 365 additions and 243 deletions
|
@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
from modules.models_settings import infer_loader
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -36,62 +37,31 @@ if shared.args.deepspeed:
|
|||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
def find_model_type(model_name):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not path_to_model.exists():
|
||||
return 'None'
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if re.match('.*rwkv.*\.pth', model_name_lower):
|
||||
return 'rwkv'
|
||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||
return 'llamacpp'
|
||||
elif re.match('.*ggml.*\.bin', model_name_lower):
|
||||
return 'llamacpp'
|
||||
elif 'chatglm' in model_name_lower:
|
||||
return 'chatglm'
|
||||
elif 'galactica' in model_name_lower:
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name_lower:
|
||||
return 'llava'
|
||||
elif 'oasst' in model_name_lower:
|
||||
return 'oasst'
|
||||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
def load_model(model_name, loader=None):
|
||||
logger.info(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
|
||||
shared.model_type = find_model_type(model_name)
|
||||
if shared.model_type == 'None':
|
||||
logger.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
shared.is_seq2seq = False
|
||||
load_func_map = {
|
||||
'Transformers': huggingface_loader,
|
||||
'AutoGPTQ': AutoGPTQ_loader,
|
||||
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'FlexGen': flexgen_loader,
|
||||
'RWKV': RWKV_loader
|
||||
}
|
||||
|
||||
if shared.args.gptq_for_llama:
|
||||
load_func = GPTQ_loader
|
||||
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or shared.args.wbits > 0:
|
||||
load_func = AutoGPTQ_loader
|
||||
elif shared.model_type == 'llamacpp':
|
||||
load_func = llamacpp_loader
|
||||
elif shared.model_type == 'rwkv':
|
||||
load_func = RWKV_loader
|
||||
elif shared.args.flexgen:
|
||||
load_func = flexgen_loader
|
||||
else:
|
||||
load_func = huggingface_loader
|
||||
if loader is None:
|
||||
if shared.args.loader is not None:
|
||||
loader = shared.args.loader
|
||||
else:
|
||||
loader = infer_loader(model_name)
|
||||
if loader is None:
|
||||
logger.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
|
||||
output = load_func(model_name)
|
||||
shared.args.loader = loader
|
||||
output = load_func_map[loader](model_name)
|
||||
if type(output) is tuple:
|
||||
model, tokenizer = output
|
||||
else:
|
||||
|
@ -111,11 +81,11 @@ def load_model(model_name):
|
|||
|
||||
def load_tokenizer(model_name, model):
|
||||
tokenizer = None
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM or "LlamaGPTQForCausalLM" in str(type(model)):
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if shared.model_type not in ['llava', 'oasst']:
|
||||
if any(s in shared.model_name.lower() for s in ['llava', 'oasst']):
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
|
@ -140,12 +110,16 @@ def load_tokenizer(model_name, model):
|
|||
|
||||
|
||||
def huggingface_loader(model_name):
|
||||
if shared.model_type == 'chatglm':
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if 'chatglm' in model_name.lower():
|
||||
LoaderClass = AutoModel
|
||||
elif shared.model_type == 'HF_seq2seq':
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
shared.is_seq2seq = True
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue