Don't require llama.cpp models to be placed in subfolders
This commit is contained in:
parent
06b6ff6c2e
commit
fcb594b90e
4 changed files with 41 additions and 39 deletions
|
@ -38,13 +38,30 @@ if shared.args.deepspeed:
|
|||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
|
||||
def find_model_type(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'rwkv-' in model_name.lower():
|
||||
return 'rwkv'
|
||||
elif len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))) > 0:
|
||||
return 'llamacpp'
|
||||
elif re.match('.*ggml.*\.bin', model_name):
|
||||
return 'llamacpp'
|
||||
elif 'chatglm' in model_name:
|
||||
return 'chatglm'
|
||||
elif 'galactica' in model_name:
|
||||
return 'galactica'
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
print(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
|
||||
shared.is_RWKV = 'rwkv-' in model_name.lower()
|
||||
shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
|
||||
if 'chatglm' in model_name.lower():
|
||||
shared.model_type = find_model_type(model_name)
|
||||
if shared.model_type == 'chatglm':
|
||||
LoaderClass = AutoModel
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
else:
|
||||
|
@ -52,7 +69,7 @@ def load_model(model_name):
|
|||
trust_remote_code = False
|
||||
|
||||
# Load the model in simple 16-bit mode by default
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):
|
||||
model = LoaderClass.from_pretrained(Path(f"{shared.args.model_dir}/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16, trust_remote_code=trust_remote_code)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
|
@ -91,7 +108,7 @@ def load_model(model_name):
|
|||
print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
|
||||
|
||||
# RMKV model (not on HuggingFace)
|
||||
elif shared.is_RWKV:
|
||||
elif shared.model_type == 'rwkv':
|
||||
from modules.RWKV import RWKVModel, RWKVTokenizer
|
||||
|
||||
model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda")
|
||||
|
@ -100,12 +117,16 @@ def load_model(model_name):
|
|||
return model, tokenizer
|
||||
|
||||
# llamacpp model
|
||||
elif shared.is_llamacpp:
|
||||
elif shared.model_type == 'llamacpp':
|
||||
from modules.llamacpp_model_alternative import LlamaCppModel
|
||||
|
||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
|
||||
print(f"llama.cpp weights detected: {model_file}\n")
|
||||
path = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
if path.is_file():
|
||||
model_file = path
|
||||
else:
|
||||
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
|
||||
|
||||
print(f"llama.cpp weights detected: {model_file}\n")
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
|
||||
return model, tokenizer
|
||||
|
||||
|
@ -190,7 +211,7 @@ def load_model(model_name):
|
|||
llama_attn_hijack.hijack_llama_attention()
|
||||
|
||||
# Loading the tokenizer
|
||||
if any((k in model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
if shared.model_type == 'gpt4chan' and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue