diff --git a/modules/callbacks.py b/modules/callbacks.py index 2ae9d90..8d30d61 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -25,7 +25,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]: continue for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1): - if torch.all(torch.eq(self.sentinel_token_ids[i], window)): + if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)): return True return False