Better HF grammar implementation (#4953)
This commit is contained in:
parent
aa200f8723
commit
12690d3ffc
19 changed files with 830 additions and 116 deletions
|
@ -18,7 +18,8 @@ from modules.callbacks import (
|
|||
_StopEverythingStoppingCriteria
|
||||
)
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.grammar import GrammarLogitsProcessor
|
||||
from modules.grammar.grammar_utils import initialize_grammar
|
||||
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache, local_rank
|
||||
|
@ -317,11 +318,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||
|
||||
# Logits processor
|
||||
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||
# In case a processor is passed by itself.
|
||||
if not isinstance(processor, LogitsProcessorList):
|
||||
processor = LogitsProcessorList([processor])
|
||||
processor.append(GrammarLogitsProcessor(state['grammar_string']))
|
||||
|
||||
# Grammar
|
||||
if state['grammar_string'].strip() != '':
|
||||
grammar = initialize_grammar(state['grammar_string'])
|
||||
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
|
||||
processor.append(grammar_processor)
|
||||
|
||||
apply_extensions('logits_processor', processor, input_ids)
|
||||
generate_params['logits_processor'] = processor
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue