diff --git a/modules/chat.py b/modules/chat.py index db79e7d..2e491d9 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -119,6 +119,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate cumulative_reply = '' for i in range(chat_generation_attempts): + reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply @@ -145,7 +146,8 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if next_character_found: break - cumulative_reply = reply + if reply is not None: + cumulative_reply = reply yield shared.history['visible'] @@ -162,6 +164,7 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ cumulative_reply = '' for i in range(chat_generation_attempts): + reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, name1, name2, check) @@ -169,7 +172,8 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ if next_character_found: break - cumulative_reply = reply + if reply is not None: + cumulative_reply = reply yield reply