Add dynamic_temperature_low parameter (#5198)

This commit is contained in:
oobabooga 2024-01-07 17:03:47 -03:00 committed by GitHub
parent b8a0b3f925
commit 0d07b3a6a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 30 additions and 88 deletions

View file

@ -155,7 +155,8 @@ def transformers_samplers():
return {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
@ -221,7 +222,8 @@ loaders_samplers = {
'ExLlamav2_HF': {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',
@ -274,7 +276,8 @@ loaders_samplers = {
'llamacpp_HF': {
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',

View file

@ -12,7 +12,8 @@ def default_preset():
return {
'temperature': 1,
'temperature_last': False,
'dynatemp': 0,
'dynamic_temperature': False,
'dynamic_temperature_low': 0.1,
'top_p': 1,
'min_p': 0,
'top_k': 0,
@ -53,7 +54,6 @@ def load_preset(name):
for k in preset:
generate_params[k] = preset[k]
generate_params['temperature'] = min(1.99, generate_params['temperature'])
return generate_params

View file

@ -16,7 +16,7 @@ global_scores = None
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
def __init__(self, temperature: float, dynatemp: float):
def __init__(self, temperature: float, dynamic_temperature: bool, dynamic_temperature_low: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
@ -28,19 +28,20 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
raise ValueError(except_msg)
self.temperature = temperature
self.dynatemp = dynatemp
self.dynamic_temperature = dynamic_temperature
self.dynamic_temperature_low = dynamic_temperature_low
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# Regular temperature
if self.dynatemp == 0:
if not self.dynamic_temperature:
scores = scores / self.temperature
return scores
# Dynamic temperature
else:
min_temp = max(0.0, self.temperature - self.dynatemp)
max_temp = self.temperature + self.dynatemp
min_temp = self.dynamic_temperature_low
max_temp = self.temperature
exponent_val = 1.0
# Convert logits to probabilities
@ -283,7 +284,7 @@ def get_logits_warper_patch(self, generation_config):
generation_config.temperature = float(generation_config.temperature)
temperature = generation_config.temperature
if generation_config.dynatemp > 0:
if generation_config.dynamic_temperature:
# Make sure TemperatureLogitsWarper will be created by temporarily
# setting temperature to a value != 1.
generation_config.temperature = 1.1
@ -291,7 +292,7 @@ def get_logits_warper_patch(self, generation_config):
warpers = self._get_logits_warper_old(generation_config)
for i in range(len(warpers)):
if warpers[i].__class__.__name__ == 'TemperatureLogitsWarper':
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynatemp)
warpers[i] = TemperatureLogitsWarperWithDynatemp(temperature, generation_config.dynamic_temperature, generation_config.dynamic_temperature_low)
warpers_to_add = LogitsProcessorList()
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
@ -359,7 +360,8 @@ def get_logits_processor_patch(self, **kwargs):
def generation_config_init_patch(self, **kwargs):
self.__init___old(**kwargs)
self.min_p = kwargs.pop("min_p", 0.0)
self.dynatemp = kwargs.pop("dynatemp", 0.0)
self.dynamic_temperature = kwargs.pop("dynamic_temperature", False)
self.dynamic_temperature_low = kwargs.pop("dynamic_temperature_low", 0.1)
self.tfs = kwargs.pop("tfs", 1.0)
self.top_a = kwargs.pop("top_a", 0.0)
self.mirostat_mode = kwargs.pop("mirostat_mode", 0)

View file

@ -285,7 +285,7 @@ def get_reply_from_output_ids(output_ids, state, starting_from=0):
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynatemp', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
for k in ['max_new_tokens', 'temperature', 'temperature_last', 'dynamic_temperature', 'dynamic_temperature_low', 'top_p', 'min_p', 'top_k', 'repetition_penalty', 'presence_penalty', 'frequency_penalty', 'repetition_penalty_range', 'typical_p', 'tfs', 'top_a', 'guidance_scale', 'penalty_alpha', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'do_sample', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'num_beams', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
if state['negative_prompt'] != '':

View file

@ -115,7 +115,8 @@ def list_interface_input_elements():
'seed',
'temperature',
'temperature_last',
'dynatemp',
'dynamic_temperature',
'dynamic_temperature_low',
'top_p',
'min_p',
'top_k',

View file

@ -49,7 +49,8 @@ def create_ui(default_preset):
shared.gradio['mirostat_mode'] = gr.Slider(0, 2, step=1, value=generate_params['mirostat_mode'], label='mirostat_mode', info='mode=1 is for llama.cpp only.')
shared.gradio['mirostat_tau'] = gr.Slider(0, 10, step=0.01, value=generate_params['mirostat_tau'], label='mirostat_tau')
shared.gradio['mirostat_eta'] = gr.Slider(0, 1, step=0.01, value=generate_params['mirostat_eta'], label='mirostat_eta')
shared.gradio['dynatemp'] = gr.Slider(0, 5, value=generate_params['dynatemp'], step=0.01, label='dynatemp')
shared.gradio['dynamic_temperature_low'] = gr.Slider(0.01, 5, value=generate_params['dynamic_temperature_low'], step=0.01, label='dynamic_temperature_low', info='Only used when dynamic_temperature is checked.')
shared.gradio['dynamic_temperature'] = gr.Checkbox(value=generate_params['dynamic_temperature'], label='dynamic_temperature')
shared.gradio['temperature_last'] = gr.Checkbox(value=generate_params['temperature_last'], label='temperature_last', info='Makes temperature the last sampler instead of the first.')
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')