Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)

This commit is contained in:
Wojtab 2023-05-10 01:18:02 +02:00 committed by GitHub
parent a2b25322f0
commit e9e75a9ec7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 812 additions and 371 deletions

View file

@ -0,0 +1,85 @@
# Technical description of multimodal extension
## Working principle
Multimodality extension does most of the stuff which is required for any image input:
- adds the UI
- saves the images as base64 JPEGs to history
- provides the hooks to the UI
- if there are images in the prompt, it:
- splits the prompt to text and image parts
- adds image start/end markers to text parts, then encodes and embeds the text parts
- calls the vision pipeline to embed the images
- stitches the embeddings together, and returns them to text generation
- loads the appropriate vision pipeline, selected either from model name, or by specifying --multimodal-pipeline parameter
Now, for the pipelines, they:
- load the required vision models
- return some consts, for example the number of tokens taken up by image
- and most importantly: return the embeddings for LLM, given a list of images
## Prompts/history
To save images in prompt/history, this extension is using a base64 JPEG, wrapped in a HTML tag, like so:
```
<img src="data:image/jpeg;base64,{img_str}">
```
where `{img_str}` is the actual image data. This format makes displaying them in the UI for free. Do note, that this format is required to be exactly the same, the regex used to find the images is: `<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">`.
## LLM input
To describe the input, let's see it on an example prompt:
```
text1<image1>text2<image2>text3
```
where `textN` is N-th text, `<imageN>` is N-th image, in HTML format specified above.
**The first step is to split the prompt into image/text parts**, so we get:
```
['text1', '<image1>', 'text2', '<image2>', 'text3']
```
this is done in `MultimodalEmbedder._split_prompt(...)` function, which returns a list of `PromptPart`s - dataclasses wrapping the separate parts.
This function also appends the image start/end markers to text, which are provided by `AbstractMultimodalPipeline.image_start()` / `AbstractMultimodalPipeline.image_end()` functions. If image start is `<Img>`, and end is `</Img>`, this function will return:
```
['text1<Img>', '<image1>', '</Img>text2<Img>', '<image2>', '</Img>text3']
```
**The returned prompt parts are then turned into token embeddings.**
First, they are modified to token IDs, for the text it is done using standard `modules.text_generation.encode()` function, and for the images the returned token IDs are changed to placeholders. The placeholder is a list of `N` times `placeholder token id`, where `N` is specified using `AbstractMultimodalPipeline.num_image_embeds()`, and placeholder token IDs using `AbstractMultimodalPipeline.placeholder_token_id()`.
Now, based on the token IDs, the prompt might get truncated, especially if `max_new_tokens` are unreasonably high. Unfortunately, it can't be done simply, just by trimming the prompt to be short enough. This way will lead to sometimes splitting the prompt in the middle of an image embedding, which usually breaks the generation. Therefore, in this case, the entire image needs to be removed from input. This is done inside `MultimodalEmbedder._encode_text(...)` function.
**After the tokenization, the tokens need to get embedded**, the text and images are once again treated separately.
The text parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_tokens(...)` function. It uses standard embedding function from the model, but to support many LLMs, the actual function is returned by the pipeline (as it might be different for different LLMs), for LLaMA it is `shared.model.model.embed_tokens(...)`.
The image parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_images(...)` function. This function is specific for a given pipeline, it takes the images as input, forwards them through vision model/projector, and returns the embeddings.
**Now, the returned embeddings are stitched together**, using `torch.cat()`, this is creating the final input to the LLM.
## Pipelines
All of the pipelines should subclass `AbstractMultimodalPipeline` class. The idea is to allow for new pipelines to be added in the same way as user extensions - git clone into `extensions/multimodal/pipelines`.
The pipelines are the description of the vision part, containing vision model/multimodal projector. All of the pipelines should have an unique `name()`, which is then selected by user, in `--multimodal-pipeline` CLI argument. For an example, see `pipelines/llava/llava.py`.
## Pipeline modules
Pipelines are organized into "pipeline modules" - subdirectories in `pipelines` directory. The pipeline modules should contain a file called `pipelines.py`, that should contain the following fields:
- `available_pipelines: List[str]` - list of pipelines provided by this module, shown as the list of available pipelines to the user
- `def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a concrete pipeline by `name`, if `name` doesn't match any, should return `None`. `params` is the user settings for multimodal extension
- `def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a pipeline from `model_name`, should be eager to return `None`, unless the determination can be done clearly (for example: minigpt-4 bases on vicuna - it should never return the pipeline, but llava can, as it has its own specific LLM finetune)
**NOTE**: A pipeline module should lazy-import the pipelines only when necessary, and it should keep its imports to minimum
## Pipeline params
The pipelines will get the extension `params` in the constructor. They should honor the following fields:
- `vision_device` - string, specifying `torch.device` to run the vision model (CLIP/ViT) on
- `vision_bits` - int, number of fp bits to load the vision model(s) in
- `projector_device` - string, specifying `torch.device` to run the projector models (Linear layers, QFormer, etc.) on
- `projector_bits` - int, number of fp bits to load the projector models in
As a helper, `AbstractMultimodalPipeline` has `_get_device(self, setting_name: str, params: dict)` and `_get_dtype(self, setting_name: str, params: dict)` helper functions, which parse string/int and return `torch.device` / `torch.dtype`.

View file

@ -0,0 +1,78 @@
# Multimodal
## Description
Adds support for multimodality (text+images) to text-generation-webui.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## Usage
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
```
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat
```
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
- install the requirements.txt
The same procedure should be used to install other pipelines, which can then me used with `--multimodal-pipeline [pipeline name]`. For additional multimodal pipelines refer to compatibility section below.
Do note, that each image takes up a considerable amount of tokens, so adjust `max_new_tokens` to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
## Compatibility
As of now, the following multimodal pipelines are supported:
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|-|-|-|-|-|
|[LLaVA 13B](https://github.com/haotian-liu/LLaVA)|`llava-13b`|[LLaVA 13B](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
Some pipelines could support different LLMs, but do note that while it might work, it isn't a supported configuration.
DO NOT report bugs if you are using a different LLM.
DO NOT report bugs with pipelines in this repository (unless they are built-in)
## Extension config
This extension uses following parameters (from settings.json):
|Parameter|Description|
|---------|-----------|
|`multimodal-vision_bits`|Number of bits to load vision models (CLIP/ViT) feature extractor in (most pipelines should support either 32 or 16, default=32)|
|`multimodal-vision_device`|Torch device to run the feature extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`multimodal-projector_bits`|Number of bits to load feature projector model(s) in (most pipelines should support either 32 or 16, default=32)|
|`multimodal-projector_device`|Torch device to run the feature projector model(s) on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`multimodal-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
## Usage through API
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
```Python
import base64
import requests
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
with open('extreme_ironing.jpg', 'rb') as f:
img_str = base64.b64encode(f.read()).decode('utf-8')
prompt = CONTEXT + f'### Human: What is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">### Assistant: '
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
```
script output:
```Python
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
```
## For pipeline developers/technical description
see [DOCS.md](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/multimodal/DOCS.md)

View file

@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
from PIL import Image
class AbstractMultimodalPipeline(ABC):
@staticmethod
@abstractmethod
def name() -> str:
'name of the pipeline, should be same as in --multimodal-pipeline'
pass
@staticmethod
@abstractmethod
def image_start() -> Optional[str]:
'return image start string, string representation of image start token, or None if not applicable'
pass
@staticmethod
@abstractmethod
def image_end() -> Optional[str]:
'return image end string, string representation of image end token, or None if not applicable'
pass
@staticmethod
@abstractmethod
def placeholder_token_id() -> int:
'return placeholder token id'
pass
@staticmethod
@abstractmethod
def num_image_embeds() -> int:
'return the number of embeds used by a single image (for example: 256 for LLaVA)'
pass
@abstractmethod
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
'forward the images through vision pipeline, and return their embeddings'
pass
@staticmethod
@abstractmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
'embed tokens, the exact function varies by LLM, for LLaMA it is `shared.model.model.embed_tokens`'
pass
@staticmethod
@abstractmethod
def placeholder_embeddings() -> torch.Tensor:
'get placeholder embeddings if there are multiple images, and `add_all_images_to_prompt` is False'
pass
def _get_device(self, setting_name: str, params: dict):
if params[setting_name] is None:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(params[setting_name])
def _get_dtype(self, setting_name: str, params: dict):
return torch.float32 if int(params[setting_name]) == 32 else torch.float16

View file

@ -0,0 +1,177 @@
import base64
import logging
import re
from dataclasses import dataclass
from io import BytesIO
from typing import Any, List, Optional
import torch
from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared
from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass
class PromptPart:
text: str
image: Optional[Image.Image] = None
is_image: bool = False
input_ids: Optional[torch.Tensor] = None
embedding: Optional[torch.Tensor] = None
class MultimodalEmbedder:
def __init__(self, params: dict):
pipeline, source = load_pipeline(params)
self.pipeline = pipeline
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
It will also append `image_start` and `image_end` before and after the image, and optionally parse and load the images,
if `load_images` is `True`.
"""
parts: List[PromptPart] = []
curr = 0
while True:
match = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt[curr:])
if match is None:
# no more image tokens, append the rest of the prompt
if curr > 0:
# add image end token after last image
parts.append(PromptPart(text=self.pipeline.image_end() + prompt[curr:]))
else:
parts.append(PromptPart(text=prompt))
break
# found an image, append image start token to the text
if match.start() > 0:
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start()))
else:
parts.append(PromptPart(text=self.pipeline.image_start()))
# append the image
parts.append(PromptPart(
text=match.group(0),
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
is_image=True
))
curr += match.end()
return parts
def _len_in_tokens_prompt_parts(self, parts: List[PromptPart]) -> int:
"""Total length in tokens of all `parts`"""
tokens = 0
for part in parts:
if part.is_image:
tokens += self.pipeline.num_image_embeds()
elif part.input_ids is not None:
tokens += len(part.input_ids)
else:
tokens += len(encode(part.text)[0])
return tokens
def len_in_tokens(self, prompt: str) -> int:
"""Total length in tokens for a given text `prompt`"""
parts = self._split_prompt(prompt, False)
return self._len_in_tokens_prompt_parts(parts)
def _encode_single_text(self, part: PromptPart, add_bos_token: bool) -> PromptPart:
"""Encode a single prompt `part` to `input_ids`. Returns a `PromptPart`"""
if part.is_image:
placeholders = torch.ones((self.pipeline.num_image_embeds())) * self.pipeline.placeholder_token_id()
part.input_ids = placeholders.to(shared.model.device, dtype=torch.int64)
else:
part.input_ids = encode(part.text, add_bos_token=add_bos_token)[0].to(shared.model.device, dtype=torch.int64)
return part
@staticmethod
def _num_images(parts: List[PromptPart]) -> int:
count = 0
for part in parts:
if part.is_image:
count += 1
return count
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
"""Encode text to token_ids, also truncate the prompt, if necessary.
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
such that the context + min_rows don't fit, we can get a prompt which is too long.
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
"""
encoded: List[PromptPart] = []
for i, part in enumerate(parts):
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token']))
# truncation:
max_len = get_max_prompt_length(state)
removed_images = 0
# 1. remove entire text/image blocks
while self._len_in_tokens_prompt_parts(encoded[1:]) > max_len:
if encoded[0].is_image:
removed_images += 1
encoded = encoded[1:]
# 2. check if the last prompt part doesn't need to get truncated
if self._len_in_tokens_prompt_parts(encoded) > max_len:
if encoded[0].is_image:
# don't truncate image embeddings, just remove the image, otherwise generation will be broken
removed_images += 1
encoded = encoded[1:]
elif len(encoded) > 1 and encoded[0].text.endswith(self.pipeline.image_start()):
# see if we can keep image_start token
len_image_start = len(encode(self.pipeline.image_start(), add_bos_token=state['add_bos_token'])[0])
if self._len_in_tokens_prompt_parts(encoded[1:]) + len_image_start > max_len:
# we can't -> remove this text, and the image
encoded = encoded[2:]
removed_images += 1
else:
# we can -> just truncate the text
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
elif len(encoded) > 0:
# only one text left, truncate it normally
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
# notify user if we truncated an image
if removed_images > 0:
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
return encoded
def _embed(self, parts: List[PromptPart]) -> List[PromptPart]:
# batch images
image_indicies = [i for i, part in enumerate(parts) if part.is_image]
embedded = self.pipeline.embed_images([parts[i].image for i in image_indicies])
for i, embeds in zip(image_indicies, embedded):
parts[i].embedding = embeds
# embed text
for (i, part) in enumerate(parts):
if not part.is_image:
parts[i].embedding = self.pipeline.embed_tokens(part.input_ids)
return parts
def _remove_old_images(self, parts: List[PromptPart], params: dict) -> List[PromptPart]:
if params['add_all_images_to_prompt']:
return parts
already_added = False
for i, part in reversed(list(enumerate(parts))):
if part.is_image:
if already_added:
parts[i].embedding = self.pipeline.placeholder_embeddings()
else:
already_added = True
return parts
def forward(self, prompt: str, state: Any, params: dict):
prompt_parts = self._split_prompt(prompt, True)
prompt_parts = self._encode_text(state, prompt_parts)
prompt_parts = self._embed(prompt_parts)
prompt_parts = self._remove_old_images(prompt_parts, params)
embeds = tuple(part.embedding for part in prompt_parts)
ids = tuple(part.input_ids for part in prompt_parts)
input_embeds = torch.cat(embeds, dim=0)
input_ids = torch.cat(ids, dim=0)
return prompt, input_ids, input_embeds, self._num_images(prompt_parts)

View file

@ -0,0 +1,52 @@
import logging
import traceback
from importlib import import_module
from pathlib import Path
from typing import Tuple
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from modules import shared
def _get_available_pipeline_modules():
pipeline_path = Path(__file__).parent / 'pipelines'
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
return [m.name for m in modules if (m / 'pipelines.py').exists()]
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
pipeline_modules = {}
available_pipeline_modules = _get_available_pipeline_modules()
for name in available_pipeline_modules:
try:
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
except:
logging.warning(f'Failed to get multimodal pipelines from {name}')
logging.warning(traceback.format_exc())
if shared.args.multimodal_pipeline is not None:
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
if pipeline is not None:
return (pipeline, k)
else:
model_name = shared.args.model.lower()
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
if pipeline is not None:
return (pipeline, k)
available = []
for k in pipeline_modules:
if hasattr(pipeline_modules[k], 'available_pipelines'):
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
available += pipelines
if shared.args.multimodal_pipeline is not None:
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
else:
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')

View file

@ -0,0 +1,9 @@
## LLaVA pipeline
This module provides 2 pipelines:
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit

View file

@ -0,0 +1,139 @@
import logging
import time
from abc import abstractmethod
from typing import List, Tuple
import torch
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
from huggingface_hub import hf_hub_download
from modules import shared
from modules.text_generation import encode
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
CLIP_REPO = "openai/clip-vit-large-patch14"
def __init__(self, params: dict) -> None:
super().__init__()
self.clip_device = self._get_device("vision_device", params)
self.clip_dtype = self._get_dtype("vision_bits", params)
self.projector_device = self._get_device("projector_device", params)
self.projector_dtype = self._get_dtype("projector_bits", params)
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
def _load_models(self):
start_ts = time.time()
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device)
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
@staticmethod
def image_start() -> str:
return "<im_start>"
@staticmethod
def image_end() -> str:
return "<im_end>"
@staticmethod
def num_image_embeds() -> int:
return 256
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
def placeholder_embeddings() -> torch.Tensor:
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
images = self.image_processor(images, return_tensors='pt')['pixel_values']
images = images.to(self.clip_device, dtype=self.clip_dtype)
with torch.no_grad():
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
image_features = self.mm_projector(image_features)
return image_features.to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
@abstractmethod
def llava_projector_repo() -> str:
pass
@staticmethod
@abstractmethod
def llava_projector_filename() -> str:
pass
@staticmethod
@abstractmethod
def llava_projector_shape() -> Tuple[int, int]:
pass
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-13b"
@staticmethod
def placeholder_token_id() -> int:
return 32000
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 5120)
@staticmethod
def llava_projector_filename() -> str:
return "mm_projector.bin"
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/LLaVA-13b-delta-v0"
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
def __init__(self, params: dict) -> None:
super().__init__(params)
@staticmethod
def name() -> str:
return "llava-7b"
@staticmethod
def placeholder_token_id() -> int:
return 32001
@staticmethod
def llava_projector_shape() -> Tuple[int, int]:
return (1024, 4096)
@staticmethod
def llava_projector_filename() -> str:
return "mm_projector.bin"
@staticmethod
def llava_projector_repo() -> str:
return "liuhaotian/LLaVA-7b-delta-v0"

View file

@ -0,0 +1,27 @@
from typing import Optional
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
available_pipelines = ['llava-7b', 'llava-13b']
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
if name == 'llava-7b':
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if name == 'llava-13b':
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
if 'llava' not in model_name.lower():
return None
if '7b' in model_name.lower():
from .llava import LLaVA_v0_7B_Pipeline
return LLaVA_v0_7B_Pipeline(params)
if '13b' in model_name.lower():
from .llava import LLaVA_v0_13B_Pipeline
return LLaVA_v0_13B_Pipeline(params)
return None

View file

@ -0,0 +1,103 @@
import base64
import logging
import re
import time
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared
params = {
"add_all_images_to_prompt": False,
# device to run vision encoder on
"vision_device": None,
# bits to load vision encoder in, either 16 or 32
"vision_bits": 32,
# device to run multimodal projector on
"projector_device": None,
# multimodal projector bits, either 32 or 16
"projector_bits": 32
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
multimodal_embedder: MultimodalEmbedder = None
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_tokenized_length(prompt):
return multimodal_embedder.len_in_tokens(prompt)
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
if image_match is None:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
def ui():
global multimodal_embedder
multimodal_embedder = MultimodalEmbedder(params)
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# The models don't seem to deal well with multiple images
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)