Reorganize model loading UI completely (#2720)

This commit is contained in:
oobabooga 2023-06-16 19:00:37 -03:00 committed by GitHub
parent 57be2eecdf
commit 7ef6a50e84
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 365 additions and 243 deletions

View file

@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
import modules.shared as shared
from modules import llama_attn_hijack, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import infer_loader
transformers.logging.set_verbosity_error()
@ -36,62 +37,31 @@ if shared.args.deepspeed:
sampler_hijack.hijack_samplers()
# Some models require special treatment in various parts of the code.
# This function detects those models
def find_model_type(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
if not path_to_model.exists():
return 'None'
model_name_lower = model_name.lower()
if re.match('.*rwkv.*\.pth', model_name_lower):
return 'rwkv'
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
return 'llamacpp'
elif re.match('.*ggml.*\.bin', model_name_lower):
return 'llamacpp'
elif 'chatglm' in model_name_lower:
return 'chatglm'
elif 'galactica' in model_name_lower:
return 'galactica'
elif 'llava' in model_name_lower:
return 'llava'
elif 'oasst' in model_name_lower:
return 'oasst'
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
# Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq'
else:
return 'HF_generic'
def load_model(model_name):
def load_model(model_name, loader=None):
logger.info(f"Loading {model_name}...")
t0 = time.time()
shared.model_type = find_model_type(model_name)
if shared.model_type == 'None':
logger.error('The path to the model does not exist. Exiting.')
return None, None
shared.is_seq2seq = False
load_func_map = {
'Transformers': huggingface_loader,
'AutoGPTQ': AutoGPTQ_loader,
'GPTQ-for-LLaMa': GPTQ_loader,
'llama.cpp': llamacpp_loader,
'FlexGen': flexgen_loader,
'RWKV': RWKV_loader
}
if shared.args.gptq_for_llama:
load_func = GPTQ_loader
elif Path(f'{shared.args.model_dir}/{model_name}/quantize_config.json').exists() or shared.args.wbits > 0:
load_func = AutoGPTQ_loader
elif shared.model_type == 'llamacpp':
load_func = llamacpp_loader
elif shared.model_type == 'rwkv':
load_func = RWKV_loader
elif shared.args.flexgen:
load_func = flexgen_loader
else:
load_func = huggingface_loader
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = infer_loader(model_name)
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
return None, None
output = load_func(model_name)
shared.args.loader = loader
output = load_func_map[loader](model_name)
if type(output) is tuple:
model, tokenizer = output
else:
@ -111,11 +81,11 @@ def load_model(model_name):
def load_tokenizer(model_name, model):
tokenizer = None
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
if any(s in model_name.lower() for s in ['gpt-4chan', 'gpt4chan']) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
elif type(model) is transformers.LlamaForCausalLM or "LlamaGPTQForCausalLM" in str(type(model)):
# Try to load an universal LLaMA tokenizer
if shared.model_type not in ['llava', 'oasst']:
if any(s in shared.model_name.lower() for s in ['llava', 'oasst']):
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
logger.info(f"Loading the universal LLaMA tokenizer from {p}...")
@ -140,12 +110,16 @@ def load_tokenizer(model_name, model):
def huggingface_loader(model_name):
if shared.model_type == 'chatglm':
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
elif shared.model_type == 'HF_seq2seq':
LoaderClass = AutoModelForSeq2SeqLM
else:
LoaderClass = AutoModelForCausalLM
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if config.to_dict().get("is_encoder_decoder", False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
else:
LoaderClass = AutoModelForCausalLM
# Load the model in simple 16-bit mode by default
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.auto_devices, shared.args.disk, shared.args.deepspeed, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None]):