From ae4ba3007f7bc825afcd656cbadab38da498f600 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 5 Oct 2023 10:01:36 -0300 Subject: [PATCH] Add grammar to transformers and _HF loaders (#4091) --- modules/grammar.py | 33 ++++++++++++++++++++++++++++++++ modules/loaders.py | 12 ++++++++++++ modules/text_generation.py | 2 ++ requirements.txt | 1 + requirements_amd.txt | 1 + requirements_amd_noavx2.txt | 1 + requirements_apple_intel.txt | 1 + requirements_apple_silicon.txt | 1 + requirements_cpu_only.txt | 1 + requirements_cpu_only_noavx2.txt | 1 + requirements_noavx2.txt | 1 + requirements_nowheels.txt | 1 + 12 files changed, 56 insertions(+) create mode 100644 modules/grammar.py diff --git a/modules/grammar.py b/modules/grammar.py new file mode 100644 index 0000000..5f6ad3a --- /dev/null +++ b/modules/grammar.py @@ -0,0 +1,33 @@ +from torch_grammar import GrammarSampler +from transformers.generation.logits_process import LogitsProcessor + +from modules import shared + +sampler = None +grammar = None +grammar_string = '' + + +class GrammarLogitsProcessor(LogitsProcessor): + def __init__(self, string): + + global sampler, grammar, grammar_string + + if string != grammar_string: + grammar_string = string + if string.strip() != '': + string = string.strip() + '\n' + sampler = GrammarSampler(string, 'root', shared.tokenizer) + else: + sampler = None + + if sampler is not None: + grammar = sampler.logits_processor() + else: + grammar = None + + def __call__(self, input_ids, scores): + if grammar is not None: + scores = grammar(input_ids, scores) + + return scores diff --git a/modules/loaders.py b/modules/loaders.py index 7580e30..964fb00 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -156,6 +156,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -183,6 +185,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -236,6 +240,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -267,6 +273,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -298,6 +306,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', @@ -339,6 +349,8 @@ loaders_samplers = { 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', 'guidance_scale', 'negative_prompt', 'ban_eos_token', diff --git a/modules/text_generation.py b/modules/text_generation.py index ab556a9..a7f3509 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -18,6 +18,7 @@ from modules.callbacks import ( _StopEverythingStoppingCriteria ) from modules.extensions import apply_extensions +from modules.grammar import GrammarLogitsProcessor from modules.html_generator import generate_4chan_html, generate_basic_html from modules.logging_colors import logger from modules.models import clear_torch_cache, local_rank @@ -319,6 +320,7 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings # In case a processor is passed by itself. if not isinstance(processor, LogitsProcessorList): processor = LogitsProcessorList([processor]) + processor.append(GrammarLogitsProcessor(state['grammar_string'])) apply_extensions('logits_processor', processor, input_ids) generate_params['logits_processor'] = processor diff --git a/requirements.txt b/requirements.txt index 9d519b9..b651c5e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_amd.txt b/requirements_amd.txt index b15e4d0..2ea6e35 100644 --- a/requirements_amd.txt +++ b/requirements_amd.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" diff --git a/requirements_amd_noavx2.txt b/requirements_amd_noavx2.txt index d567d79..18a8133 100644 --- a/requirements_amd_noavx2.txt +++ b/requirements_amd_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.38.1; platform_system != "Windows" diff --git a/requirements_apple_intel.txt b/requirements_apple_intel.txt index 6a37726..3d3896a 100644 --- a/requirements_apple_intel.txt +++ b/requirements_apple_intel.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_apple_silicon.txt b/requirements_apple_silicon.txt index 76024e2..fb8598e 100644 --- a/requirements_apple_silicon.txt +++ b/requirements_apple_silicon.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_cpu_only.txt b/requirements_cpu_only.txt index 39f7e7d..cde13ef 100644 --- a/requirements_cpu_only.txt +++ b/requirements_cpu_only.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_cpu_only_noavx2.txt b/requirements_cpu_only_noavx2.txt index ec66500..bf7a5fd 100644 --- a/requirements_cpu_only_noavx2.txt +++ b/requirements_cpu_only_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_noavx2.txt b/requirements_noavx2.txt index f0e3383..6057d46 100644 --- a/requirements_noavx2.txt +++ b/requirements_noavx2.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows" diff --git a/requirements_nowheels.txt b/requirements_nowheels.txt index d71d82d..2984bd3 100644 --- a/requirements_nowheels.txt +++ b/requirements_nowheels.txt @@ -25,6 +25,7 @@ tqdm wandb git+https://github.com/huggingface/transformers@211f93aab95d1c683494e61c3cf8ff10e1f5d6b7 +git+https://github.com/oobabooga/torch-grammar.git # bitsandbytes bitsandbytes==0.41.1; platform_system != "Windows"