Add RWKV tokenizer

This commit is contained in:
oobabooga 2023-03-06 08:45:49 -03:00
parent c855b828fe
commit e91f4bc25a
3 changed files with 34 additions and 15 deletions

View file

@ -2,6 +2,7 @@ import os
from pathlib import Path
import numpy as np
from tokenizers import Tokenizer
import modules.shared as shared
@ -43,3 +44,22 @@ class RWKVModel:
)
return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
class RWKVTokenizer:
def __init__(self):
pass
@classmethod
def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(os.path.abspath(tokenizer_path))
result = self()
result.tokenizer = tokenizer
return result
def encode(self, prompt):
return self.tokenizer.encode(prompt).ids
def decode(self, ids):
return self.tokenizer.decode(ids)