Remove mutable defaults from function signature. (#1663)

This commit is contained in:
IJumpAround 2023-05-08 21:55:41 -04:00 committed by GitHub
parent 32ad47c898
commit 020fe7b50b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 10 deletions

View file

@ -142,7 +142,7 @@ def stop_everything_event():
shared.stop_everything = True
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
def generate_reply(question, state, eos_token=None, stopping_strings=None):
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
@ -173,7 +173,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
yield formatted_outputs(reply, shared.model_name)
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
@ -272,7 +272,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
return
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None):
seed = set_manual_seed(state['seed'])
generate_params = {'token_count': state['max_new_tokens']}
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
@ -309,7 +309,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
return
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=None):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]