Remove softprompt support
This commit is contained in:
parent
ccb4c9f178
commit
00b94847da
7 changed files with 8 additions and 106 deletions
|
@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs):
|
|||
|
||||
|
||||
def get_max_prompt_length(state):
|
||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
return max_length
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
|
||||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
|
@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
|
|||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
|
@ -232,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
|
@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
if cuda:
|
||||
output = output.cuda()
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
|
@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue