Move towards HF LLaMA implementation

This commit is contained in:
oobabooga 2023-03-05 01:20:31 -03:00
parent bd8aac8fa4
commit c33715ad5b
6 changed files with 4 additions and 245 deletions

View file

@ -39,10 +39,9 @@ def load_model(model_name):
t0 = time.time()
shared.is_RWKV = model_name.lower().startswith('rwkv-')
shared.is_LLaMA = model_name.lower().startswith('llama-')
# Default settings
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV or shared.is_LLaMA):
if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen or shared.is_RWKV):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
@ -86,23 +85,6 @@ def load_model(model_name):
return model, None
# LLaMA model (not on HuggingFace)
elif shared.is_LLaMA:
if shared.args.load_in_8bit:
import modules.LLaMA_8bit
from modules.LLaMA_8bit import LLaMAModel_8bit
model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}'))
return model, None
else:
import modules.LLaMA
from modules.LLaMA import LLaMAModel
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
return model, None
# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"