dynatemp_low, dynatemp_high, dynatemp_exponent parameters (#5209)

This commit is contained in:
oobabooga 2024-01-08 23:28:35 -03:00 committed by GitHub
parent dc1df22a2b
commit 29c2693ea0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 51 additions and 27 deletions

View file

@ -16,7 +16,7 @@ global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynamic_temperature: bool, dynamic_temperature_low: float):
def __init__(self, temperature: float, dynamic_temperature: bool, dynatemp_low: float, dynatemp_high: float, dynatemp_exponent: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -29,7 +29,9 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
self.temperature = temperature
self.dynamic_temperature = dynamic_temperature
self.dynamic_temperature_low = dynamic_temperature_low
self.dynatemp_low = dynatemp_low
self.dynatemp_high = dynatemp_high
self.dynatemp_exponent = dynatemp_exponent
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
@ -40,9 +42,9 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
# Dynamic temperature
else:
min_temp = self.dynamic_temperature_low
max_temp = self.temperature
exponent_val = 1.0
min_temp = self.dynatemp_low
max_temp = self.dynatemp_high
exponent_val = self.dynatemp_exponent
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
@ -82,7 +84,7 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp)
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp, "exponent=", exponent_val)
return scores
@ -292,7 +294,13 @@ def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynamic_temperature, generation_config.dynamic_temperature_low)
warpers[i] = TemperatureLogitsWarperWithDynatemp(
temperature,
generation_config.dynamic_temperature,
generation_config.dynatemp_low,
generation_config.dynatemp_high,
generation_config.dynatemp_exponent
)
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
@ -361,7 +369,9 @@ def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
self.dynamic_temperature_low = kwargs.pop("dynamic_temperature_low", 0.1)
self.dynatemp_low = kwargs.pop("dynatemp_low", 1)
self.dynatemp_high = kwargs.pop("dynatemp_high", 1)
self.dynatemp_exponent = kwargs.pop("dynatemp_exponent", 1)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)