diff --git a/modules/RWKV.py b/modules/RWKV.py index 98b1184..88f1ec2 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -42,4 +42,4 @@ class RWKVModel: token_stop = token_stop ) - return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) diff --git a/modules/text_generation.py b/modules/text_generation.py index 4c9d1f0..cc8b62d 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -86,15 +86,14 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi if shared.is_RWKV: if shared.args.no_stream: - reply = question + shared.model.generate(question, token_count=max_new_tokens, temperature=temperature) + reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature) yield formatted_outputs(reply, None) - return formatted_outputs(reply, None) else: for i in range(max_new_tokens//8): - reply = question + shared.model.generate(question, token_count=8, temperature=temperature) + reply = shared.model.generate(question, token_count=8, temperature=temperature) yield formatted_outputs(reply, None) question = reply - return formatted_outputs(reply, None) + return formatted_outputs(reply, None) original_question = question if not (shared.args.chat or shared.args.cai_chat):