Add --rwkv-cuda-on parameter, bump rwkv version

This commit is contained in:
oobabooga 2023-03-06 20:12:54 -03:00
parent eebec65075
commit 153dfeb4dd
3 changed files with 4 additions and 3 deletions

View file

@ -9,7 +9,7 @@ import modules.shared as shared
np.set_printoptions(precision=4, suppress=True, linewidth=200)
os.environ['RWKV_JIT_ON'] = '1'
os.environ["RWKV_CUDA_ON"] = '0' # '1' : use CUDA kernel for seq mode (much faster)
os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS