Add grammar to transformers and _HF loaders (#4091)

This commit is contained in:
oobabooga 2023-10-05 10:01:36 -03:00 committed by GitHub
parent 0197fdddf1
commit ae4ba3007f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 0 deletions

View file

@ -18,6 +18,7 @@ from modules.callbacks import (
_StopEverythingStoppingCriteria
)
from modules.extensions import apply_extensions
from modules.grammar import GrammarLogitsProcessor
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank
@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
# In case a processor is passed by itself.
if not isinstance(processor, LogitsProcessorList):
processor = LogitsProcessorList([processor])
processor.append(GrammarLogitsProcessor(state['grammar_string']))
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor