Add LLaMA 8-bit support
This commit is contained in:
parent
c93f1fa99b
commit
bd8aac8fa4
2 changed files with 137 additions and 4 deletions
|
@ -88,12 +88,20 @@ def load_model(model_name):
|
|||
|
||||
# LLaMA model (not on HuggingFace)
|
||||
elif shared.is_LLaMA:
|
||||
import modules.LLaMA
|
||||
from modules.LLaMA import LLaMAModel
|
||||
if shared.args.load_in_8bit:
|
||||
import modules.LLaMA_8bit
|
||||
from modules.LLaMA_8bit import LLaMAModel_8bit
|
||||
|
||||
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
|
||||
model = LLaMAModel_8bit.from_pretrained(Path(f'models/{model_name}'))
|
||||
|
||||
return model, None
|
||||
return model, None
|
||||
else:
|
||||
import modules.LLaMA
|
||||
from modules.LLaMA import LLaMAModel
|
||||
|
||||
model = LLaMAModel.from_pretrained(Path(f'models/{model_name}'))
|
||||
|
||||
return model, None
|
||||
|
||||
# Custom
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue