LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
12 changed files with 426 additions and 42 deletions
|
@ -1,4 +1,5 @@
|
|||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
@ -39,17 +40,60 @@ def iterator():
|
|||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def apply_extensions(text, typ):
|
||||
def _apply_string_extensions(function_name, text):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
text = extension.input_modifier(text)
|
||||
elif typ == "output" and hasattr(extension, "output_modifier"):
|
||||
text = extension.output_modifier(text)
|
||||
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
if hasattr(extension, function_name):
|
||||
text = getattr(extension, function_name)(text)
|
||||
return text
|
||||
|
||||
|
||||
# Input hijack of extensions
|
||||
def _apply_input_hijack(text, visible_text):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
if callable(extension.input_hijack['value']):
|
||||
text, visible_text = extension.input_hijack['value'](text, visible_text)
|
||||
else:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
}
|
||||
|
||||
|
||||
def apply_extensions(typ, *args, **kwargs):
|
||||
if typ not in EXTENSION_MAP:
|
||||
raise ValueError(f"Invalid extension type {typ}")
|
||||
return EXTENSION_MAP[typ](*args, **kwargs)
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue