Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
22 changed files with 812 additions and 371 deletions
9
extensions/multimodal/pipelines/llava/README.md
Normal file
9
extensions/multimodal/pipelines/llava/README.md
Normal file
|
@ -0,0 +1,9 @@
|
|||
## LLaVA pipeline
|
||||
|
||||
This module provides 2 pipelines:
|
||||
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
|
||||
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
|
||||
|
||||
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
|
||||
|
||||
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit
|
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared
|
||||
from modules.text_generation import encode
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
|
||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__()
|
||||
self.clip_device = self._get_device("vision_device", params)
|
||||
self.clip_dtype = self._get_dtype("vision_bits", params)
|
||||
self.projector_device = self._get_device("projector_device", params)
|
||||
self.projector_dtype = self._get_dtype("projector_bits", params)
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return "<im_start>"
|
||||
|
||||
@staticmethod
|
||||
def image_end() -> str:
|
||||
return "<im_end>"
|
||||
|
||||
@staticmethod
|
||||
def num_image_embeds() -> int:
|
||||
return 256
|
||||
|
||||
@staticmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
|
||||
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features.to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_repo() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_filename() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
|
||||
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-13b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32000
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 5120)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-13b-delta-v0"
|
||||
|
||||
|
||||
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-7b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32001
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 4096)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-7b-delta-v0"
|
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from typing import Optional
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
|
||||
available_pipelines = ['llava-7b', 'llava-13b']
|
||||
|
||||
|
||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if name == 'llava-7b':
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if name == 'llava-13b':
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
||||
|
||||
|
||||
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if 'llava' not in model_name.lower():
|
||||
return None
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue