Add LLaMA support

This commit is contained in:
oobabooga 2023-03-03 14:39:14 -03:00
parent 2bff646130
commit ea5c5eb3da
4 changed files with 110 additions and 2 deletions

96
modules/LLaMA.py Normal file
View file

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the GNU General Public License version 3.
import json
import os
import sys
import time
from pathlib import Path
from typing import Tuple
import fire
import torch
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama import LLaMA, ModelArgs, Tokenizer, Transformer
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MP'] = '1'
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '2223'
def setup_model_parallel() -> Tuple[int, int]:
local_rank = int(os.environ.get("LOCAL_RANK", -1))
world_size = int(os.environ.get("WORLD_SIZE", -1))
torch.distributed.init_process_group("gloo")
initialize_model_parallel(world_size)
torch.cuda.set_device(local_rank)
# seed must be the same in all processes
torch.manual_seed(1)
return local_rank, world_size
def load(
ckpt_dir: str,
tokenizer_path: str,
local_rank: int,
world_size: int,
max_seq_len: int,
max_batch_size: int,
) -> LLaMA:
start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
assert world_size == len(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
ckpt_path = checkpoints[local_rank]
print("Loading")
checkpoint = torch.load(ckpt_path, map_location="cpu")
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
torch.set_default_tensor_type(torch.FloatTensor)
model.load_state_dict(checkpoint, strict=False)
generator = LLaMA(model, tokenizer)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
return generator
class LLaMAModel:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path, max_seq_len=512, max_batch_size=32):
tokenizer_path = path / "tokenizer.model"
path = os.path.abspath(path)
tokenizer_path = os.path.abspath(tokenizer_path)
local_rank, world_size = setup_model_parallel()
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
generator = load(
path, tokenizer_path, local_rank, world_size, max_seq_len, max_batch_size
)
result = self()
result.pipeline = generator
return result
def generate(self, prompt, token_count=512, temperature=0.8, top_p=0.95):
results = self.pipeline.generate(
[prompt], max_gen_len=token_count, temperature=temperature, top_p=top_p
)
return results[0]