diff --git a/modules/text_generation.py b/modules/text_generation.py index 7b5fcd6..e18a76d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -22,7 +22,7 @@ def get_max_prompt_length(tokens): return max_length def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - if shared.is_RWKV: + if shared.is_RWKV or shared.is_llamacpp: input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) return input_ids @@ -142,6 +142,24 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi input_ids = encode(question) print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") return + elif shared.is_llamacpp: + try: + if shared.args.no_stream: + reply = shared.model.generate(context=question, num_tokens=max_new_tokens) + yield formatted_outputs(reply, shared.model_name) + else: + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(question, shared.model_name) + for reply in shared.model.generate_with_streaming(context=question, num_tokens=max_new_tokens): + yield formatted_outputs(reply, shared.model_name) + except Exception as e: + print(e) + finally: + t1 = time.time() + output = encode(reply)[0] + input_ids = encode(question) + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") + return input_ids = encode(question, max_new_tokens) original_input_ids = input_ids