Make stop_everything work with non-streamed generation (#2848)

This commit is contained in:
快乐的我531 2023-06-24 22:19:16 +08:00 committed by GitHub
parent ec482f3dae
commit e356f69b36
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 2 deletions

View file

@ -9,6 +9,14 @@ import transformers
import modules.shared as shared
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func