diff --git a/modules/chat.py b/modules/chat.py index f40f829..69d81e9 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -115,14 +115,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") - if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) - else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) - # Generate reply = '' for i in range(chat_generation_attempts): + + # The prompt needs to be generated here because, as the reply + # grows, it may become necessary to remove more old messages to + # fit into the 2048 tokens window. + if custom_generate_chat_prompt is None: + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0])) + else: + prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0])) + for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"): # Extracting the reply @@ -156,10 +160,10 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) reply = '' for i in range(chat_generation_attempts): + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]), impersonate=True) for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True) if not substring_found: diff --git a/modules/text_generation.py b/modules/text_generation.py index 7f5aad5..2460df4 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -159,35 +159,53 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi else: generate_params.insert(0, "inputs=input_ids") - # Generate the entire reply at once. - if shared.args.no_stream: - with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - - yield formatted_outputs(reply, shared.model_name) - - # Stream the reply 1 token at a time. - # This is based on the trick of using 'stopping_criteria' to create an iterator. - elif not shared.args.flexgen: - - def generate_with_callback(callback=None, **kwargs): - kwargs['stopping_criteria'].append(Stream(callback_func=callback)) - clear_torch_cache() + try: + # Generate the entire reply at once. + if shared.args.no_stream: with torch.no_grad(): - shared.model.generate(**kwargs) + output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - def generate_with_streaming(**kwargs): - return Iteratorize(generate_with_callback, kwargs, callback=None) + reply = decode(output) + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") - yield formatted_outputs(original_question, shared.model_name) - with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: - for output in generator: + yield formatted_outputs(reply, shared.model_name) + + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator. + elif not shared.args.flexgen: + + def generate_with_callback(callback=None, **kwargs): + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + clear_torch_cache() + with torch.no_grad(): + shared.model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + yield formatted_outputs(original_question, shared.model_name) + with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + for output in generator: + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) + + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + yield formatted_outputs(reply, shared.model_name) + + if output[-1] == n: + break + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + for i in range(max_new_tokens//8+1): + clear_torch_cache() + with torch.no_grad(): + output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) reply = decode(output) @@ -196,30 +214,14 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = original_question + apply_extensions(reply[len(question):], "output") yield formatted_outputs(reply, shared.model_name) - if output[-1] == n: + if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): break - # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' - else: - for i in range(max_new_tokens//8+1): - clear_torch_cache() - with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - reply = decode(output) + input_ids = np.reshape(output, (1, output.shape[0])) + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - yield formatted_outputs(reply, shared.model_name) - - if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): - break - - input_ids = np.reshape(output, (1, output.shape[0])) - if shared.soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") - return + finally: + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") + return