From 7c2babfe39e96cde98b437f71e87c60583a20a50 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 25 Feb 2023 01:42:19 -0300 Subject: [PATCH] Rename greed to "generation attempts" --- modules/chat.py | 31 ++++++++++++++++--------------- modules/shared.py | 3 +++ server.py | 4 ++-- settings-template.json | 3 +++ 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 76bbf99..3dac40c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -84,7 +84,7 @@ def extract_message_from_reply(question, reply, current, other, check, extension def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -113,7 +113,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate reply = '' - for i in range(greed): + for i in range(chat_generation_attempts): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply @@ -138,9 +138,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if next_character_found: break yield shared.history['visible'] - print(i, reply) -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, generation_attempts=1): eos_token = '\n' if check else None if 'pygmalion' in shared.model_name.lower(): @@ -148,19 +147,21 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) - for reply in generate_reply(prompt, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): - reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) - if not substring_found: - yield reply - if next_character_found: - break - yield reply + reply = '' + for i in range(generation_attempts): + for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): + reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False) + if not substring_found: + yield reply + if next_character_found: + break + yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1): - for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed): +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): + for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): yield generate_chat_html(_history, name1, name2, shared.character) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): if shared.character != 'None' and len(shared.history['visible']) == 1: if shared.args.cai_chat: yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -170,7 +171,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, greed): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) diff --git a/modules/shared.py b/modules/shared.py index 7b87d28..d59cee9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,6 +31,9 @@ settings = { 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, 'chat_prompt_size_max': 2048, + 'chat_generation_attempts': 1, + 'chat_generation_attempts_min': 1, + 'chat_generation_attempts_max': 5, 'preset_pygmalion': 'Pygmalion', 'name1_pygmalion': 'You', 'name2_pygmalion': 'Kawaii', diff --git a/server.py b/server.py index a060860..e715193 100644 --- a/server.py +++ b/server.py @@ -241,10 +241,10 @@ if shared.args.chat or shared.args.cai_chat: shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens']) with gr.Column(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) - shared.gradio['greed'] = gr.Slider(minimum=1, maximum=5, value=1, step=1) + shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts') create_settings_menus() - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'greed']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] if shared.args.extensions is not None: with gr.Tab('Extensions'): extensions_module.create_extensions_block() diff --git a/settings-template.json b/settings-template.json index dae7696..1316564 100644 --- a/settings-template.json +++ b/settings-template.json @@ -12,6 +12,9 @@ "chat_prompt_size": 2048, "chat_prompt_size_min": 0, "chat_prompt_size_max": 2048, + "chat_generation_attempts": 1, + "chat_generation_attempts_min": 1, + "chat_generation_attempts_max": 5, "preset_pygmalion": "Pygmalion", "name1_pygmalion": "You", "name2_pygmalion": "Kawaii",