Add epsilon_cutoff/eta_cutoff parameters (#2258)

This commit is contained in:
oobabooga 2023-05-21 15:11:57 -03:00 committed by GitHub
parent 767a767989
commit 8ac3636966
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 36 additions and 9 deletions

View file

@ -15,6 +15,8 @@ def build_parameters(body, chat=False):
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
'eta_cutoff': float(body.get('eta_cutoff', 0)),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),

View file

@ -208,6 +208,8 @@ class Handler(BaseHTTPRequestHandler):
'add_bos_token': shared.settings.get('add_bos_token', True),
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
@ -516,6 +518,8 @@ class Handler(BaseHTTPRequestHandler):
'add_bos_token': shared.settings.get('add_bos_token', True),
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,