Implement stopping string search in string space (#2847)
This commit is contained in:
parent
0f9088f730
commit
8bb3bb39b3
4 changed files with 61 additions and 112 deletions
|
|
@ -9,33 +9,6 @@ import transformers
|
|||
import modules.shared as shared
|
||||
|
||||
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
self.sentinel_token_ids = sentinel_token_ids
|
||||
self.starting_idx = starting_idx
|
||||
self.shortest = min([x.shape[-1] for x in sentinel_token_ids])
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
for sample in input_ids:
|
||||
trimmed_sample = sample[self.starting_idx:]
|
||||
trimmed_len = trimmed_sample.shape[-1]
|
||||
if trimmed_len < self.shortest:
|
||||
continue
|
||||
|
||||
for sentinel in self.sentinel_token_ids:
|
||||
sentinel_len = sentinel.shape[-1]
|
||||
if trimmed_len < sentinel_len:
|
||||
continue
|
||||
|
||||
window = trimmed_sample[-sentinel_len:]
|
||||
if torch.all(torch.eq(sentinel, window)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue