Add tail-free and top-a sampling (#2357)
This commit is contained in:
parent
b4662bf4af
commit
9e7204bef4
5 changed files with 113 additions and 5 deletions
|
|
@ -15,7 +15,7 @@ from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
|||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
from modules import llama_attn_hijack, sampler_hijack
|
||||
from modules.logging_colors import logger
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
|
@ -36,6 +36,8 @@ if shared.args.deepspeed:
|
|||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
sampler_hijack.hijack_samplers()
|
||||
|
||||
|
||||
# Some models require special treatment in various parts of the code.
|
||||
# This function detects those models
|
||||
|
|
|
|||
102
modules/sampler_hijack.py
Normal file
102
modules/sampler_hijack.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsWarper
|
||||
from transformers.generation.logits_process import LogitNormalization, LogitsProcessorList
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
tfs = float(tfs)
|
||||
if tfs < 0 or tfs > 1.0:
|
||||
raise ValueError(f"`tfs` has to be a float >= 0 and <= 1, but is {tfs}")
|
||||
self.tfs = tfs
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Compute second derivative normalized CDF
|
||||
d2 = probs.diff().diff().abs()
|
||||
normalized_d2 = d2 / d2.sum(dim=-1, keepdim=True)
|
||||
normalized_d2_cdf = normalized_d2.cumsum(dim=-1)
|
||||
|
||||
# Remove tokens with CDF value above the threshold (token with 0 are kept)
|
||||
sorted_indices_to_remove = normalized_d2_cdf > self.tfs
|
||||
|
||||
# Centre the distribution around the cutoff as in the original implementation of the algorithm
|
||||
sorted_indices_to_remove = torch.cat(
|
||||
(
|
||||
torch.zeros(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
sorted_indices_to_remove,
|
||||
torch.ones(scores.shape[0], 1, dtype=torch.bool, device=scores.device),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
class TopALogitsWarper(LogitsWarper):
|
||||
def __init__(self, top_a: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
top_a = float(top_a)
|
||||
if top_a < 0 or top_a > 1.0:
|
||||
raise ValueError(f"`top_a` has to be a float >= 0 and <= 1, but is {top_a}")
|
||||
self.top_a = top_a
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
sorted_logits, sorted_indices = torch.sort(scores, descending=True)
|
||||
probs = sorted_logits.softmax(dim=-1)
|
||||
|
||||
# Remove tokens with probability less than top_a*(max(probs))^2 (token with 0 are kept)
|
||||
probs_max = probs[..., 0, None]
|
||||
sorted_indices_to_remove = probs < probs_max * probs_max * self.top_a
|
||||
|
||||
if self.min_tokens_to_keep > 1:
|
||||
# Keep at least min_tokens_to_keep
|
||||
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
|
||||
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
scores = scores.masked_fill(indices_to_remove, self.filter_value)
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_warper_patch(self, generation_config):
|
||||
warpers = self._get_logits_warper_old(generation_config)
|
||||
warpers_to_add = LogitsProcessorList()
|
||||
min_tokens_to_keep = 2 if generation_config.num_beams > 1 else 1
|
||||
|
||||
if generation_config.tfs is not None and 0.0 <= generation_config.tfs <= 1.0:
|
||||
warpers_to_add.append(TailFreeLogitsWarper(tfs=generation_config.tfs, min_tokens_to_keep=min_tokens_to_keep))
|
||||
if generation_config.top_a is not None and 0.0 <= generation_config.top_a <= 1.0:
|
||||
warpers_to_add.append(TopALogitsWarper(top_a=generation_config.top_a, min_tokens_to_keep=min_tokens_to_keep))
|
||||
|
||||
if warpers and isinstance(warpers[-1], LogitNormalization):
|
||||
warpers = warpers[:-1] + warpers_to_add + [warpers[-1]]
|
||||
else:
|
||||
warpers += warpers_to_add
|
||||
|
||||
return warpers
|
||||
|
||||
|
||||
def generation_config_init_patch(self, **kwargs):
|
||||
self.__init___old(**kwargs)
|
||||
self.tfs = kwargs.pop("tfs", 1.0)
|
||||
self.top_a = kwargs.pop("top_a", 0.0)
|
||||
|
||||
|
||||
def hijack_samplers():
|
||||
transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper
|
||||
transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch
|
||||
|
||||
transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__
|
||||
transformers.GenerationConfig.__init__ = generation_config_init_patch
|
||||
|
|
@ -194,7 +194,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||
|
||||
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def list_model_elements():
|
|||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue