Add customizable ban tokens (#3899)

This commit is contained in:
saltacc 2023-09-15 14:27:27 -07:00 committed by GitHub
parent fb864dad7b
commit f01b9aa71f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 56 additions and 5 deletions

View file

@ -266,6 +266,14 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
if generate_params.get('suppress_tokens', None):
generate_params['suppress_tokens'] += to_ban
else:
generate_params['suppress_tokens'] = to_ban
generate_params.update({'use_cache': not shared.args.no_cache})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})