Dynamic Temperature HF loader support (#5174)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
3eca20c015
commit
48327cc5c4
14 changed files with 184 additions and 8 deletions
|
@ -10,9 +10,84 @@ from transformers.generation.logits_process import (
|
|||
TemperatureLogitsWarper
|
||||
)
|
||||
|
||||
from modules import shared
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
"scores will be invalid."
|
||||
)
|
||||
if isinstance(temperature, float) and temperature == 0.0:
|
||||
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
|
||||
|
||||
raise ValueError(except_msg)
|
||||
|
||||
self.temperature = temperature
|
||||
self.dynatemp = dynatemp
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
# Regular temperature
|
||||
if self.dynatemp == 0:
|
||||
scores = scores / self.temperature
|
||||
return scores
|
||||
|
||||
# Dynamic temperature
|
||||
else:
|
||||
min_temp = max(0.0, self.temperature - self.dynatemp)
|
||||
max_temp = self.temperature + self.dynatemp
|
||||
exponent_val = 1.0
|
||||
|
||||
# Convert logits to probabilities
|
||||
probs = torch.softmax(scores, dim=-1)
|
||||
|
||||
# Calculate entropy of the softmax probabilities
|
||||
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
|
||||
|
||||
# Guard against future possible division by zero
|
||||
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
|
||||
|
||||
# Any logits which are not -Infinity will be considered for calculating max entropy.
|
||||
num_valid_tokens = torch.sum(scores > -float('inf')).item()
|
||||
|
||||
# Now, calculate the max entropy by using only the valid tokens' count
|
||||
max_entropy = math.log(num_valid_tokens)
|
||||
|
||||
# Guard against future possible division by zero
|
||||
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
|
||||
|
||||
# Normalize the entropy
|
||||
normalized_entropy = entropy / max_entropy
|
||||
|
||||
# Map the normalized entropy to the desired temperature range using the power function
|
||||
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
|
||||
|
||||
# Apply the dynamically calculated temperature scaling
|
||||
scores = scores / dyn_temp
|
||||
|
||||
# print("----------------------\nTemperature from generation_config:", self.temperature)
|
||||
# print("min_temp:", min_temp)
|
||||
# print("max_temp:", max_temp)
|
||||
# print("Entropy:", entropy.item())
|
||||
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
|
||||
# print("Normalized Entropy:", normalized_entropy.item())
|
||||
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
|
||||
# print("----------------------")
|
||||
|
||||
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
|
||||
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
|
||||
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class MinPLogitsWarper(LogitsWarper):
|
||||
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
if min_p < 0 or min_p > 1.0:
|
||||
|
@ -198,14 +273,28 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||
# presence_penalty and frequency_penalty
|
||||
raw_presence_penalty = (counts > 0).to(scores.dtype)
|
||||
raw_frequency_penalty = counts.to(scores.dtype)
|
||||
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
|
||||
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
|
||||
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
# Make sure that temperature is float and not int
|
||||
if isinstance(generation_config.temperature, int):
|
||||
generation_config.temperature = float(generation_config.temperature)
|
||||
|
||||
temperature = generation_config.temperature
|
||||
if generation_config.dynatemp > 0:
|
||||
# Make sure TemperatureLogitsWarper will be created by temporarily
|
||||
# setting temperature to a value != 1.
|
||||
generation_config.temperature = 1.1
|
||||
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
|
||||
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
|
||||
|
@ -232,18 +321,18 @@ def get_logits_warper_patch(self, generation_config):
|
|||
if generation_config.temperature_last:
|
||||
temperature_idx = None
|
||||
for i in range(len(warpers)):
|
||||
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
|
||||
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
|
||||
temperature_idx = i
|
||||
break
|
||||
|
||||
if temperature_idx is not None:
|
||||
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
warpers.append(warpers.pop(temperature_idx))
|
||||
|
||||
if normalize is not None:
|
||||
warpers.append(normalize)
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
warpers = LogitsProcessorList(warpers)
|
||||
# for i in range(len(warpers)):
|
||||
# print(warpers[i].__class__.__name__)
|
||||
return warpers
|
||||
|
@ -272,6 +361,7 @@ def get_logits_processor_patch(self, **kwargs):
|
|||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.min_p = kwargs.pop("min_p", 0.0)
|
||||
self.dynatemp = kwargs.pop("dynatemp", 0.0)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue