LLaVA support (#1487)

This commit is contained in:
Wojtab 2023-04-24 01:32:22 +02:00 committed by GitHub
parent 9197d3fec8
commit 12212cf6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 42 deletions

View file

@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
original_question = question
if not shared.is_chat():
question = apply_extensions(question, 'input')
question = apply_extensions('input', question)
# These models are not part of Hugging Face, so we handle them
# separately and terminate the function call earlier
@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
else:
if not shared.is_chat():
@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
except Exception:
@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
return
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
original_input_ids = input_ids
output = input_ids[0]
if shared.args.verbose:
@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
generate_params.update({'synced_gpus': True})
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
try:
# Generate the entire reply at once.
@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
if output[-1] in eos_token_ids:
break
@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output')
reply = original_question + apply_extensions('output', reply)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break