Add repetition penalty range parameter to transformers (#2916)

This commit is contained in:
oobabooga 2023-06-29 13:40:13 -03:00 committed by GitHub
parent c6cae106e7
commit 3443219cbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 55 additions and 5 deletions

View file

@ -5,6 +5,7 @@ import transformers
from transformers import LogitsWarper
from transformers.generation.logits_process import (
LogitNormalization,
LogitsProcessor,
LogitsProcessorList,
TemperatureLogitsWarper
)
@ -121,6 +122,29 @@ class MirostatLogitsWarper(LogitsWarper):
return scores
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self._range = _range
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
input_ids = input_ids[:, -self._range:]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores
def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
warpers_to_add = LogitsProcessorList()
@ -146,6 +170,19 @@ def get_logits_warper_patch(self, generation_config):
return warpers
def get_logits_processor_patch(self, **kwargs):
result = self._get_logits_processor_old(**kwargs)
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
repetition_penalty = kwargs['generation_config'].repetition_penalty
if repetition_penalty_range > 0:
for i in range(len(result)):
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
return result
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.tfs = kwargs.pop("tfs", 1.0)
@ -153,11 +190,15 @@ def generation_config_init_patch(self, **kwargs):
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
def hijack_samplers():
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
transformers.GenerationConfig.__init__ = generation_config_init_patch