LLaVA: small fixes (#1664)

* change multimodal projector to the correct one

* remove reference to custom stopping strings from readme

* fix stopping strings if tokenizer extension adds/removes tokens

* add API example

* LLaVA 7B just dropped, add to readme that there is no support for it currently
This commit is contained in:
Wojtab 2023-05-03 04:12:22 +02:00 committed by GitHub
parent c31b0f15a7
commit 80c2f25131
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 31 deletions

View file

@ -236,21 +236,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList()
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
if type(st) is list and len(st) > 0:
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
break
# Update generate_params with the eos token and the stopping strings
if shared.args.flexgen:
generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
# Add the encoded tokens to generate_params
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
@ -265,6 +250,21 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
stopping_criteria_list = transformers.StoppingCriteriaList()
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
if type(st) is list and len(st) > 0:
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
break
# Update generate_params with the eos token and the stopping strings
if shared.args.flexgen:
generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
try:
# Generate the entire reply at once.
if shared.args.no_stream: