Support LLaVA v1.5 (#4305)
This commit is contained in:
parent
bb71272903
commit
32984ea2f0
6 changed files with 111 additions and 18 deletions
|
@ -13,6 +13,20 @@ from modules.logging_colors import logger
|
|||
from modules.text_generation import encode
|
||||
|
||||
|
||||
def expand2square(pil_img: Image.Image, background_color: Tuple[int]) -> Image.Image:
|
||||
width, height = pil_img.size
|
||||
if width == height:
|
||||
return pil_img
|
||||
elif width > height:
|
||||
result = Image.new(pil_img.mode, (width, width), background_color)
|
||||
result.paste(pil_img, (0, (width - height) // 2))
|
||||
return result
|
||||
else:
|
||||
result = Image.new(pil_img.mode, (height, height), background_color)
|
||||
result.paste(pil_img, ((height - width) // 2, 0))
|
||||
return result
|
||||
|
||||
|
||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||
|
||||
|
@ -27,21 +41,33 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
logger.info(f"LLaVA - Loading CLIP from {self.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(self.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||
mm_projector = self.build_mm_projector()
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
projector_data = {k[19:]: v for k, v in projector_data.items() if k.startswith('model.mm_projector.')}
|
||||
mm_projector.load_state_dict(projector_data)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
def build_mm_projector(self) -> torch.nn.Module:
|
||||
projector_shape = self.llava_projector_shape()
|
||||
if len(projector_shape) == 2:
|
||||
return torch.nn.Linear(*projector_shape)
|
||||
else:
|
||||
modules = []
|
||||
modules.append(torch.nn.Linear(projector_shape[0], projector_shape[1]))
|
||||
for i in range(2, len(projector_shape)):
|
||||
modules.append(torch.nn.GELU())
|
||||
modules.append(torch.nn.Linear(projector_shape[i-1], projector_shape[i]))
|
||||
return torch.nn.Sequential(*modules)
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return "<im_start>"
|
||||
|
@ -175,3 +201,50 @@ class LLaVA_LLaMA_2_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
|||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*256, add_bos_token=False)[0])
|
||||
|
||||
|
||||
class LLaVA_v1_5_13B_Pipeline(LLaVA_v0_13B_Pipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14-336"
|
||||
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-v1.5-13b"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 5120, 5120)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/llava-v1.5-13b"
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def image_end() -> str:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def num_image_embeds() -> int:
|
||||
return 576
|
||||
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
# pad it to square first
|
||||
images = [
|
||||
expand2square(image, tuple(int(x*255) for x in self.image_processor.image_mean))
|
||||
for image in images
|
||||
]
|
||||
return super().embed_images(images)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<unk>"*576, add_bos_token=False)[0])
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional
|
|||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
|
||||
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b']
|
||||
available_pipelines = ['llava-7b', 'llava-13b', 'llava-llama-2-13b', 'llava-v1.5-13b']
|
||||
|
||||
|
||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
|
@ -15,6 +15,9 @@ def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline
|
|||
if name == 'llava-llama-2-13b':
|
||||
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||
if name == 'llava-v1.5-13b':
|
||||
from .llava import LLaVA_v1_5_13B_Pipeline
|
||||
return LLaVA_v1_5_13B_Pipeline(params)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -25,10 +28,15 @@ def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[Abst
|
|||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_LLaMA_2_13B_Pipeline
|
||||
return LLaVA_LLaMA_2_13B_Pipeline(params)
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
elif 'llava-v1.5' in model_name.lower():
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v1_5_13B_Pipeline
|
||||
return LLaVA_v1_5_13B_Pipeline(params)
|
||||
else:
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue