Remove softprompt support

This commit is contained in:
oobabooga 2023-06-06 07:42:23 -03:00
parent ccb4c9f178
commit 00b94847da
7 changed files with 8 additions and 106 deletions

View file

@ -27,11 +27,7 @@ def generate_reply(*args, **kwargs):
def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
return state['truncation_length'] - state['max_new_tokens']
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
@ -80,14 +76,6 @@ def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens)
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
@ -232,18 +220,11 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Add the encoded tokens to generate_params
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions)
stopping_criteria_list = transformers.StoppingCriteriaList()
@ -269,9 +250,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
if cuda:
output = output.cuda()
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
# Stream the reply 1 token at a time.
@ -289,9 +267,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
if output[-1] in eos_token_ids:
break