From 0b2a6b2819d6b522df379abd8a65597fccbcc664 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 25 Jan 2023 10:19:50 -0300 Subject: [PATCH] Add file --- modules/stopping_criteria.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 modules/stopping_criteria.py diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py new file mode 100644 index 0000000..3baadf6 --- /dev/null +++ b/modules/stopping_criteria.py @@ -0,0 +1,31 @@ +''' +This code was copied from + +https://github.com/PygmalionAI/gradio-ui/ + +''' + +import torch +import transformers + +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: torch.LongTensor, + starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, + _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: + continue + + for window in trimmed_sample.unfold( + 0, self.sentinel_token_ids.shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids, window)): + return True + return False