Add 'hallucinations' filter #326

This breaks the API since a new parameter has been added.
It should be a one-line fix. See api-example.py.
This commit is contained in:
oobabooga 2023-03-15 11:04:30 -03:00
parent 128d18e298
commit 9d6a625bd6
5 changed files with 25 additions and 18 deletions

View file

@ -89,7 +89,7 @@ def clear_torch_cache():
if not shared.args.cpu:
torch.cuda.empty_cache()
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
clear_torch_cache()
t0 = time.time()
@ -143,6 +143,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,