Add various checks to model loading functions
This commit is contained in:
parent
abd361b3a0
commit
ef10ffc6b4
2 changed files with 28 additions and 19 deletions
|
@ -40,10 +40,14 @@ if shared.args.deepspeed:
|
|||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
def find_model_type(model_name):
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if not path_to_model.exists():
|
||||
return 'None'
|
||||
|
||||
model_name_lower = model_name.lower()
|
||||
if 'rwkv-' in model_name_lower:
|
||||
return 'rwkv'
|
||||
elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
|
||||
elif len(list(path_to_model.glob('*ggml*.bin'))) > 0:
|
||||
return 'llamacpp'
|
||||
elif re.match('.*ggml.*\.bin', model_name_lower):
|
||||
return 'llamacpp'
|
||||
|
@ -58,7 +62,7 @@ def find_model_type(model_name):
|
|||
elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code)
|
||||
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
|
@ -71,11 +75,14 @@ def load_model(model_name):
|
|||
t0 = time.time()
|
||||
|
||||
shared.model_type = find_model_type(model_name)
|
||||
if shared.args.wbits > 0 or shared.args.autogptq:
|
||||
if shared.args.autogptq:
|
||||
load_func = AutoGPTQ_loader
|
||||
else:
|
||||
load_func = GPTQ_loader
|
||||
if shared.model_type == 'None':
|
||||
logging.error('The path to the model does not exist. Exiting.')
|
||||
return None, None
|
||||
|
||||
if shared.args.autogptq:
|
||||
load_func = AutoGPTQ_loader
|
||||
elif shared.args.wbits > 0:
|
||||
load_func = GPTQ_loader
|
||||
elif shared.model_type == 'llamacpp':
|
||||
load_func = llamacpp_loader
|
||||
elif shared.model_type == 'rwkv':
|
||||
|
@ -101,6 +108,7 @@ def load_model(model_name):
|
|||
|
||||
|
||||
def load_tokenizer(model_name, model):
|
||||
tokenizer = None
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
|
@ -122,7 +130,9 @@ def load_tokenizer(model_name, model):
|
|||
except:
|
||||
pass
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}/"), trust_remote_code=shared.args.trust_remote_code)
|
||||
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
|
||||
if path_to_model.exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue