LLaVA support (#1487)

This commit is contained in:
Wojtab 2023-04-24 01:32:22 +02:00 committed by GitHub
parent 9197d3fec8
commit 12212cf6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 42 deletions

View file

@ -135,7 +135,7 @@ def load_quantized(model_name):
# Find the model type
if not shared.args.model_type:
name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])):
model_type = 'llama'
elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt'

View file

@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1)
@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True
visible_text = custom_generate_chat_prompt = None
visible_text = None
eos_token = '\n' if state['stop_at_newline'] else None
stopping_strings = get_stopping_strings(state)
# Check if any extension wants to hijack this function call
for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
text, visible_text = apply_extensions('input_hijack', text, visible_text)
if visible_text is None:
visible_text = text
if not _continue:
text = apply_extensions(text, "input")
text = apply_extensions("input", text)
# Generating the prompt
kwargs = {'_continue': _continue}
if custom_generate_chat_prompt is None:
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs)
else:
prompt = custom_generate_chat_prompt(text, state, **kwargs)
# Yield *Is typing...*
if not any((regenerate, _continue)):
@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output")
visible_reply = apply_extensions("output", visible_reply)
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
@ -273,14 +266,14 @@ def send_last_reply_to_input():
def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def send_dummy_message(text, name1, name2, mode):
shared.history['visible'].append([text, ''])
shared.history['internal'].append([apply_extensions(text, "input"), ''])
shared.history['internal'].append([apply_extensions("input", text), ''])
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode):
shared.history['visible'].append(['', ''])
shared.history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode):
if greeting != '':
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Save cleared logs
save_history(mode)
@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode):
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Create .json log files since they don't already exist
save_history(mode)

View file

@ -1,4 +1,5 @@
import traceback
from functools import partial
import gradio as gr
@ -39,17 +40,60 @@ def iterator():
# Extension functions that map string -> string
def apply_extensions(text, typ):
def _apply_string_extensions(function_name, text):
for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"):
text = extension.input_modifier(text)
elif typ == "output" and hasattr(extension, "output_modifier"):
text = extension.output_modifier(text)
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
text = extension.bot_prefix_modifier(text)
if hasattr(extension, function_name):
text = getattr(extension, function_name)(text)
return text
# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
for extension, _ in iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
if callable(extension.input_hijack['value']):
text, visible_text = extension.input_hijack['value'](text, visible_text)
else:
text, visible_text = extension.input_hijack['value']
return text, visible_text
# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
return None
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}
def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block():
global setup_called

View file

@ -50,6 +50,8 @@ def find_model_type(model_name):
return 'chatglm'
elif 'galactica' in model_name:
return 'galactica'
elif 'llava' in model_name:
return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
@ -217,11 +219,12 @@ def load_model(model_name):
tokenizer = None
# Try to load an universal LLaMA tokenizer
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
if shared.model_type != 'llava':
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.

View file

@ -56,7 +56,7 @@ settings = {
'chat_default_extensions': ["gallery"],
'presets': {
'default': 'Default',
'.*(alpaca|llama)': "LLaMA-Precise",
'.*(alpaca|llama|llava)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive',
},

View file

@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
original_question = question
if not shared.is_chat():
question = apply_extensions(question, 'input')
question = apply_extensions('input', question)
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
else:
if not shared.is_chat():
@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
except Exception:
@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
return
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
original_input_ids = input_ids
output = input_ids[0]
if shared.args.verbose:
@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
generate_params.update({'synced_gpus': True})
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
try:
# Generate the entire reply at once.
@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
if output[-1] in eos_token_ids:
break
@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break