LLaVA: small fixes (#1664)
* change multimodal projector to the correct one * remove reference to custom stopping strings from readme * fix stopping strings if tokenizer extension adds/removes tokens * add API example * LLaVA 7B just dropped, add to readme that there is no support for it currently
This commit is contained in:
parent
c31b0f15a7
commit
80c2f25131
3 changed files with 56 additions and 31 deletions
|
@ -236,21 +236,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||
break
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
|
@ -265,6 +250,21 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||
break
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue