Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)

This commit is contained in:
Wojtab 2023-05-10 01:18:02 +02:00 committed by GitHub
parent a2b25322f0
commit e9e75a9ec7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 812 additions and 371 deletions

View file

@ -1,71 +0,0 @@
# LLaVA
## Description
Adds [LLaVA 13B](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## LLaVA-7B
7B version currently isn't supported. It will be supported if/when [more generic multimodality support](https://github.com/oobabooga/text-generation-webui/discussions/1687) gets implemented.
## Usage
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g) (note: it's a 4-bit [GPTQ quantization](https://github.com/oobabooga/text-generation-webui/tree/main/docs/GPTQ-models-(4-bit-mode).md), done on "old CUDA" branch), and then start server.py with `--extensions llava` argument.
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
## Extension config
This extension uses following parameters (from settings.json):
|Parameter|Description|
|---------|-----------|
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-clip_repo`|Huggingface repository of CLIP model, `openai/clip-vit-large-patch14` by default. There should be no need to change it|
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|`llava-projector_device`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-projector_repo`|Huggingface repository of multimodal projector, `liuhaotian/LLaVA-13b-delta-v0` by default. There should be no need to change it|
|`llava-projector_filename`|The filename of multimodal projector weights, `mm_projector.bin` by default. There should be no need to change it|
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
## Technical description
### Original LLaVA
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
```
images --> CLIP --> projector --> input embeddings for images --> |
| --> LLaMA
prompt -------------------------> input embeddings for text ----> |
```
The images are represented in the prompt by the following token IDs:
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
- 32001 - `<im_start>` - token marking start of an image
- 32002 - `<im_end>` - token marking end of an image
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
The concatenated prompt then gets fed to fine-tuned LLaMA.
### In this extension
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).
### Usage through API
You can run the multimodal inference through API, by inputting the images to prompt. Images are embedded like so: `f'<img src="data:image/jpeg;base64,{img_str}">'`, where `img_str` is base-64 jpeg data. Python example:
```Python
import base64
import requests
CONTEXT = "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
with open('extreme_ironing.jpg', 'rb') as f:
img_str = base64.b64encode(f.read()).decode('utf-8')
prompt = CONTEXT + f'### Human: \nWhat is unusual about this image: \n<img src="data:image/jpeg;base64,{img_str}">\n### Assistant: \n'
print(requests.post('http://127.0.0.1:5000/api/v1/generate', json={'prompt': prompt, 'stopping_strings': ['\n###']}).json())
```
script output:
```Python
{'results': [{'text': "The unusual aspect of this image is that a man is standing on top of a yellow minivan while doing his laundry. He has set up a makeshift clothes line using the car's rooftop as an outdoor drying area. This scene is uncommon because people typically do their laundry indoors, in a dedicated space like a laundromat or a room in their home, rather than on top of a moving vehicle. Additionally, hanging clothes on the car could be potentially hazardous or illegal in some jurisdictions due to the risk of damaging the vehicle or causing accidents on the road.\n##"}]}
```

View file

@ -1,272 +1,6 @@
import base64
import re
import time
from dataclasses import dataclass
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
from modules import shared
from modules.extensions import apply_extensions
from modules.text_generation import encode, get_max_prompt_length
params = {
"add_all_images_to_prompt": False,
# device to run CLIP on
"clip_device": None,
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
"clip_bits": 32,
# clip repository
"clip_repo": "openai/clip-vit-large-patch14",
# device to run projector on
"projector_device": None,
# projector bits, either 32 or 16
"projector_bits": 32,
# projector repository
"projector_repo": "liuhaotian/LLaVA-13b-delta-v0",
# file with the projector weights
"projector_file": "mm_projector.bin"
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
llava_embedder = None
@dataclass
class Token:
token: str
id: int
class LLaVAEmbedder:
IM_PATCH = Token("<im_patch>", 32000)
IM_START = Token("<im_start>", 32001)
IM_END = Token("<im_end>", 32002)
def __init__(self):
self.clip_device = self._get_device("clip_device")
self.clip_dtype = self._get_dtype("clip_bits")
self.projector_device = self._get_device("projector_device")
self.projector_dtype = self._get_dtype("projector_bits")
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
def _get_device(self, setting_name):
if params[setting_name] is None:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(params[setting_name])
def _get_dtype(self, setting_name):
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
def _load_models(self):
start_ts = time.time()
print(f"LLaVA - Loading CLIP from {params['clip_repo']} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype).to(self.clip_device)
print(f"LLaVA - Loading projector from {params['projector_repo']} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(params["projector_repo"], params["projector_file"])
mm_projector = torch.nn.Linear(1024, 5120)
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device)
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
def _update_prompt(self, prompt, images):
for _ in images:
# replace the image token with the image patch token in the prompt (each occurrence)
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
return prompt
def _extract_image_features(self, images):
images = self.image_processor(images, return_tensors='pt')['pixel_values']
images = images.to(self.clip_device, dtype=self.clip_dtype)
with torch.no_grad():
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
image_features = self.mm_projector(image_features)
return image_features
def forward(self, prompt, images, state):
prompt = self._update_prompt(prompt, images)
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
# prompt got truncated in the middle of an image, remove the image data
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
input_ids = input_ids[im_end+1:]
input_embeds = input_embeds[im_end+1:]
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
images = images[-leftover_images:]
if len(images) == 0:
return prompt, input_ids, input_embeds, 0
total_embedded = 0
image_features = self._extract_image_features(images).to(self.projector_device)
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
# multimodal LLM, but the current prompt is not multimodal/truncated
return prompt, input_ids, input_embeds, total_embedded
cur_image_idx = 0
if not params['add_all_images_to_prompt']:
image_start_tokens = [image_start_tokens[-1]]
cur_image_idx = -1
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
total_embedded += 1
return prompt, input_ids, input_embeds, total_embedded
@staticmethod
def len_in_tokens(text):
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
image_tokens = 0
for _ in images:
image_tokens += 258
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
i = len(shared.history['internal']) - 1
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}\n")
string = shared.history['internal'][i][0]
if string != '':
rows.insert(1, f"{prefix1}{string.strip()}\n")
i -= 1
if impersonate:
min_rows = 2
rows.append(f"{prefix1}")
elif not _continue:
# Adding the user message
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}\n")
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
if also_return_rows:
return prompt, rows
else:
return prompt
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
if len(images) == 0:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return (prompt,
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
import logging
def ui():
global llava_embedder
llava_embedder = LLaVAEmbedder()
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")