From 8bb3bb39b395ae50dae9f4a62f23e30adc373088 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 24 Jun 2023 09:43:00 -0300 Subject: [PATCH] Implement stopping string search in string space (#2847) --- README.md | 2 +- modules/callbacks.py | 27 ----------- modules/chat.py | 49 ++------------------ modules/text_generation.py | 95 ++++++++++++++++++++++---------------- 4 files changed, 61 insertions(+), 112 deletions(-) diff --git a/README.md b/README.md index 8199639..2c836f4 100644 --- a/README.md +++ b/README.md @@ -350,4 +350,4 @@ The presets that are included by default are the result of a contest that receiv - Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui - Godlike preset: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets -- Code for early stopping in chat mode, code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/ +- Code for some of the sliders: https://github.com/PygmalionAI/gradio-ui/ diff --git a/modules/callbacks.py b/modules/callbacks.py index fb92e18..c61bddf 100644 --- a/modules/callbacks.py +++ b/modules/callbacks.py @@ -9,33 +9,6 @@ import transformers import modules.shared as shared -class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): - - def __init__(self, sentinel_token_ids: list, starting_idx: int): - transformers.StoppingCriteria.__init__(self) - self.sentinel_token_ids = sentinel_token_ids - self.starting_idx = starting_idx - self.shortest = min([x.shape[-1] for x in sentinel_token_ids]) - - def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool: - for sample in input_ids: - trimmed_sample = sample[self.starting_idx:] - trimmed_len = trimmed_sample.shape[-1] - if trimmed_len < self.shortest: - continue - - for sentinel in self.sentinel_token_ids: - sentinel_len = sentinel.shape[-1] - if trimmed_len < sentinel_len: - continue - - window = trimmed_sample[-sentinel_len:] - if torch.all(torch.eq(sentinel, window)): - return True - - return False - - class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func diff --git a/modules/chat.py b/modules/chat.py index f4acd0c..7d1cea1 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -1,4 +1,3 @@ -import ast import base64 import copy import functools @@ -144,40 +143,10 @@ def get_stopping_strings(state): f"\n{state['name2']}:" ] - stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]") - return stopping_strings - - -def extract_message_from_reply(reply, state): - next_character_found = False - stopping_strings = get_stopping_strings(state) - if state['stop_at_newline']: - lines = reply.split('\n') - reply = lines[0].strip() - if len(lines) > 1: - next_character_found = True - else: - for string in stopping_strings: - idx = reply.find(string) - if idx != -1: - reply = reply[:idx] - next_character_found = True + stopping_strings.append("\n") - # If something like "\nYo" is generated just before "\nYou:" - # is completed, trim it - if not next_character_found: - for string in stopping_strings: - for j in range(len(string) - 1, 0, -1): - if reply[-j:] == string[:j]: - reply = reply[:-j] - break - else: - continue - - break - - return reply, next_character_found + return stopping_strings def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loading_message=True): @@ -191,7 +160,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa # Defining some variables just_started = True visible_text = None - eos_token = '\n' if state['stop_at_newline'] else None stopping_strings = get_stopping_strings(state) # Preparing the input @@ -231,11 +199,10 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa cumulative_reply = '' for i in range(state['chat_generation_attempts']): reply = None - for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True)): + for j, reply in enumerate(generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True)): reply = cumulative_reply + reply # Extract the reply - reply, next_character_found = extract_message_from_reply(reply, state) visible_reply = re.sub("(||{{user}})", state['name1'], reply) visible_reply = apply_extensions("output", visible_reply) @@ -262,9 +229,6 @@ def chatbot_wrapper(text, history, state, regenerate=False, _continue=False, loa if state['stream']: yield output - if next_character_found: - break - if reply in [None, cumulative_reply]: break else: @@ -281,7 +245,6 @@ def impersonate_wrapper(text, start_with, state): # Defining some variables cumulative_reply = '' - eos_token = '\n' if state['stop_at_newline'] else None prompt = generate_chat_prompt('', state, impersonate=True) stopping_strings = get_stopping_strings(state) @@ -289,16 +252,12 @@ def impersonate_wrapper(text, start_with, state): cumulative_reply = text for i in range(state['chat_generation_attempts']): reply = None - for reply in generate_reply(prompt + cumulative_reply, state, eos_token=eos_token, stopping_strings=stopping_strings, is_chat=True): + for reply in generate_reply(prompt + cumulative_reply, state, stopping_strings=stopping_strings, is_chat=True): reply = cumulative_reply + reply - reply, next_character_found = extract_message_from_reply(reply, state) yield reply.lstrip(' ') if shared.stop_everything: return - if next_character_found: - break - if reply in [None, cumulative_reply]: break else: diff --git a/modules/text_generation.py b/modules/text_generation.py index d0965b8..81ada7e 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -9,8 +9,7 @@ import torch import transformers import modules.shared as shared -from modules.callbacks import (Iteratorize, Stream, - _SentinelTokenStoppingCriteria) +from modules.callbacks import Iteratorize, Stream from modules.extensions import apply_extensions from modules.html_generator import generate_4chan_html, generate_basic_html from modules.logging_colors import logger @@ -42,11 +41,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id: input_ids = input_ids[:, 1:] - # Llama adds this extra token when the first character is '\n', and this - # compromises the stopping criteria, so we just remove it - if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: - input_ids = input_ids[:, 1:] - # Handling truncation if truncation_length is not None: input_ids = input_ids[:, -truncation_length:] @@ -139,15 +133,43 @@ def stop_everything_event(): shared.stop_everything = True -def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None): - for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False): +def generate_reply_wrapper(question, state, stopping_strings=None): + reply = question if not shared.is_seq2seq else '' + yield formatted_outputs(reply, shared.model_name) + + for reply in generate_reply(question, state, stopping_strings, is_chat=False): if not shared.is_seq2seq: reply = question + reply yield formatted_outputs(reply, shared.model_name) -def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False): +def apply_stopping_strings(reply, all_stop_strings): + stop_found = False + for string in all_stop_strings: + idx = reply.find(string) + if idx != -1: + reply = reply[:idx] + stop_found = True + break + + if not stop_found: + # If something like "\nYo" is generated just before "\nYou:" + # is completed, trim it + for string in all_stop_strings: + for j in range(len(string) - 1, 0, -1): + if reply[-j:] == string[:j]: + reply = reply[:-j] + break + else: + continue + + break + + return reply, stop_found + + +def _generate_reply(question, state, stopping_strings=None, is_chat=False): state = apply_extensions('state', state) generate_func = apply_extensions('custom_generate_reply') if generate_func is None: @@ -168,29 +190,39 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c if not is_chat: question = apply_extensions('input', question) + # Finding the stopping strings + all_stop_strings = [] + for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): + if type(st) is list and len(st) > 0: + all_stop_strings += st + if shared.args.verbose: print(f'\n\n{question}\n--------------------\n') shared.stop_everything = False clear_torch_cache() seed = set_manual_seed(state['seed']) - is_stream = state['stream'] last_update = -1 reply = '' - for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings, is_chat=is_chat): + is_stream = state['stream'] + if len(all_stop_strings) > 0 and not state['stream']: + state['stream'] = True + + for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): + reply, stop_found = apply_stopping_strings(reply, all_stop_strings) if is_stream: cur_time = time.time() if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps last_update = cur_time yield reply - else: - yield reply - if is_stream: - yield reply + if stop_found: + break + + yield reply -def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): +def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: generate_params[k] = state[k] @@ -213,11 +245,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, output = input_ids[0] cuda = not any((shared.args.cpu, shared.args.deepspeed)) - # Find the eos tokens - eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] - if eos_token is not None: - eos_token_ids.append(int(encode(eos_token)[0][-1])) - # Add the encoded tokens to generate_params question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) original_input_ids = input_ids @@ -225,17 +252,10 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, if inputs_embeds is not None: generate_params.update({'inputs_embeds': inputs_embeds}) - # Create the StoppingCriteriaList with the stopping strings (needs to be done after tokenizer extensions) - stopping_criteria_list = transformers.StoppingCriteriaList() - for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): - if type(st) is list and len(st) > 0: - sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st] - stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) - break - - # Update generate_params with the eos token and the stopping strings + # Find the eos tokens + eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] generate_params['eos_token_id'] = eos_token_ids - generate_params['stopping_criteria'] = stopping_criteria_list + generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() t0 = time.time() try: @@ -280,7 +300,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None, return -def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): +def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False): seed = set_manual_seed(state['seed']) t0 = time.time() @@ -312,7 +332,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No return -def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False): +def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} for k in ['max_new_tokens', 'do_sample', 'temperature']: generate_params[k] = state[k] @@ -326,8 +346,8 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N # Find the eos tokens eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] - if eos_token is not None: - eos_token_ids.append(int(encode(eos_token)[0][-1])) + if not state['ban_eos_token']: + generate_params['stop'] = eos_token_ids[-1] # Add the encoded tokens to generate_params question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) @@ -336,9 +356,6 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N if inputs_embeds is not None: generate_params.update({'inputs_embeds': inputs_embeds}) - # Update generate_params with the eos token and the stopping strings - generate_params['stop'] = eos_token_ids[-1] - t0 = time.time() try: if not is_chat: