LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
12 changed files with 426 additions and 42 deletions
|
@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions(question, 'input')
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
|
@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
|
@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
|
@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
return
|
||||
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
if shared.args.verbose:
|
||||
|
@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
generate_params.update({'synced_gpus': True})
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
|
@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue