Implement stopping string search in string space (#2847)

This commit is contained in:
oobabooga 2023-06-24 09:43:00 -03:00 committed by GitHub
parent 0f9088f730
commit 8bb3bb39b3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 112 deletions

View file

@ -9,33 +9,6 @@ import transformers
import modules.shared as shared
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: list, starting_idx: int):
transformers.StoppingCriteria.__init__(self)
self.sentinel_token_ids = sentinel_token_ids
self.starting_idx = starting_idx
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
for sample in input_ids:
trimmed_sample = sample[self.starting_idx:]
trimmed_len = trimmed_sample.shape[-1]
if trimmed_len < self.shortest:
continue
for sentinel in self.sentinel_token_ids:
sentinel_len = sentinel.shape[-1]
if trimmed_len < sentinel_len:
continue
window = trimmed_sample[-sentinel_len:]
if torch.all(torch.eq(sentinel, window)):
return True
return False
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func