Add dynamic_temperature_low parameter (#5198)
This commit is contained in:
parent
b8a0b3f925
commit
0d07b3a6a1
11 changed files with 30 additions and 88 deletions
|
@ -16,7 +16,7 @@ global_scores = None
|
|||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynatemp: float):
|
||||
def __init__(self, temperature: float, dynamic_temperature: bool, dynamic_temperature_low: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
|
@ -28,19 +28,20 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
raise ValueError(except_msg)
|
||||
|
||||
self.temperature = temperature
|
||||
self.dynatemp = dynatemp
|
||||
self.dynamic_temperature = dynamic_temperature
|
||||
self.dynamic_temperature_low = dynamic_temperature_low
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# Regular temperature
|
||||
if self.dynatemp == 0:
|
||||
if not self.dynamic_temperature:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
# Dynamic temperature
|
||||
else:
|
||||
min_temp = max(0.0, self.temperature - self.dynatemp)
|
||||
max_temp = self.temperature + self.dynatemp
|
||||
min_temp = self.dynamic_temperature_low
|
||||
max_temp = self.temperature
|
||||
exponent_val = 1.0
|
||||
|
||||
# Convert logits to probabilities
|
||||
|
@ -283,7 +284,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynatemp > 0:
|
||||
if generation_config.dynamic_temperature:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
|
@ -291,7 +292,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
warpers = self._get_logits_warper_old(generation_config)
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynamic_temperature, generation_config.dynamic_temperature_low)
|
||||
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
|
@ -359,7 +360,8 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.min_p = kwargs.pop("min_p", 0.0)
|
||||
self.dynatemp = kwargs.pop("dynatemp", 0.0)
|
||||
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
|
||||
self.dynamic_temperature_low = kwargs.pop("dynamic_temperature_low", 0.1)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue