Add top_k to RWKV

This commit is contained in:
oobabooga 2023-03-07 17:24:28 -03:00
parent 827ae51f72
commit 8660227e1b
3 changed files with 5 additions and 4 deletions

View file

@ -33,10 +33,11 @@ class RWKVModel:
result.pipeline = pipeline
return result
def generate(self, context, token_count=20, temperature=1, top_p=1, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
def generate(self, context, token_count=20, temperature=1, top_p=1, top_k=50, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
args = PIPELINE_ARGS(
temperature = temperature,
top_p = top_p,
top_k = top_k,
alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
token_ban = token_ban, # ban the generation of some tokens