Add grammar to llama.cpp loader (closes #4019)

This commit is contained in:
oobabooga 2023-09-24 07:08:41 -07:00
parent a3ad9fe6c0
commit b227e65d86
18 changed files with 162 additions and 0 deletions

View file

@ -1,5 +1,6 @@
import re
from functools import partial
from pathlib import Path
import numpy as np
import torch
@ -42,6 +43,8 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
class LlamaCppModel:
def __init__(self):
self.initialized = False
self.grammar_file = 'None'
self.grammar = None
def __del__(self):
self.model.__del__()
@ -107,6 +110,17 @@ class LlamaCppModel:
logits = np.expand_dims(logits, 0) # batch dim is expected
return torch.tensor(logits, dtype=torch.float32)
def load_grammar(self, fname):
if fname != self.grammar_file:
self.grammar_file = fname
p = Path(f'grammars/{fname}')
print(p)
if p.exists():
logger.info(f'Loading the following grammar file: {p}')
self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p))
else:
self.grammar = None
def generate(self, prompt, state, callback=None):
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
@ -118,6 +132,7 @@ class LlamaCppModel:
prompt = prompt[-get_max_prompt_length(state):]
prompt = self.decode(prompt)
self.load_grammar(state['grammar_file'])
logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
@ -140,6 +155,7 @@ class LlamaCppModel:
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=logit_processors,
grammar=self.grammar
)
output = ""