Make stop_everything work with non-streamed generation (#2848)
This commit is contained in:
parent
ec482f3dae
commit
e356f69b36
2 changed files with 12 additions and 2 deletions
|
@ -9,6 +9,14 @@ import transformers
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
|
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
|
||||||
|
def __init__(self):
|
||||||
|
transformers.StoppingCriteria.__init__(self)
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
||||||
|
return shared.stop_everything
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
class Stream(transformers.StoppingCriteria):
|
||||||
def __init__(self, callback_func=None):
|
def __init__(self, callback_func=None):
|
||||||
self.callback_func = callback_func
|
self.callback_func = callback_func
|
||||||
|
|
|
@ -9,7 +9,8 @@ import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules.callbacks import Iteratorize, Stream
|
from modules.callbacks import (Iteratorize, Stream,
|
||||||
|
_StopEverythingStoppingCriteria)
|
||||||
from modules.extensions import apply_extensions
|
from modules.extensions import apply_extensions
|
||||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||||
|
|
||||||
# Find the eos tokens
|
# Stopping criteria / eos token
|
||||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||||
generate_params['eos_token_id'] = eos_token_ids
|
generate_params['eos_token_id'] = eos_token_ids
|
||||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||||
|
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue