Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
22 changed files with 812 additions and 371 deletions
85
extensions/multimodal/DOCS.md
Normal file
85
extensions/multimodal/DOCS.md
Normal file
|
@ -0,0 +1,85 @@
|
|||
# Technical description of multimodal extension
|
||||
|
||||
## Working principle
|
||||
Multimodality extension does most of the stuff which is required for any image input:
|
||||
|
||||
- adds the UI
|
||||
- saves the images as base64 JPEGs to history
|
||||
- provides the hooks to the UI
|
||||
- if there are images in the prompt, it:
|
||||
- splits the prompt to text and image parts
|
||||
- adds image start/end markers to text parts, then encodes and embeds the text parts
|
||||
- calls the vision pipeline to embed the images
|
||||
- stitches the embeddings together, and returns them to text generation
|
||||
- loads the appropriate vision pipeline, selected either from model name, or by specifying --multimodal-pipeline parameter
|
||||
|
||||
Now, for the pipelines, they:
|
||||
|
||||
- load the required vision models
|
||||
- return some consts, for example the number of tokens taken up by image
|
||||
- and most importantly: return the embeddings for LLM, given a list of images
|
||||
|
||||
## Prompts/history
|
||||
|
||||
To save images in prompt/history, this extension is using a base64 JPEG, wrapped in a HTML tag, like so:
|
||||
```
|
||||
<img src="data:image/jpeg;base64,{img_str}">
|
||||
```
|
||||
where `{img_str}` is the actual image data. This format makes displaying them in the UI for free. Do note, that this format is required to be exactly the same, the regex used to find the images is: `<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">`.
|
||||
|
||||
## LLM input
|
||||
To describe the input, let's see it on an example prompt:
|
||||
```
|
||||
text1<image1>text2<image2>text3
|
||||
```
|
||||
where `textN` is N-th text, `<imageN>` is N-th image, in HTML format specified above.
|
||||
|
||||
**The first step is to split the prompt into image/text parts**, so we get:
|
||||
```
|
||||
['text1', '<image1>', 'text2', '<image2>', 'text3']
|
||||
```
|
||||
this is done in `MultimodalEmbedder._split_prompt(...)` function, which returns a list of `PromptPart`s - dataclasses wrapping the separate parts.
|
||||
|
||||
This function also appends the image start/end markers to text, which are provided by `AbstractMultimodalPipeline.image_start()` / `AbstractMultimodalPipeline.image_end()` functions. If image start is `<Img>`, and end is `</Img>`, this function will return:
|
||||
```
|
||||
['text1<Img>', '<image1>', '</Img>text2<Img>', '<image2>', '</Img>text3']
|
||||
```
|
||||
|
||||
**The returned prompt parts are then turned into token embeddings.**
|
||||
|
||||
First, they are modified to token IDs, for the text it is done using standard `modules.text_generation.encode()` function, and for the images the returned token IDs are changed to placeholders. The placeholder is a list of `N` times `placeholder token id`, where `N` is specified using `AbstractMultimodalPipeline.num_image_embeds()`, and placeholder token IDs using `AbstractMultimodalPipeline.placeholder_token_id()`.
|
||||
|
||||
Now, based on the token IDs, the prompt might get truncated, especially if `max_new_tokens` are unreasonably high. Unfortunately, it can't be done simply, just by trimming the prompt to be short enough. This way will lead to sometimes splitting the prompt in the middle of an image embedding, which usually breaks the generation. Therefore, in this case, the entire image needs to be removed from input. This is done inside `MultimodalEmbedder._encode_text(...)` function.
|
||||
|
||||
**After the tokenization, the tokens need to get embedded**, the text and images are once again treated separately.
|
||||
|
||||
The text parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_tokens(...)` function. It uses standard embedding function from the model, but to support many LLMs, the actual function is returned by the pipeline (as it might be different for different LLMs), for LLaMA it is `shared.model.model.embed_tokens(...)`.
|
||||
|
||||
The image parts are turned to embeddings, using `AbstractMultimodalPipeline.embed_images(...)` function. This function is specific for a given pipeline, it takes the images as input, forwards them through vision model/projector, and returns the embeddings.
|
||||
|
||||
**Now, the returned embeddings are stitched together**, using `torch.cat()`, this is creating the final input to the LLM.
|
||||
|
||||
## Pipelines
|
||||
|
||||
All of the pipelines should subclass `AbstractMultimodalPipeline` class. The idea is to allow for new pipelines to be added in the same way as user extensions - git clone into `extensions/multimodal/pipelines`.
|
||||
|
||||
The pipelines are the description of the vision part, containing vision model/multimodal projector. All of the pipelines should have an unique `name()`, which is then selected by user, in `--multimodal-pipeline` CLI argument. For an example, see `pipelines/llava/llava.py`.
|
||||
|
||||
## Pipeline modules
|
||||
|
||||
Pipelines are organized into "pipeline modules" - subdirectories in `pipelines` directory. The pipeline modules should contain a file called `pipelines.py`, that should contain the following fields:
|
||||
- `available_pipelines: List[str]` - list of pipelines provided by this module, shown as the list of available pipelines to the user
|
||||
- `def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a concrete pipeline by `name`, if `name` doesn't match any, should return `None`. `params` is the user settings for multimodal extension
|
||||
- `def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]`: - a function to get a pipeline from `model_name`, should be eager to return `None`, unless the determination can be done clearly (for example: minigpt-4 bases on vicuna - it should never return the pipeline, but llava can, as it has its own specific LLM finetune)
|
||||
|
||||
**NOTE**: A pipeline module should lazy-import the pipelines only when necessary, and it should keep its imports to minimum
|
||||
|
||||
## Pipeline params
|
||||
|
||||
The pipelines will get the extension `params` in the constructor. They should honor the following fields:
|
||||
- `vision_device` - string, specifying `torch.device` to run the vision model (CLIP/ViT) on
|
||||
- `vision_bits` - int, number of fp bits to load the vision model(s) in
|
||||
- `projector_device` - string, specifying `torch.device` to run the projector models (Linear layers, QFormer, etc.) on
|
||||
- `projector_bits` - int, number of fp bits to load the projector models in
|
||||
|
||||
As a helper, `AbstractMultimodalPipeline` has `_get_device(self, setting_name: str, params: dict)` and `_get_dtype(self, setting_name: str, params: dict)` helper functions, which parse string/int and return `torch.device` / `torch.dtype`.
|
78
extensions/multimodal/README.md
Normal file
78
extensions/multimodal/README.md
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Multimodal
|
||||
|
||||
## Description
|
||||
|
||||
Adds support for multimodality (text+images) to text-generation-webui.
|
||||
|
||||
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
||||
|
||||
## Usage
|
||||
|
||||
To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples:
|
||||
|
||||
```
|
||||
python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat
|
||||
python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat
|
||||
python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat
|
||||
python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat
|
||||
```
|
||||
|
||||
There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`:
|
||||
|
||||
- clone https://github.com/Wojtab/minigpt-4-pipeline into `extensions/multimodal/pipelines`
|
||||
- install the requirements.txt
|
||||
|
||||
The same procedure should be used to install other pipelines, which can then me used with `--multimodal-pipeline [pipeline name]`. For additional multimodal pipelines refer to compatibility section below.
|
||||
|
||||
Do note, that each image takes up a considerable amount of tokens, so adjust `max_new_tokens` to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
||||
|
||||
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
||||
|
||||
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. It seems as some multimodal networks consider the features in all images at the same time as if they were a single image. Due to this behavior, by default the extension skips previous images. However, it can lead to sub-par generation on other pipelines. If you want to include all images, just tick this checkbox.
|
||||
|
||||
## Compatibility
|
||||
As of now, the following multimodal pipelines are supported:
|
||||
|Pipeline|`--multimodal-pipeline`|Default LLM|LLM info(for the linked model)|Pipeline repository|
|
||||
|-|-|-|-|-|
|
||||
|[LLaVA 13B](https://github.com/haotian-liu/LLaVA)|`llava-13b`|[LLaVA 13B](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
||||
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
||||
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||
|
||||
Some pipelines could support different LLMs, but do note that while it might work, it isn't a supported configuration.
|
||||
|
||||
DO NOT report bugs if you are using a different LLM.
|
||||
|
||||
DO NOT report bugs with pipelines in this repository (unless they are built-in)
|
||||
|
||||
## Extension config
|
||||
This extension uses following parameters (from settings.json):
|
||||
|Parameter|Description|
|
||||
|---------|-----------|
|
||||
|`multimodal-vision_bits`|Number of bits to load vision models (CLIP/ViT) feature extractor in (most pipelines should support either 32 or 16, default=32)|
|
||||
|`multimodal-vision_device`|Torch device to run the feature extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`multimodal-projector_bits`|Number of bits to load feature projector model(s) in (most pipelines should support either 32 or 16, default=32)|
|
||||
|`multimodal-projector_device`|Torch device to run the feature projector model(s) on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`multimodal-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
||||
|
||||
## Usage through API
|
||||
|
||||
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
|
||||
```Python
|
||||
import base64
|
||||
import requests
|
||||
|
||||
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.### Human: Hi!### Assistant: Hi there! How can I help you today?\n"
|
||||
|
||||
with open('extreme_ironing.jpg', 'rb') as f:
|
||||
img_str = base64.b64encode(f.read()).decode('utf-8')
|
||||
prompt = CONTEXT + f'### Human: What is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">### Assistant: '
|
||||
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
|
||||
```
|
||||
script output:
|
||||
```Python
|
||||
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
|
||||
```
|
||||
|
||||
## For pipeline developers/technical description
|
||||
see [DOCS.md](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/multimodal/DOCS.md)
|
62
extensions/multimodal/abstract_pipeline.py
Normal file
62
extensions/multimodal/abstract_pipeline.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class AbstractMultimodalPipeline(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def name() -> str:
|
||||
'name of the pipeline, should be same as in --multimodal-pipeline'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def image_start() -> Optional[str]:
|
||||
'return image start string, string representation of image start token, or None if not applicable'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def image_end() -> Optional[str]:
|
||||
'return image end string, string representation of image end token, or None if not applicable'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def placeholder_token_id() -> int:
|
||||
'return placeholder token id'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def num_image_embeds() -> int:
|
||||
'return the number of embeds used by a single image (for example: 256 for LLaVA)'
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
'forward the images through vision pipeline, and return their embeddings'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
'embed tokens, the exact function varies by LLM, for LLaMA it is `shared.model.model.embed_tokens`'
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
'get placeholder embeddings if there are multiple images, and `add_all_images_to_prompt` is False'
|
||||
pass
|
||||
|
||||
def _get_device(self, setting_name: str, params: dict):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name: str, params: dict):
|
||||
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
177
extensions/multimodal/multimodal_embedder.py
Normal file
177
extensions/multimodal/multimodal_embedder.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
import base64
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptPart:
|
||||
text: str
|
||||
image: Optional[Image.Image] = None
|
||||
is_image: bool = False
|
||||
input_ids: Optional[torch.Tensor] = None
|
||||
embedding: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class MultimodalEmbedder:
|
||||
def __init__(self, params: dict):
|
||||
pipeline, source = load_pipeline(params)
|
||||
self.pipeline = pipeline
|
||||
logging.info(f'Multimodal: loaded pipeline {self.pipeline.name()} from pipelines/{source} ({self.pipeline.__class__.__name__})')
|
||||
|
||||
def _split_prompt(self, prompt: str, load_images: bool = False) -> List[PromptPart]:
|
||||
"""Splits a prompt into a list of `PromptParts` to separate image data from text.
|
||||
It will also append `image_start` and `image_end` before and after the image, and optionally parse and load the images,
|
||||
if `load_images` is `True`.
|
||||
"""
|
||||
parts: List[PromptPart] = []
|
||||
curr = 0
|
||||
while True:
|
||||
match = re.search(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt[curr:])
|
||||
if match is None:
|
||||
# no more image tokens, append the rest of the prompt
|
||||
if curr > 0:
|
||||
# add image end token after last image
|
||||
parts.append(PromptPart(text=self.pipeline.image_end() + prompt[curr:]))
|
||||
else:
|
||||
parts.append(PromptPart(text=prompt))
|
||||
break
|
||||
# found an image, append image start token to the text
|
||||
if match.start() > 0:
|
||||
parts.append(PromptPart(text=prompt[curr:curr+match.start()]+self.pipeline.image_start()))
|
||||
else:
|
||||
parts.append(PromptPart(text=self.pipeline.image_start()))
|
||||
# append the image
|
||||
parts.append(PromptPart(
|
||||
text=match.group(0),
|
||||
image=Image.open(BytesIO(base64.b64decode(match.group(1)))) if load_images else None,
|
||||
is_image=True
|
||||
))
|
||||
curr += match.end()
|
||||
return parts
|
||||
|
||||
def _len_in_tokens_prompt_parts(self, parts: List[PromptPart]) -> int:
|
||||
"""Total length in tokens of all `parts`"""
|
||||
tokens = 0
|
||||
for part in parts:
|
||||
if part.is_image:
|
||||
tokens += self.pipeline.num_image_embeds()
|
||||
elif part.input_ids is not None:
|
||||
tokens += len(part.input_ids)
|
||||
else:
|
||||
tokens += len(encode(part.text)[0])
|
||||
return tokens
|
||||
|
||||
def len_in_tokens(self, prompt: str) -> int:
|
||||
"""Total length in tokens for a given text `prompt`"""
|
||||
parts = self._split_prompt(prompt, False)
|
||||
return self._len_in_tokens_prompt_parts(parts)
|
||||
|
||||
def _encode_single_text(self, part: PromptPart, add_bos_token: bool) -> PromptPart:
|
||||
"""Encode a single prompt `part` to `input_ids`. Returns a `PromptPart`"""
|
||||
if part.is_image:
|
||||
placeholders = torch.ones((self.pipeline.num_image_embeds())) * self.pipeline.placeholder_token_id()
|
||||
part.input_ids = placeholders.to(shared.model.device, dtype=torch.int64)
|
||||
else:
|
||||
part.input_ids = encode(part.text, add_bos_token=add_bos_token)[0].to(shared.model.device, dtype=torch.int64)
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def _num_images(parts: List[PromptPart]) -> int:
|
||||
count = 0
|
||||
for part in parts:
|
||||
if part.is_image:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _encode_text(self, state, parts: List[PromptPart]) -> List[PromptPart]:
|
||||
"""Encode text to token_ids, also truncate the prompt, if necessary.
|
||||
|
||||
The chat/instruct mode should make prompts that fit in get_max_prompt_length, but if max_new_tokens are set
|
||||
such that the context + min_rows don't fit, we can get a prompt which is too long.
|
||||
We can't truncate image embeddings, as it leads to broken generation, so remove the images instead and warn the user
|
||||
"""
|
||||
encoded: List[PromptPart] = []
|
||||
for i, part in enumerate(parts):
|
||||
encoded.append(self._encode_single_text(part, i==0 and state['add_bos_token']))
|
||||
|
||||
# truncation:
|
||||
max_len = get_max_prompt_length(state)
|
||||
removed_images = 0
|
||||
|
||||
# 1. remove entire text/image blocks
|
||||
while self._len_in_tokens_prompt_parts(encoded[1:]) > max_len:
|
||||
if encoded[0].is_image:
|
||||
removed_images += 1
|
||||
encoded = encoded[1:]
|
||||
|
||||
# 2. check if the last prompt part doesn't need to get truncated
|
||||
if self._len_in_tokens_prompt_parts(encoded) > max_len:
|
||||
if encoded[0].is_image:
|
||||
# don't truncate image embeddings, just remove the image, otherwise generation will be broken
|
||||
removed_images += 1
|
||||
encoded = encoded[1:]
|
||||
elif len(encoded) > 1 and encoded[0].text.endswith(self.pipeline.image_start()):
|
||||
# see if we can keep image_start token
|
||||
len_image_start = len(encode(self.pipeline.image_start(), add_bos_token=state['add_bos_token'])[0])
|
||||
if self._len_in_tokens_prompt_parts(encoded[1:]) + len_image_start > max_len:
|
||||
# we can't -> remove this text, and the image
|
||||
encoded = encoded[2:]
|
||||
removed_images += 1
|
||||
else:
|
||||
# we can -> just truncate the text
|
||||
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
||||
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
||||
elif len(encoded) > 0:
|
||||
# only one text left, truncate it normally
|
||||
trunc_len = self._len_in_tokens_prompt_parts(encoded) - max_len
|
||||
encoded[0].input_ids = encoded[0].input_ids[trunc_len:]
|
||||
|
||||
# notify user if we truncated an image
|
||||
if removed_images > 0:
|
||||
logging.warning(f"Multimodal: removed {removed_images} image(s) from prompt. Try decreasing max_new_tokens if generation is broken")
|
||||
|
||||
return encoded
|
||||
|
||||
def _embed(self, parts: List[PromptPart]) -> List[PromptPart]:
|
||||
# batch images
|
||||
image_indicies = [i for i, part in enumerate(parts) if part.is_image]
|
||||
embedded = self.pipeline.embed_images([parts[i].image for i in image_indicies])
|
||||
for i, embeds in zip(image_indicies, embedded):
|
||||
parts[i].embedding = embeds
|
||||
# embed text
|
||||
for (i, part) in enumerate(parts):
|
||||
if not part.is_image:
|
||||
parts[i].embedding = self.pipeline.embed_tokens(part.input_ids)
|
||||
return parts
|
||||
|
||||
def _remove_old_images(self, parts: List[PromptPart], params: dict) -> List[PromptPart]:
|
||||
if params['add_all_images_to_prompt']:
|
||||
return parts
|
||||
already_added = False
|
||||
for i, part in reversed(list(enumerate(parts))):
|
||||
if part.is_image:
|
||||
if already_added:
|
||||
parts[i].embedding = self.pipeline.placeholder_embeddings()
|
||||
else:
|
||||
already_added = True
|
||||
return parts
|
||||
|
||||
def forward(self, prompt: str, state: Any, params: dict):
|
||||
prompt_parts = self._split_prompt(prompt, True)
|
||||
prompt_parts = self._encode_text(state, prompt_parts)
|
||||
prompt_parts = self._embed(prompt_parts)
|
||||
prompt_parts = self._remove_old_images(prompt_parts, params)
|
||||
embeds = tuple(part.embedding for part in prompt_parts)
|
||||
ids = tuple(part.input_ids for part in prompt_parts)
|
||||
input_embeds = torch.cat(embeds, dim=0)
|
||||
input_ids = torch.cat(ids, dim=0)
|
||||
return prompt, input_ids, input_embeds, self._num_images(prompt_parts)
|
52
extensions/multimodal/pipeline_loader.py
Normal file
52
extensions/multimodal/pipeline_loader.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from modules import shared
|
||||
|
||||
|
||||
def _get_available_pipeline_modules():
|
||||
pipeline_path = Path(__file__).parent / 'pipelines'
|
||||
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
|
||||
return [m.name for m in modules if (m / 'pipelines.py').exists()]
|
||||
|
||||
|
||||
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
||||
pipeline_modules = {}
|
||||
available_pipeline_modules = _get_available_pipeline_modules()
|
||||
for name in available_pipeline_modules:
|
||||
try:
|
||||
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
||||
except:
|
||||
logging.warning(f'Failed to get multimodal pipelines from {name}')
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
else:
|
||||
model_name = shared.args.model.lower()
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
|
||||
available = []
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'available_pipelines'):
|
||||
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
|
||||
available += pipelines
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
||||
else:
|
||||
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
||||
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
9
extensions/multimodal/pipelines/llava/README.md
Normal file
9
extensions/multimodal/pipelines/llava/README.md
Normal file
|
@ -0,0 +1,9 @@
|
|||
## LLaVA pipeline
|
||||
|
||||
This module provides 2 pipelines:
|
||||
- `llava-7b` - for use with LLaVA v0 7B model (finetuned LLaMa 7B)
|
||||
- `llava-13b` - for use with LLaVA v0 13B model (finetuned LLaMa 13B)
|
||||
|
||||
[LLaVA](https://github.com/haotian-liu/LLaVA) uses CLIP `openai/clip-vit-large-patch14` as the vision model, and then a single linear layer. For 13B the projector weights are in `liuhaotian/LLaVA-13b-delta-v0`, and for 7B they are in `liuhaotian/LLaVA-7b-delta-v0`.
|
||||
|
||||
The supported parameter combinations for both the vision model, and the projector are: CUDA/32bit, CUDA/16bit, CPU/32bit
|
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
139
extensions/multimodal/pipelines/llava/llava.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared
|
||||
from modules.text_generation import encode
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
|
||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__()
|
||||
self.clip_device = self._get_device("vision_device", params)
|
||||
self.clip_dtype = self._get_dtype("vision_bits", params)
|
||||
self.projector_device = self._get_device("projector_device", params)
|
||||
self.projector_dtype = self._get_dtype("projector_bits", params)
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
@staticmethod
|
||||
def image_start() -> str:
|
||||
return "<im_start>"
|
||||
|
||||
@staticmethod
|
||||
def image_end() -> str:
|
||||
return "<im_end>"
|
||||
|
||||
@staticmethod
|
||||
def num_image_embeds() -> int:
|
||||
return 256
|
||||
|
||||
@staticmethod
|
||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
def placeholder_embeddings() -> torch.Tensor:
|
||||
return LLaVA_v0_Pipeline.embed_tokens(encode("<im_patch>"*256, add_bos_token=False)[0])
|
||||
|
||||
def embed_images(self, images: List[Image.Image]) -> torch.Tensor:
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features.to(shared.model.device, dtype=shared.model.dtype)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_repo() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_filename() -> str:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
|
||||
class LLaVA_v0_13B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-13b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32000
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 5120)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-13b-delta-v0"
|
||||
|
||||
|
||||
class LLaVA_v0_7B_Pipeline(LLaVA_v0_Pipeline):
|
||||
def __init__(self, params: dict) -> None:
|
||||
super().__init__(params)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "llava-7b"
|
||||
|
||||
@staticmethod
|
||||
def placeholder_token_id() -> int:
|
||||
return 32001
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_shape() -> Tuple[int, int]:
|
||||
return (1024, 4096)
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_filename() -> str:
|
||||
return "mm_projector.bin"
|
||||
|
||||
@staticmethod
|
||||
def llava_projector_repo() -> str:
|
||||
return "liuhaotian/LLaVA-7b-delta-v0"
|
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
27
extensions/multimodal/pipelines/llava/pipelines.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
from typing import Optional
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
|
||||
available_pipelines = ['llava-7b', 'llava-13b']
|
||||
|
||||
|
||||
def get_pipeline(name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if name == 'llava-7b':
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if name == 'llava-13b':
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
||||
|
||||
|
||||
def get_pipeline_from_model_name(model_name: str, params: dict) -> Optional[AbstractMultimodalPipeline]:
|
||||
if 'llava' not in model_name.lower():
|
||||
return None
|
||||
if '7b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_7B_Pipeline
|
||||
return LLaVA_v0_7B_Pipeline(params)
|
||||
if '13b' in model_name.lower():
|
||||
from .llava import LLaVA_v0_13B_Pipeline
|
||||
return LLaVA_v0_13B_Pipeline(params)
|
||||
return None
|
103
extensions/multimodal/script.py
Normal file
103
extensions/multimodal/script.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import base64
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run vision encoder on
|
||||
"vision_device": None,
|
||||
# bits to load vision encoder in, either 16 or 32
|
||||
"vision_bits": 32,
|
||||
# device to run multimodal projector on
|
||||
"projector_device": None,
|
||||
# multimodal projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
multimodal_embedder: MultimodalEmbedder = None
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_tokenized_length(prompt):
|
||||
return multimodal_embedder.len_in_tokens(prompt)
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
|
||||
|
||||
if image_match is None:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
||||
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return (prompt,
|
||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||
|
||||
|
||||
def ui():
|
||||
global multimodal_embedder
|
||||
multimodal_embedder = MultimodalEmbedder(params)
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# The models don't seem to deal well with multiple images
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
Loading…
Add table
Add a link
Reference in a new issue