Dynamic Temperature HF loader support (#5174)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
kalomaze 2024-01-07 07:36:26 -06:00 committed by GitHub
parent 3eca20c015
commit 48327cc5c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 184 additions and 8 deletions

View file

@ -10,6 +10,10 @@ from transformers import is_torch_xpu_available
import modules.shared as shared
class StopNowException(Exception):
pass
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
@ -49,13 +53,13 @@ class Iteratorize:
def _callback(val):
if self.stop_now or shared.stop_everything:
raise ValueError
raise StopNowException
self.q.put(val)
def gentask():
try:
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
except ValueError:
except StopNowException:
pass
except:
traceback.print_exc()

View file

@ -144,6 +144,9 @@ class LlamacppHF(PreTrainedModel):
self.model.n_tokens = longest_prefix
if len(seq_tensor) - longest_prefix > 0:
self.model.eval(seq[longest_prefix:])
else:
self.model.n_tokens -= 1
self.model.eval([seq[-1]])
if reset:
self.model.reset()

View file

@ -155,6 +155,7 @@ def transformers_samplers():
return {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
@ -220,6 +221,7 @@ loaders_samplers = {
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',
@ -272,6 +274,7 @@ loaders_samplers = {
'llamacpp_HF': {
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',

View file

@ -8,7 +8,7 @@ from modules.text_generation import generate_reply
global_scores = None
def get_next_logits(prompt, state, use_samplers, previous, top_logits=50, return_dict=False):
def get_next_logits(prompt, state, use_samplers, previous, top_logits=25, return_dict=False):
if shared.model is None:
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous

View file

@ -12,6 +12,7 @@ def default_preset():
return {
'temperature': 1,
'temperature_last': False,
'dynatemp': 0,
'top_p': 1,
'min_p': 0,
'top_k': 0,

View file

@ -10,9 +10,84 @@ from transformers.generation.logits_process import (
TemperatureLogitsWarper
)
from modules import shared
global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
self.temperature = temperature
self.dynatemp = dynatemp
self.filter_value = filter_value
self.min_tokens_to_keep = min_tokens_to_keep
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Regular temperature
if self.dynatemp == 0:
scores = scores / self.temperature
return scores
# Dynamic temperature
else:
min_temp = max(0.0, self.temperature - self.dynatemp)
max_temp = self.temperature + self.dynatemp
exponent_val = 1.0
# Convert logits to probabilities
probs = torch.softmax(scores, dim=-1)
# Calculate entropy of the softmax probabilities
entropy = -1.0 * torch.where(probs > 0, probs * torch.log(probs), torch.zeros_like(probs)).sum()
# Guard against future possible division by zero
entropy = max(entropy, torch.tensor(1e-10)) # Ensures entropy is slightly greater than 0
# Any logits which are not -Infinity will be considered for calculating max entropy.
num_valid_tokens = torch.sum(scores > -float('inf')).item()
# Now, calculate the max entropy by using only the valid tokens' count
max_entropy = math.log(num_valid_tokens)
# Guard against future possible division by zero
max_entropy = max_entropy if max_entropy > 0.0 else 1e-10
# Normalize the entropy
normalized_entropy = entropy / max_entropy
# Map the normalized entropy to the desired temperature range using the power function
dyn_temp = min_temp + (max_temp - min_temp) * (normalized_entropy.pow(exponent_val))
# Apply the dynamically calculated temperature scaling
scores = scores / dyn_temp
# print("----------------------\nTemperature from generation_config:", self.temperature)
# print("min_temp:", min_temp)
# print("max_temp:", max_temp)
# print("Entropy:", entropy.item())
# print("Max Possible Entropy considering valid tokens only:", max_entropy)
# print("Normalized Entropy:", normalized_entropy.item())
# print("Dynamic Temperature (dyn_temp):", dyn_temp.item())
# print("----------------------")
# max_prob_token_id = torch.argmax(scores, dim=-1) # Get the token ID with the highest probability
# max_prob_token = shared.tokenizer.convert_ids_to_tokens(int(max_prob_token_id)) # Convert ID to token
# print("--- T=", float(dyn_temp), "token=", max_prob_token, "min=", min_temp, "max=", max_temp)
return scores
class MinPLogitsWarper(LogitsWarper):
def __init__(self, min_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
if min_p < 0 or min_p > 1.0:
@ -198,14 +273,28 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
# presence_penalty and frequency_penalty
raw_presence_penalty = (counts > 0).to(scores.dtype)
raw_frequency_penalty = counts.to(scores.dtype)
additive_penalty = raw_presence_penalty*self.presence_penalty + raw_frequency_penalty*self.frequency_penalty
additive_penalty = raw_presence_penalty * self.presence_penalty + raw_frequency_penalty * self.frequency_penalty
scores_row.scatter_add_(0, unique_ids, -additive_penalty)
return scores
def get_logits_warper_patch(self, generation_config):
# Make sure that temperature is float and not int
if isinstance(generation_config.temperature, int):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynatemp > 0:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
@ -232,18 +321,18 @@ def get_logits_warper_patch(self, generation_config):
if generation_config.temperature_last:
temperature_idx = None
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
if warpers[i].__class__.__name__ in ['TemperatureLogitsWarper', 'TemperatureLogitsWarperWithDynatemp']:
temperature_idx = i
break
if temperature_idx is not None:
warpers = warpers[:temperature_idx] + warpers[temperature_idx + 1:] + [warpers[temperature_idx]]
warpers = LogitsProcessorList(warpers)
warpers.append(warpers.pop(temperature_idx))
if normalize is not None:
warpers.append(normalize)
warpers.append(SpyLogitsWarper())
warpers = LogitsProcessorList(warpers)
# for i in range(len(warpers)):
# print(warpers[i].__class__.__name__)
return warpers
@ -272,6 +361,7 @@ def get_logits_processor_patch(self, **kwargs):
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynatemp = kwargs.pop("dynatemp", 0.0)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)

View file

@ -283,7 +283,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'typical_p', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
generate_params[k] = state[k]
if state['negative_prompt'] != '':

View file

@ -115,6 +115,7 @@ def list_interface_input_elements():
'seed',
'temperature',
'temperature_last',
'dynatemp',
'top_p',
'min_p',
'top_k',

View file

@ -49,6 +49,7 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['dynatemp'] = gr.Slider(0, 5, value=generate_params['dynatemp'], step=0.01, label='dynatemp')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')