Add the option to use samplers in the logit viewer

This commit is contained in:
oobabooga 2023-08-22 20:18:16 -07:00
parent 25e5eaa6a6
commit 8545052c9d
8 changed files with 75 additions and 18 deletions

View file

@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
TemperatureLogitsWarper
)
global_scores = None
class TailFreeLogitsWarper(LogitsWarper):
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
return scores
class SpyLogitsWarper(LogitsWarper):
def __init__(self):
pass
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
global global_scores
global_scores = scores
return scores
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
else:
warpers += warpers_to_add
warpers.append(SpyLogitsWarper())
return warpers