Better HF grammar implementation (#4953)

This commit is contained in:
oobabooga 2023-12-17 02:01:23 -03:00 committed by GitHub
parent aa200f8723
commit 12690d3ffc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 830 additions and 116 deletions

View file

@ -18,7 +18,8 @@ from modules.callbacks import (
_StopEverythingStoppingCriteria
)
from modules.extensions import apply_extensions
from modules.grammar import GrammarLogitsProcessor
from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank
@ -317,11 +318,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
# Logits processor
processor = state.get('logits_processor', LogitsProcessorList([]))
# In case a processor is passed by itself.
if not isinstance(processor, LogitsProcessorList):
processor = LogitsProcessorList([processor])
processor.append(GrammarLogitsProcessor(state['grammar_string']))
# Grammar
if state['grammar_string'].strip() != '':
grammar = initialize_grammar(state['grammar_string'])
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
processor.append(grammar_processor)
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor