LLaVA support (#1487)

This commit is contained in:
Wojtab 2023-04-24 01:32:22 +02:00 committed by GitHub
parent 9197d3fec8
commit 12212cf6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 42 deletions

View file

@ -1,4 +1,5 @@
import traceback
from functools import partial
import gradio as gr
@ -39,17 +40,60 @@ def iterator():
# Extension functions that map string -> string
def apply_extensions(text, typ):
def _apply_string_extensions(function_name, text):
for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"):
text = extension.input_modifier(text)
elif typ == "output" and hasattr(extension, "output_modifier"):
text = extension.output_modifier(text)
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
text = extension.bot_prefix_modifier(text)
if hasattr(extension, function_name):
text = getattr(extension, function_name)(text)
return text
# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
for extension, _ in iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
if callable(extension.input_hijack['value']):
text, visible_text = extension.input_hijack['value'](text, visible_text)
else:
text, visible_text = extension.input_hijack['value']
return text, visible_text
# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
return None
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}
def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block():
global setup_called