Make stop_everything work with non-streamed generation (#2848)
This commit is contained in:
parent
ec482f3dae
commit
e356f69b36
2 changed files with 12 additions and 2 deletions
|
@ -9,7 +9,8 @@ import torch
|
|||
import transformers
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import Iteratorize, Stream
|
||||
from modules.callbacks import (Iteratorize, Stream,
|
||||
_StopEverythingStoppingCriteria)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
|
@ -252,10 +253,11 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Find the eos tokens
|
||||
# Stopping criteria / eos token
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria());
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue