LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
12 changed files with 426 additions and 42 deletions
|
@ -135,7 +135,7 @@ def load_quantized(model_name):
|
|||
# Find the model type
|
||||
if not shared.args.model_type:
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])):
|
||||
model_type = 'llama'
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
|
|
|
@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
|
@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
text, visible_text = apply_extensions('input_hijack', text, visible_text)
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
text = apply_extensions("input", text)
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
|
@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
|
@ -273,14 +266,14 @@ def send_last_reply_to_input():
|
|||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def send_dummy_message(text, name1, name2, mode):
|
||||
shared.history['visible'].append([text, ''])
|
||||
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||
shared.history['internal'].append([apply_extensions("input", text), ''])
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
|
@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode):
|
|||
shared.history['visible'].append(['', ''])
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
|
@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||
|
||||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode):
|
|||
# Insert greeting if it exists
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
@ -39,17 +40,60 @@ def iterator():
|
|||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def apply_extensions(text, typ):
|
||||
def _apply_string_extensions(function_name, text):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
text = extension.input_modifier(text)
|
||||
elif typ == "output" and hasattr(extension, "output_modifier"):
|
||||
text = extension.output_modifier(text)
|
||||
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
if hasattr(extension, function_name):
|
||||
text = getattr(extension, function_name)(text)
|
||||
return text
|
||||
|
||||
|
||||
# Input hijack of extensions
|
||||
def _apply_input_hijack(text, visible_text):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
if callable(extension.input_hijack['value']):
|
||||
text, visible_text = extension.input_hijack['value'](text, visible_text)
|
||||
else:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
return None
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
}
|
||||
|
||||
|
||||
def apply_extensions(typ, *args, **kwargs):
|
||||
if typ not in EXTENSION_MAP:
|
||||
raise ValueError(f"Invalid extension type {typ}")
|
||||
return EXTENSION_MAP[typ](*args, **kwargs)
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
|
|
|
@ -50,6 +50,8 @@ def find_model_type(model_name):
|
|||
return 'chatglm'
|
||||
elif 'galactica' in model_name:
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name:
|
||||
return 'llava'
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
|
@ -217,11 +219,12 @@ def load_model(model_name):
|
|||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
if shared.model_type != 'llava':
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
|
|
|
@ -56,7 +56,7 @@ settings = {
|
|||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
'default': 'Default',
|
||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||
'.*(alpaca|llama|llava)': "LLaMA-Precise",
|
||||
'.*pygmalion': 'NovelAI-Storywriter',
|
||||
'.*RWKV': 'Naive',
|
||||
},
|
||||
|
|
|
@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions(question, 'input')
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
|
@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
|
@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
|
@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
return
|
||||
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
if shared.args.verbose:
|
||||
|
@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
generate_params.update({'synced_gpus': True})
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
|
@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue