Add Mirostat v2 sampling to transformer models (#2571)

This commit is contained in:
brandonj60 2023-06-09 19:26:31 -05:00 committed by GitHub
parent aff3e04df4
commit b04e18d10c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 66 additions and 7 deletions

View file

@ -193,7 +193,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a']:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
generate_params[k] = state[k]
for k in ['epsilon_cutoff', 'eta_cutoff']: