Rename additive_repetition_penalty to presence_penalty, add frequency_penalty (#4376)
This commit is contained in:
parent
ef1489cd4d
commit
72f6fc6923
14 changed files with 64 additions and 30 deletions
|
@ -139,24 +139,35 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||
Copied from the transformers library
|
||||
'''
|
||||
|
||||
def __init__(self, penalty: float, additive_penalty: float, _range: int):
|
||||
def __init__(self, penalty: float, presence_penalty: float, frequency_penalty: float, _range: int):
|
||||
if not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self.additive_penalty = additive_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self._range = _range
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
score -= self.additive_penalty
|
||||
# We loop here because torch.unique() needs to process each row separately in the
|
||||
# case that batch_size > 1.
|
||||
for input_ids_row, scores_row in zip(input_ids, scores):
|
||||
unique_ids, counts = torch.unique(input_ids_row, return_counts=True)
|
||||
score = torch.gather(scores_row, 0, unique_ids)
|
||||
|
||||
# multiplicative repetition penalty
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
scores_row.scatter_(0, unique_ids, score)
|
||||
|
||||
# presence_penalty and frequency_penalty
|
||||
raw_presence_penalty = (counts > 0).to(scores.dtype)
|
||||
raw_frequency_penalty = counts.to(scores.dtype)
|
||||
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
|
||||
|
@ -188,9 +199,10 @@ def get_logits_warper_patch(self, generation_config):
|
|||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
||||
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
|
||||
presence_penalty = kwargs['generation_config'].presence_penalty
|
||||
frequency_penalty = kwargs['generation_config'].frequency_penalty
|
||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0)
|
||||
do_rep_pen_hijack = (repetition_penalty > 1) or (presence_penalty != 0) or (frequency_penalty != 0)
|
||||
if do_rep_pen_hijack:
|
||||
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
|
||||
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
|
||||
|
@ -200,7 +212,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
if do_rep_pen_hijack:
|
||||
for i in range(len(result)):
|
||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, presence_penalty, frequency_penalty, repetition_penalty_range)
|
||||
|
||||
return result
|
||||
|
||||
|
@ -213,7 +225,8 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)
|
||||
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
||||
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue