Lint
This commit is contained in:
parent
d93db3b486
commit
c4c7fc4ab3
3 changed files with 5 additions and 6 deletions
|
@ -16,7 +16,7 @@ global_scores = None
|
|||
|
||||
|
||||
class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
||||
def __init__(self, temperature: float, dynatemp: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
|
||||
def __init__(self, temperature: float, dynatemp: float):
|
||||
if not isinstance(temperature, float) or not (temperature > 0):
|
||||
except_msg = (
|
||||
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
|
||||
|
@ -29,8 +29,6 @@ class TemperatureLogitsWarperWithDynatemp(LogitsWarper):
|
|||
|
||||
self.temperature = temperature
|
||||
self.dynatemp = dynatemp
|
||||
self.filter_value = filter_value
|
||||
self.min_tokens_to_keep = min_tokens_to_keep
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue