Add LLaMA 8-bit support

This commit is contained in:
oobabooga 2023-03-04 13:28:42 -03:00
parent c93f1fa99b
commit bd8aac8fa4
2 changed files with 137 additions and 4 deletions

View file

@ -88,12 +88,20 @@ def load_model(model_name):
# LLaMA model (not on HuggingFace)
elif shared.is_LLaMA:
import modules.LLaMA
from modules.LLaMA import LLaMAModel
if shared.args.load_in_8bit:
import modules.LLaMA_8bit
from modules.LLaMA_8bit import LLaMAModel_8bit
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}'))
return model, None
return model, None
else:
import modules.LLaMA
from modules.LLaMA import LLaMAModel
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
return model, None
# Custom
else: