Support LLaVA v1.5 (#4305)

This commit is contained in:
Haotian Liu 2023-10-20 00:28:14 -05:00 committed by GitHub
parent bb71272903
commit 32984ea2f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 18 deletions

View file

@ -13,6 +13,20 @@ from modules.logging_colors import logger
from modules.text_generation import encode
def expand2square(pil_img: Image.Image, background_color: Tuple[int]) -> Image.Image:
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
CLIP_REPO = "openai/clip-vit-large-patch14"
@ -27,21 +41,33 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
def _load_models(self):
start_ts = time.time()
logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logger.info(f"LLaVA - Loading CLIP from {self.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
mm_projector = self.build_mm_projector()
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
projector_data = {k[19:]: v for k, v in projector_data.items() if k.startswith('model.mm_projector.')}
mm_projector.load_state_dict(projector_data)
mm_projector = mm_projector.to(self.projector_device)
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
def build_mm_projector(self) -> torch.nn.Module:
projector_shape = self.llava_projector_shape()
if len(projector_shape) == 2:
return torch.nn.Linear(*projector_shape)
else:
modules = []
modules.append(torch.nn.Linear(projector_shape[0], projector_shape[1]))
for i in range(2, len(projector_shape)):
modules.append(torch.nn.GELU())
modules.append(torch.nn.Linear(projector_shape[i-1], projector_shape[i]))
return torch.nn.Sequential(*modules)
@staticmethod
def image_start() -> str:
return "<im_start>"
@ -175,3 +201,50 @@ class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
class LLaVA_v1_5_13B_Pipeline(LLaVA_v0_13B_Pipeline):
CLIP_REPO = "openai/clip-vit-large-patch14-336"
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-v1.5-13b"
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 5120, 5120)
@staticmethod
def placeholder_token_id() -> int:
return 0
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/llava-v1.5-13b"
@staticmethod
def image_start() -> str:
return ""
@staticmethod
def image_end() -> str:
return ""
@staticmethod
def num_image_embeds() -> int:
return 576
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
# pad it to square first
images = [
expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
for image in images
]
return super().embed_images(images)
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*576, add_bos_token=False)[0])

View file

@ -2,7 +2,7 @@ from typing import Optional
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b']
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b', 'llava-v1.5-13b']
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
@ -15,6 +15,9 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
if name == 'llava-llama-2-13b':
from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params)
if name == 'llava-v1.5-13b':
from .llava import LLaVA_v1_5_13B_Pipeline
return LLaVA_v1_5_13B_Pipeline(params)
return None
@ -25,10 +28,15 @@ def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[Abst
if '13b' in model_name.lower():
from .llava import LLaVA_LLaMA_2_13B_Pipeline
return LLaVA_LLaMA_2_13B_Pipeline(params)
if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if '13b' in model_name.lower():
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
elif 'llava-v1.5' in model_name.lower():
if '13b' in model_name.lower():
from .llava import LLaVA_v1_5_13B_Pipeline
return LLaVA_v1_5_13B_Pipeline(params)
else:
if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if '13b' in model_name.lower():
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None