Add additive_repetition_penalty sampler setting. (#3627)

This commit is contained in:
tdrussell 2023-10-23 00:28:07 -05:00 committed by GitHub
parent 6086768309
commit 4440f87722
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 34 additions and 8 deletions

View file

@ -139,11 +139,12 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
def __init__(self, penalty: float, additive_penalty: float, _range: int):
if not (penalty > 0):
raise ValueError(f"`penalty` has to be strictly positive, but is {penalty}")
self.penalty = penalty
self.additive_penalty = additive_penalty
self._range = _range
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -153,6 +154,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
score -= self.additive_penalty
scores.scatter_(1, input_ids, score)
return scores
@ -185,14 +187,20 @@ def get_logits_warper_patch(self, generation_config):
def get_logits_processor_patch(self, **kwargs):
result = self._get_logits_processor_old(**kwargs)
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
repetition_penalty = kwargs['generation_config'].repetition_penalty
additive_repetition_penalty = kwargs['generation_config'].additive_repetition_penalty
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
do_rep_pen_hijack = (repetition_penalty > 1) or (additive_repetition_penalty > 0)
if do_rep_pen_hijack:
# Make sure that a RepetitionPenaltyLogitsProcessor will be created
kwargs['generation_config'].repetition_penalty = 1.1 # must set to some value > 1
if repetition_penalty_range > 0:
result = self._get_logits_processor_old(**kwargs)
if do_rep_pen_hijack:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, additive_repetition_penalty, repetition_penalty_range)
return result
@ -205,6 +213,7 @@ def generation_config_init_patch(self, **kwargs):
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.additive_repetition_penalty = kwargs.pop("additive_repetition_penalty", 0)
def hijack_samplers():