Add grammar to transformers and _HF loaders (#4091)

This commit is contained in:
oobabooga 2023-10-05 10:01:36 -03:00 committed by GitHub
parent 0197fdddf1
commit ae4ba3007f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 56 additions and 0 deletions

33
modules/grammar.py Normal file
View file

@ -0,0 +1,33 @@
from torch_grammar import GrammarSampler
from transformers.generation.logits_process import LogitsProcessor
from modules import shared
sampler = None
grammar = None
grammar_string = ''
class GrammarLogitsProcessor(LogitsProcessor):
def __init__(self, string):
global sampler, grammar, grammar_string
if string != grammar_string:
grammar_string = string
if string.strip() != '':
string = string.strip() + '\n'
sampler = GrammarSampler(string, 'root', shared.tokenizer)
else:
sampler = None
if sampler is not None:
grammar = sampler.logits_processor()
else:
grammar = None
def __call__(self, input_ids, scores):
if grammar is not None:
scores = grammar(input_ids, scores)
return scores