Add mirostat parameters for llama.cpp (#2287)

This commit is contained in:
oobabooga 2023-05-22 19:37:24 -03:00 committed by GitHub
parent ec7437f00a
commit c0fd7f3257
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 80 additions and 15 deletions

View file

@ -59,7 +59,7 @@ class LlamaCppModel:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
context = context if type(context) is str else context.decode()
completion_chunks = self.model.create_completion(
prompt=context,
@ -68,6 +68,9 @@ class LlamaCppModel:
top_p=top_p,
top_k=top_k,
repeat_penalty=repetition_penalty,
mirostat_mode=int(mirostat_mode),
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
stream=True
)
output = ""