Add grammar to llama.cpp loader (closes #4019)

This commit is contained in:
oobabooga 2023-09-24 07:08:41 -07:00
parent a3ad9fe6c0
commit b227e65d86
18 changed files with 162 additions and 0 deletions

View file

@ -1,5 +1,6 @@
import re
from functools import partial
from pathlib import Path
import numpy as np
import torch
@ -42,6 +43,8 @@ def custom_token_ban_logits_processor(token_ids, input_ids, logits):
class LlamaCppModel:
def __init__(self):
self.initialized = False
self.grammar_file = 'None'
self.grammar = None
def __del__(self):
self.model.__del__()
@ -107,6 +110,17 @@ class LlamaCppModel:
logits = np.expand_dims(logits, 0) # batch dim is expected
return torch.tensor(logits, dtype=torch.float32)
def load_grammar(self, fname):
if fname != self.grammar_file:
self.grammar_file = fname
p = Path(f'grammars/{fname}')
print(p)
if p.exists():
logger.info(f'Loading the following grammar file: {p}')
self.grammar = llama_cpp_lib().LlamaGrammar.from_file(str(p))
else:
self.grammar = None
def generate(self, prompt, state, callback=None):
LogitsProcessorList = llama_cpp_lib().LogitsProcessorList
@ -118,6 +132,7 @@ class LlamaCppModel:
prompt = prompt[-get_max_prompt_length(state):]
prompt = self.decode(prompt)
self.load_grammar(state['grammar_file'])
logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
logit_processors.append(partial(ban_eos_logits_processor, self.model.token_eos()))
@ -140,6 +155,7 @@ class LlamaCppModel:
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=logit_processors,
grammar=self.grammar
)
output = ""

View file

@ -305,6 +305,7 @@ loaders_samplers = {
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'grammar_file',
'ban_eos_token',
'custom_token_bans',
},

View file

@ -114,6 +114,7 @@ def list_interface_input_elements():
'mirostat_mode',
'mirostat_tau',
'mirostat_eta',
'grammar_file',
'negative_prompt',
'guidance_scale',
'add_bos_token',

View file

@ -108,6 +108,9 @@ def create_ui(default_preset):
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams', info='For Beam Search, along with length_penalty and early_stopping.')
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Row():
shared.gradio['grammar_file'] = gr.Dropdown(value='None', choices=utils.get_available_grammars(), label='Grammar file (GBNF)', elem_classes='slim-dropdown')
ui.create_refresh_button(shared.gradio['grammar_file'], lambda: None, lambda: {'choices': utils.get_available_grammars()}, 'refresh-button')
with gr.Box():
with gr.Row():

View file

@ -124,3 +124,7 @@ def get_datasets(path: str, ext: str):
def get_available_chat_styles():
return sorted(set(('-'.join(k.stem.split('-')[1:]) for k in Path('css').glob('chat_style*.css'))), key=natural_keys)
def get_available_grammars():
return ['None'] + sorted([item.name for item in list(Path('grammars').glob('*.gbnf'))], key=natural_keys)