Add LLaMA support
This commit is contained in:
parent
2bff646130
commit
ea5c5eb3da
4 changed files with 110 additions and 2 deletions
|
@ -39,9 +39,10 @@ def load_model(model_name):
|
|||
t0 = time.time()
|
||||
|
||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
||||
shared.is_LLaMA = model_name.lower().startswith('llama-')
|
||||
|
||||
# Default settings
|
||||
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
|
||||
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
else:
|
||||
|
@ -85,6 +86,15 @@ def load_model(model_name):
|
|||
|
||||
return model, None
|
||||
|
||||
# LLaMA model (not on HuggingFace)
|
||||
elif shared.is_LLaMA:
|
||||
import modules.LLaMA
|
||||
from modules.LLaMA import LLaMAModel
|
||||
|
||||
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
|
||||
|
||||
return model, None
|
||||
|
||||
# Custom
|
||||
else:
|
||||
command = "AutoModelForCausalLM.from_pretrained"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue