Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
22 changed files with 812 additions and 371 deletions
|
@ -14,7 +14,7 @@ from PIL import Image
|
|||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
||||
from modules.text_generation import (encode, generate_reply,
|
||||
from modules.text_generation import (generate_reply, get_encoded_length,
|
||||
get_max_prompt_length)
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
|
||||
# Building the prompt
|
||||
i = len(history) - 1
|
||||
while i >= 0 and len(encode(''.join(rows))[0]) < max_length:
|
||||
while i >= 0 and get_encoded_length(''.join(rows)) < max_length:
|
||||
if _continue and i == len(history) - 1:
|
||||
rows.insert(1, bot_turn_stripped + history[i][1].strip())
|
||||
else:
|
||||
|
@ -90,7 +90,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", bot_turn_stripped.rstrip(' ')))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
while len(rows) > min_rows and get_encoded_length(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
|
|
|
@ -7,6 +7,7 @@ import gradio as gr
|
|||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
|
@ -73,15 +74,12 @@ def _apply_input_hijack(text, visible_text):
|
|||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
# custom_generate_chat_prompt handling - currently only the first one will work
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
if hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -95,16 +93,26 @@ def _apply_state_modifier_extensions(state):
|
|||
return state
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
# Custom generate reply handling
|
||||
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
||||
# currently only the first one will work
|
||||
def _apply_custom_tokenized_length(prompt):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_tokenized_length'):
|
||||
return getattr(extension, 'custom_tokenized_length')(prompt)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Custom generate reply handling - currently only the first one will work
|
||||
def _apply_custom_generate_reply():
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_generate_reply'):
|
||||
|
@ -121,7 +129,8 @@ EXTENSION_MAP = {
|
|||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply
|
||||
"custom_generate_reply": _apply_custom_generate_reply,
|
||||
"tokenized_length": _apply_custom_tokenized_length
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -166,6 +166,8 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
|||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
# Multimodal
|
||||
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
||||
|
||||
args = parser.parse_args()
|
||||
args_defaults = parser.parse_args([])
|
||||
|
@ -183,12 +185,21 @@ if args.trust_remote_code:
|
|||
if args.share:
|
||||
logging.warning("The gradio \"share link\" feature downloads a proprietary and unaudited blob to create a reverse tunnel. This is potentially dangerous.")
|
||||
|
||||
|
||||
def add_extension(name):
|
||||
if args.extensions is None:
|
||||
args.extensions = [name]
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append(name)
|
||||
|
||||
|
||||
# Activating the API extension
|
||||
if args.api or args.public_api:
|
||||
if args.extensions is None:
|
||||
args.extensions = ['api']
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append('api')
|
||||
add_extension('api')
|
||||
|
||||
# Activating the multimodal extension
|
||||
if args.multimodal_pipeline is not None:
|
||||
add_extension('multimodal')
|
||||
|
||||
|
||||
def is_chat():
|
||||
|
|
|
@ -59,6 +59,14 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
return input_ids.cuda()
|
||||
|
||||
|
||||
def get_encoded_length(prompt):
|
||||
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
||||
if length_after_extensions is not None:
|
||||
return length_after_extensions
|
||||
|
||||
return len(encode(prompt)[0])
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue