Add mirostat parameters for llama.cpp (#2287)

This commit is contained in:
oobabooga 2023-05-22 19:37:24 -03:00 committed by GitHub
parent ec7437f00a
commit c0fd7f3257
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 80 additions and 15 deletions

View file

@ -59,7 +59,7 @@ class LlamaCppModel:
string = string.encode()
return self.model.tokenize(string)
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
context = context if type(context) is str else context.decode()
completion_chunks = self.model.create_completion(
prompt=context,
@ -68,6 +68,9 @@ class LlamaCppModel:
top_p=top_p,
top_k=top_k,
repeat_penalty=repetition_penalty,
mirostat_mode=int(mirostat_mode),
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
stream=True
)
output = ""

View file

@ -294,6 +294,10 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
if shared.model_type == 'llamacpp':
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
generate_params[k] = state[k]
t0 = time.time()
reply = ''
try:

View file

@ -37,7 +37,7 @@ def list_model_elements():
def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu', 'stream']
if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu', 'name1_instruct', 'name2_instruct', 'context_instruct', 'turn_template', 'chat_style', 'chat-instruct_command']