From 3443219cbced52cd10836f479cd5095038738bc0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 29 Jun 2023 13:40:13 -0300 Subject: [PATCH] Add repetition penalty range parameter to transformers (#2916) --- api-examples/api-example-chat-stream.py | 1 + api-examples/api-example-chat.py | 1 + api-examples/api-example-stream.py | 1 + api-examples/api-example.py | 1 + extensions/api/util.py | 1 + extensions/openai/script.py | 1 + modules/exllama.py | 1 + modules/presets.py | 5 +-- modules/sampler_hijack.py | 41 +++++++++++++++++++++++++ modules/text_generation.py | 2 +- modules/ui.py | 2 +- server.py | 3 +- 12 files changed, 55 insertions(+), 5 deletions(-) diff --git a/api-examples/api-example-chat-stream.py b/api-examples/api-example-chat-stream.py index 046f3d0..8e37b56 100644 --- a/api-examples/api-example-chat-stream.py +++ b/api-examples/api-example-chat-stream.py @@ -44,6 +44,7 @@ async def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, diff --git a/api-examples/api-example-chat.py b/api-examples/api-example-chat.py index 7048043..23f2f18 100644 --- a/api-examples/api-example-chat.py +++ b/api-examples/api-example-chat.py @@ -38,6 +38,7 @@ def run(user_input, history): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, diff --git a/api-examples/api-example-stream.py b/api-examples/api-example-stream.py index 64d4e05..79a01e4 100644 --- a/api-examples/api-example-stream.py +++ b/api-examples/api-example-stream.py @@ -33,6 +33,7 @@ async def run(context): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, diff --git a/api-examples/api-example.py b/api-examples/api-example.py index 54a4f37..b09823c 100644 --- a/api-examples/api-example.py +++ b/api-examples/api-example.py @@ -25,6 +25,7 @@ def run(prompt): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, diff --git a/extensions/api/util.py b/extensions/api/util.py index 01ae163..d575c60 100644 --- a/extensions/api/util.py +++ b/extensions/api/util.py @@ -21,6 +21,7 @@ def build_parameters(body, chat=False): 'tfs': float(body.get('tfs', 1)), 'top_a': float(body.get('top_a', 0)), 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), + 'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)), 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), 'top_k': int(body.get('top_k', 0)), 'min_length': int(body.get('min_length', 0)), diff --git a/extensions/openai/script.py b/extensions/openai/script.py index e7bbf6e..323d682 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -29,6 +29,7 @@ default_req_params = { 'top_p': 1.0, 'top_k': 1, 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1.0, 'suffix': None, 'stream': False, diff --git a/modules/exllama.py b/modules/exllama.py index 449926e..8543890 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -71,6 +71,7 @@ class ExllamaModel: self.generator.settings.top_k = state['top_k'] self.generator.settings.typical = state['typical_p'] self.generator.settings.token_repetition_penalty_max = state['repetition_penalty'] + self.generator.settings.token_repetition_penalty_sustain = state['repetition_penalty_range'] if state['ban_eos_token']: self.generator.disallow_tokens([self.tokenizer.eos_token_id]) else: diff --git a/modules/presets.py b/modules/presets.py index bb8dc41..d8ae6e1 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -15,6 +15,7 @@ def load_preset(name): 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1, + 'repetition_penalty_range': 0, 'encoder_repetition_penalty': 1, 'top_k': 0, 'num_beams': 1, @@ -46,9 +47,9 @@ def load_preset_memoized(name): def load_preset_for_ui(name, state): generate_params = load_preset(name) state.update(generate_params) - return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']] def generate_preset_yaml(state): - data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} + data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']} return yaml.dump(data, sort_keys=False) diff --git a/modules/sampler_hijack.py b/modules/sampler_hijack.py index bcec250..391ece9 100644 --- a/modules/sampler_hijack.py +++ b/modules/sampler_hijack.py @@ -5,6 +5,7 @@ import transformers from transformers import LogitsWarper from transformers.generation.logits_process import ( LogitNormalization, + LogitsProcessor, LogitsProcessorList, TemperatureLogitsWarper ) @@ -121,6 +122,29 @@ class MirostatLogitsWarper(LogitsWarper): return scores +class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor): + ''' + Copied from the transformers library + ''' + def __init__(self, penalty: float, _range: int): + if not isinstance(penalty, float) or not (penalty > 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + self._range = _range + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + + input_ids = input_ids[:, -self._range:] + score = torch.gather(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * self.penalty, score / self.penalty) + + scores.scatter_(1, input_ids, score) + return scores + + def get_logits_warper_patch(self, generation_config): warpers = self._get_logits_warper_old(generation_config) warpers_to_add = LogitsProcessorList() @@ -146,6 +170,19 @@ def get_logits_warper_patch(self, generation_config): return warpers +def get_logits_processor_patch(self, **kwargs): + result = self._get_logits_processor_old(**kwargs) + repetition_penalty_range = kwargs['generation_config'].repetition_penalty_range + repetition_penalty = kwargs['generation_config'].repetition_penalty + + if repetition_penalty_range > 0: + for i in range(len(result)): + if result[i].__class__.__name__ == 'RepetitionPenaltyLogitsProcessor': + result[i] = RepetitionPenaltyLogitsProcessorWithRange(repetition_penalty, repetition_penalty_range) + + return result + + def generation_config_init_patch(self, **kwargs): self.__init___old(**kwargs) self.tfs = kwargs.pop("tfs", 1.0) @@ -153,11 +190,15 @@ def generation_config_init_patch(self, **kwargs): self.mirostat_mode = kwargs.pop("mirostat_mode", 0) self.mirostat_eta = kwargs.pop("mirostat_eta", 0.1) self.mirostat_tau = kwargs.pop("mirostat_tau", 5) + self.repetition_penalty_range = kwargs.pop("repetition_penalty_range", 0) def hijack_samplers(): transformers.GenerationMixin._get_logits_warper_old = transformers.GenerationMixin._get_logits_warper transformers.GenerationMixin._get_logits_warper = get_logits_warper_patch + transformers.GenerationMixin._get_logits_processor_old = transformers.GenerationMixin._get_logits_processor + transformers.GenerationMixin._get_logits_processor = get_logits_processor_patch + transformers.GenerationConfig.__init___old = transformers.GenerationConfig.__init__ transformers.GenerationConfig.__init__ = generation_config_init_patch diff --git a/modules/text_generation.py b/modules/text_generation.py index 49639cf..171da53 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -230,7 +230,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False): def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): generate_params = {} - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']: generate_params[k] = state[k] for k in ['epsilon_cutoff', 'eta_cutoff']: diff --git a/modules/ui.py b/modules/ui.py index 9f8cd5a..101fb02 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -38,7 +38,7 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream', 'tfs', 'top_a'] if chat: elements += ['name1', 'name2', 'greeting', 'context', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command'] diff --git a/server.py b/server.py index ddd7e8c..9fdbfd9 100644 --- a/server.py +++ b/server.py @@ -336,6 +336,7 @@ def create_settings_menus(default_preset): with gr.Column(): shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty', info='Exponential penalty factor for repeating prior tokens. 1 means no penalty, higher value = less repetition, lower value = more repetition.') + shared.gradio['repetition_penalty_range'] = gr.Slider(0, 4096, step=64, value=generate_params['repetition_penalty_range'], label='repetition_penalty_range', info='The number of most recent tokens to consider for repetition penalty. 0 makes all tokens be used.') shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty', info='Also known as the "Hallucinations filter". Used to penalize tokens that are *not* in the prior text. Higher value = more likely to stay in context, lower value = more likely to diverge.') shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size', info='If not set to 0, specifies the length of token sets that are completely blocked from repeating at all. Higher values = blocks larger phrases, lower values = blocks words or letters from repeating. Only 0 or high values are a good idea in most cases.') shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'], label='min_length', info='Minimum generation length in tokens.') @@ -376,7 +377,7 @@ def create_settings_menus(default_preset): shared.gradio['skip_special_tokens'] = gr.Checkbox(value=shared.settings['skip_special_tokens'], label='Skip special tokens', info='Some specific models need this unset.') shared.gradio['stream'] = gr.Checkbox(value=not shared.args.no_stream, label='Activate text streaming') - shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) + shared.gradio['preset_menu'].change(presets.load_preset_for_ui, [shared.gradio[k] for k in ['preset_menu', 'interface_state']], [shared.gradio[k] for k in ['interface_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]) def create_file_saving_menus():