Add a grammar editor to the UI (#4061)

This commit is contained in:
oobabooga 2023-09-24 18:05:24 -03:00 committed by GitHub
parent 08c4fb12ae
commit 08cf150c0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 75 additions and 62 deletions

View file

@ -43,7 +43,7 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
class LlamaCppModel:
def __init__(self):
self.initialized = False
self.grammar_file = 'None'
self.grammar_string = ''
self.grammar = None
def __del__(self):
@ -110,13 +110,11 @@ class LlamaCppModel:
logits = np.expand_dims(logits, 0) # batch dim is expected
return torch.tensor(logits, dtype=torch.float32)
def load_grammar(self, fname):
if fname != self.grammar_file:
self.grammar_file = fname
p = Path(f'grammars/{fname}')
if p.exists():
logger.info(f'Loading the following grammar file: {p}')
self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p))
def load_grammar(self, string):
if string != self.grammar_string:
self.grammar_string = string
if string.strip() != '':
self.grammar = llama_cpp_lib().LlamaGrammar.from_string(string)
else:
self.grammar = None
@ -131,7 +129,7 @@ class LlamaCppModel:
prompt = prompt[-get_max_prompt_length(state):]
prompt = self.decode(prompt)
self.load_grammar(state['grammar_file'])
self.load_grammar(state['grammar_string'])
logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))