Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
22 changed files with 812 additions and 371 deletions
103
extensions/multimodal/script.py
Normal file
103
extensions/multimodal/script.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import base64
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run vision encoder on
|
||||
"vision_device": None,
|
||||
# bits to load vision encoder in, either 16 or 32
|
||||
"vision_bits": 32,
|
||||
# device to run multimodal projector on
|
||||
"projector_device": None,
|
||||
# multimodal projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
multimodal_embedder: MultimodalEmbedder = None
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_tokenized_length(prompt):
|
||||
return multimodal_embedder.len_in_tokens(prompt)
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_match = re.search(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', prompt)
|
||||
|
||||
if image_match is None:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = multimodal_embedder.forward(prompt, state, params)
|
||||
logging.info(f'Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return (prompt,
|
||||
input_ids.unsqueeze(0).to(shared.model.device, dtype=torch.int64),
|
||||
input_embeds.unsqueeze(0).to(shared.model.device, dtype=shared.model.dtype))
|
||||
|
||||
|
||||
def ui():
|
||||
global multimodal_embedder
|
||||
multimodal_embedder = MultimodalEmbedder(params)
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# The models don't seem to deal well with multiple images
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
Loading…
Add table
Add a link
Reference in a new issue