diff --git a/server.py b/server.py index 2773b92..c158636 100644 --- a/server.py +++ b/server.py @@ -71,9 +71,9 @@ settings = { 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:', 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n', 'stop_at_newline': True, - 'history_size': 0, - 'history_size_min': 0, - 'history_size_max': 64, + 'chat_prompt_size': 2048, + 'chat_prompt_size_min': 0, + 'chat_prompt_size_max': 2048, 'preset_pygmalion': 'Pygmalion', 'name1_pygmalion': 'You', 'name2_pygmalion': 'Kawaii', @@ -503,13 +503,13 @@ def clean_chat_message(text): text = text.strip() return text -def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=False): +def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False): text = clean_chat_message(text) rows = [f"{context.strip()}\n"] i = len(history['internal'])-1 count = 0 - max_length = get_max_prompt_length(tokens) + max_length = min(get_max_prompt_length(tokens), chat_prompt_size) while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") count += 1 @@ -517,8 +517,6 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe rows.insert(1, f"{name1}: {history['internal'][i][0].strip()}\n") count += 1 i -= 1 - if history_size != 0 and count >= history_size: - break if not impersonate: rows.append(f"{name1}: {text}\n") @@ -566,14 +564,14 @@ def extract_message_from_reply(question, reply, current, other, check, extension return reply, next_character_found, substring_found -def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): +def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): if args.picture and picture is not None: text, visible_text = generate_chat_picture(picture, name1, name2) else: visible_text = text text = apply_extensions(text, "input") - question = generate_chat_prompt(text, tokens, name1, name2, context, history_size) + question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size) history['internal'].append(['', '']) history['visible'].append(['', '']) eos_token = '\n' if check else None @@ -587,8 +585,8 @@ def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, break yield history['visible'] -def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): - question = generate_chat_prompt(text, tokens, name1, name2, context, history_size, impersonate=True) +def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + question = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True) eos_token = '\n' if check else None for reply in generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found, substring_found = extract_message_from_reply(question, reply, name1, name2, check, extensions=False) @@ -598,19 +596,19 @@ def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, to break yield apply_extensions(reply, "output") -def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): - for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): +def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): + for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): yield generate_chat_html(_history, name1, name2, character) -def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture=None): +def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None): last = history['visible'].pop() history['internal'].pop() text = last[0] if args.cai_chat: - for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): + for i in cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): yield i else: - for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size, picture): + for i in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture): yield i def remove_last_message(name1, name2): @@ -886,7 +884,7 @@ if args.chat or args.cai_chat: with gr.Column(): max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens']) with gr.Column(): - history_size_slider = gr.Slider(minimum=settings['history_size_min'], maximum=settings['history_size_max'], step=1, label='Chat history size in prompt (0 for no limit)', value=settings['history_size']) + chat_prompt_size_slider = gr.Slider(minimum=settings['chat_prompt_size_min'], maximum=settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=settings['chat_prompt_size']) preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus() @@ -926,7 +924,7 @@ if args.chat or args.cai_chat: if args.extensions is not None: create_extensions_block() - input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, history_size_slider] + input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider] if args.picture: input_params.append(picture_select) if args.cai_chat: diff --git a/settings-template.json b/settings-template.json index 6c1099c..dae7696 100644 --- a/settings-template.json +++ b/settings-template.json @@ -9,9 +9,9 @@ "prompt": "Common sense questions and answers\n\nQuestion: \nFactual answer:", "prompt_gpt4chan": "-----\n--- 865467536\nInput text\n--- 865467537\n", "stop_at_newline": true, - "history_size": 0, - "history_size_min": 0, - "history_size_max": 64, + "chat_prompt_size": 2048, + "chat_prompt_size_min": 0, + "chat_prompt_size_max": 2048, "preset_pygmalion": "Pygmalion", "name1_pygmalion": "You", "name2_pygmalion": "Kawaii",