LLaVA support (#1487)

This commit is contained in:
Wojtab 2023-04-24 01:32:22 +02:00 committed by GitHub
parent 9197d3fec8
commit 12212cf6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 42 deletions

View file

@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1)
@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True
visible_text = custom_generate_chat_prompt = None
visible_text = None
eos_token = '\n' if state['stop_at_newline'] else None
stopping_strings = get_stopping_strings(state)
# Check if any extension wants to hijack this function call
for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
text, visible_text = apply_extensions('input_hijack', text, visible_text)
if visible_text is None:
visible_text = text
if not _continue:
text = apply_extensions(text, "input")
text = apply_extensions("input", text)
# Generating the prompt
kwargs = {'_continue': _continue}
if custom_generate_chat_prompt is None:
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs)
else:
prompt = custom_generate_chat_prompt(text, state, **kwargs)
# Yield *Is typing...*
if not any((regenerate, _continue)):
@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output")
visible_reply = apply_extensions("output", visible_reply)
# We need this global variable to handle the Stop event,
# otherwise gradio gets confused
@ -273,14 +266,14 @@ def send_last_reply_to_input():
def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def send_dummy_message(text, name1, name2, mode):
shared.history['visible'].append([text, ''])
shared.history['internal'].append([apply_extensions(text, "input"), ''])
shared.history['internal'].append([apply_extensions("input", text), ''])
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode):
shared.history['visible'].append(['', ''])
shared.history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input")
shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode):
if greeting != '':
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Save cleared logs
save_history(mode)
@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode):
# Insert greeting if it exists
if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Create .json log files since they don't already exist
save_history(mode)