Add the option to use samplers in the logit viewer
This commit is contained in:
parent
25e5eaa6a6
commit
8545052c9d
8 changed files with 75 additions and 18 deletions
|
@ -10,6 +10,8 @@ from transformers.generation.logits_process import (
|
|||
TemperatureLogitsWarper
|
||||
)
|
||||
|
||||
global_scores = None
|
||||
|
||||
|
||||
class TailFreeLogitsWarper(LogitsWarper):
|
||||
def __init__(self, tfs: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
|
@ -122,6 +124,16 @@ class MirostatLogitsWarper(LogitsWarper):
|
|||
return scores
|
||||
|
||||
|
||||
class SpyLogitsWarper(LogitsWarper):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
global global_scores
|
||||
global_scores = scores
|
||||
return scores
|
||||
|
||||
|
||||
class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
|
@ -168,6 +180,7 @@ def get_logits_warper_patch(self, generation_config):
|
|||
else:
|
||||
warpers += warpers_to_add
|
||||
|
||||
warpers.append(SpyLogitsWarper())
|
||||
return warpers
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue