Add temperature_last parameter (#4472)
This commit is contained in:
parent
1ab8700d94
commit
aa5d671579
7 changed files with 38 additions and 7 deletions
|
@ -12,6 +12,7 @@ from transformers.generation.logits_process import (
|
|||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if min_p < 0 or min_p > 1.0:
|
||||
|
@ -41,6 +42,7 @@ class MinPLogitsWarper(LogitsWarper):
|
|||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
|
@ -214,19 +216,36 @@ def get_logits_warper_patch(self, generation_config):
|
|||
if not isinstance(warper, TemperatureLogitsWarper):
|
||||
warpers.remove(warper)
|
||||
else:
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs < 1.0:
|
||||
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
||||
if generation_config.top_a is not None and 0.0 < generation_config.top_a <= 1.0:
|
||||
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.min_p is not None and 0.0 <= generation_config.min_p <= 1.0:
|
||||
if generation_config.min_p is not None and 0.0 < generation_config.min_p <= 1.0:
|
||||
warpers_to_add.append(MinPLogitsWarper(min_p=generation_config.min_p, min_tokens_to_keep=min_tokens_to_keep))
|
||||
|
||||
if warpers and isinstance(warpers[-1], LogitNormalization):
|
||||
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
||||
if len(warpers) > 0 and isinstance(warpers[-1], LogitNormalization):
|
||||
normalize = warpers.pop(-1)
|
||||
else:
|
||||
warpers += warpers_to_add
|
||||
normalize = None
|
||||
|
||||
warpers += warpers_to_add
|
||||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
if temperature_idx is not None:
|
||||
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
|
||||
if normalize is not None:
|
||||
warpers.append(normalize)
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
# for i in range(len(warpers)):
|
||||
# print(warpers[i].__class__.__name__)
|
||||
return warpers
|
||||
|
||||
|
||||
|
@ -261,6 +280,7 @@ def generation_config_init_patch(self, **kwargs):
|
|||
self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0)
|
||||
self.presence_penalty = kwargs.pop("presence_penalty", 0)
|
||||
self.frequency_penalty = kwargs.pop("frequency_penalty", 0)
|
||||
self.temperature_last = kwargs.pop("temperature_last", False)
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue