Add customizable ban tokens (#3899)

This commit is contained in:
saltacc 2023-09-15 14:27:27 -07:00 committed by GitHub
parent fb864dad7b
commit f01b9aa71f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 56 additions and 5 deletions

View file

@ -30,7 +30,7 @@ class Exllamav2Model:
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
model = ExLlamaV2(config)
split = None
@ -60,6 +60,11 @@ class Exllamav2Model:
if state['ban_eos_token']:
settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
settings.disallow_tokens(self.tokenizer, to_ban)
ids = self.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
initial_len = ids.shape[-1]