Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)

This commit is contained in:
Wojtab 2023-05-10 01:18:02 +02:00 committed by GitHub
parent a2b25322f0
commit e9e75a9ec7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 812 additions and 371 deletions

View file

@ -7,6 +7,7 @@ import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()
@ -73,15 +74,12 @@ def _apply_input_hijack(text, visible_text):
return text, visible_text
# custom_generate_chat_prompt handling
# custom_generate_chat_prompt handling - currently only the first one will work
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
if hasattr(extension, 'custom_generate_chat_prompt'):
return custom_generate_chat_prompt(text, state, **kwargs)
return None
@ -95,16 +93,26 @@ def _apply_state_modifier_extensions(state):
return state
# Extension functions that override the default tokenizer output
# Extension functions that override the default tokenizer output - currently only the first one will work
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
# Custom generate reply handling
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
# currently only the first one will work
def _apply_custom_tokenized_length(prompt):
for extension, _ in iterator():
if hasattr(extension, 'custom_tokenized_length'):
return getattr(extension, 'custom_tokenized_length')(prompt)
return None
# Custom generate reply handling - currently only the first one will work
def _apply_custom_generate_reply():
for extension, _ in iterator():
if hasattr(extension, 'custom_generate_reply'):
@ -121,7 +129,8 @@ EXTENSION_MAP = {
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
"custom_generate_reply": _apply_custom_generate_reply
"custom_generate_reply": _apply_custom_generate_reply,
"tokenized_length": _apply_custom_tokenized_length
}