From cf6caf18301ecd2cbc872dbb0db27b77187e83dc Mon Sep 17 00:00:00 2001 From: Maks <10959136+maksmaisak@users.noreply.github.com> Date: Tue, 9 May 2023 16:12:53 +0200 Subject: [PATCH] Make the RWKV model cache the RNN state between messages (#1354) --- modules/RWKV.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/modules/RWKV.py b/modules/RWKV.py index 957bc00..35d650e 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -1,3 +1,4 @@ +import copy import os from pathlib import Path @@ -32,6 +33,10 @@ class RWKVModel: result = self() result.pipeline = pipeline + result.model = model + result.cached_context = "" + result.cached_model_state = None + result.cached_output_logits = None return result def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None): @@ -45,7 +50,17 @@ class RWKVModel: token_stop=token_stop or [] ) - return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + if self.cached_context != "": + if context.startswith(self.cached_context): + context = context[len(self.cached_context):] + else: + self.cached_context = "" + self.cached_model_state = None + self.cached_output_logits = None + + # out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) + out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback) + return out def generate_with_streaming(self, **kwargs): with Iteratorize(self.generate, kwargs, callback=None) as generator: @@ -54,6 +69,61 @@ class RWKVModel: reply += token yield reply + # Similar to the PIPELINE.generate, but lets us maintain the cached_model_state + def generate_from_cached_state(self, ctx="", token_count=20, args=None, callback=None): + all_tokens = [] + out_str = '' + occurrence = {} + state = copy.deepcopy(self.cached_model_state) if self.cached_model_state is not None else None + + # if we ended up with an empty context, just reuse the cached logits + # this can happen if a user undoes a message and then sends the exact message again + # in that case the full context ends up being the same as the cached_context, so the remaining context is empty. + if ctx == "": + out = self.cached_output_logits + + for i in range(token_count): + + # forward + tokens = self.pipeline.encode(ctx) if i == 0 else [token] + while len(tokens) > 0: + out, state = self.model.forward(tokens[:args.chunk_len], state) + tokens = tokens[args.chunk_len:] + + # cache the model state after scanning the context + # we don't cache the state after processing our own generated tokens because + # the output string might be post-processed arbitrarily. Therefore, what's fed into the model + # on the next round of chat might be slightly different what what it output on the previous round + if i == 0: + self.cached_context += ctx + self.cached_model_state = copy.deepcopy(state) + self.cached_output_logits = copy.deepcopy(out) + + # adjust probabilities + for n in args.token_ban: + out[n] = -float('inf') + for n in occurrence: + out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) + + # sampler + token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) + if token in args.token_stop: + break + all_tokens += [token] + if token not in occurrence: + occurrence[token] = 1 + else: + occurrence[token] += 1 + + # output + tmp = self.pipeline.decode([token]) + if '\ufffd' not in tmp: # is valid utf-8 string? + if callback: + callback(tmp) + out_str += tmp + + return out_str + class RWKVTokenizer: def __init__(self):