Add temperature_last parameter (#4472)

This commit is contained in:
oobabooga 2023-11-04 13:09:07 -03:00 committed by GitHub
parent 1ab8700d94
commit aa5d671579
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 38 additions and 7 deletions

View file

@ -12,6 +12,7 @@ from transformers.generation.logits_process import (
global_scores = None
class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if min_p < 0 or min_p > 1.0:
@ -41,6 +42,7 @@ class MinPLogitsWarper(LogitsWarper):
scores = scores.masked_fill(indices_to_remove, self.filter_value)
return scores
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
tfs = float(tfs)
@ -214,19 +216,36 @@ def get_logits_warper_patch(self, generation_config):
if not isinstance(warper, TemperatureLogitsWarper):
warpers.remove(warper)
else:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0:
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
if warpers and isinstance(warpers[-1], LogitNormalization):
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
normalize = warpers.pop(-1)
else:
warpers += warpers_to_add
normalize = None
warpers += warpers_to_add
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
temperature_idx = i
break
if temperature_idx is not None:
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
warpers = LogitsProcessorList(warpers)
if normalize is not None:
warpers.append(normalize)
warpers.append(SpyLogitsWarper())
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
return warpers
@ -261,6 +280,7 @@ def generation_config_init_patch(self, **kwargs):
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
self.presence_penalty = kwargs.pop("presence_penalty", 0)
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
self.temperature_last = kwargs.pop("temperature_last", False)
def hijack_samplers():