LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
12 changed files with 426 additions and 42 deletions
|
@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
|
@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
text, visible_text = apply_extensions('input_hijack', text, visible_text)
|
||||
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
text = apply_extensions("input", text)
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
|
@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
|
@ -273,14 +266,14 @@ def send_last_reply_to_input():
|
|||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def send_dummy_message(text, name1, name2, mode):
|
||||
shared.history['visible'].append([text, ''])
|
||||
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||
shared.history['internal'].append([apply_extensions("input", text), ''])
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
|
@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode):
|
|||
shared.history['visible'].append(['', ''])
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
|
@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode):
|
|||
|
||||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode):
|
|||
# Insert greeting if it exists
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue