Make stop_everything work with non-streamed generation (#2848)
This commit is contained in:
parent
ec482f3dae
commit
e356f69b36
2 changed files with 12 additions and 2 deletions
|
@ -9,6 +9,14 @@ import transformers
|
|||
import modules.shared as shared
|
||||
|
||||
|
||||
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||
def __init__(self):
|
||||
transformers.StoppingCriteria.__init__(self)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||
return shared.stop_everything
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue