Add repetition penalty range parameter to transformers (#2916)
This commit is contained in:
parent
c6cae106e7
commit
3443219cbc
12 changed files with 55 additions and 5 deletions
|
@ -5,6 +5,7 @@ import transformers
|
|||
from transformers import LogitsWarper
|
||||
from transformers.generation.logits_process import (
|
||||
LogitNormalization,
|
||||
LogitsProcessor,
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper
|
||||
)
|
||||
|
@ -121,6 +122,29 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self._range = _range
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
input_ids = input_ids[:, -self._range:]
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
|
@ -146,6 +170,19 @@ def get_logits_warper_patch(self, generation_config):
|
|||
return warpers
|
||||
|
||||
|
||||
def get_logits_processor_patch(self, **kwargs):
|
||||
result = self._get_logits_processor_old(**kwargs)
|
||||
repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range
|
||||
repetition_penalty = kwargs['generation_config'].repetition_penalty
|
||||
|
||||
if repetition_penalty_range > 0:
|
||||
for i in range(len(result)):
|
||||
if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor':
|
||||
result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
|
@ -153,11 +190,15 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1)
|
||||
self.mirostat_tau = kwargs.pop("mirostat_tau", 5)
|
||||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
||||
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
||||
|
||||
transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor
|
||||
transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch
|
||||
|
||||
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
|
||||
transformers.GenerationConfig.__init__ = generation_config_init_patch
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue