Add customizable ban tokens (#3899)

This commit is contained in:
saltacc 2023-09-15 14:27:27 -07:00 committed by GitHub
parent fb864dad7b
commit f01b9aa71f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 56 additions and 5 deletions

View file

@ -31,6 +31,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits):
return logits
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
for token_id in token_ids:
logits[token_id] = -float('inf')
return logits
class LlamaCppModel:
def __init__(self):
self.initialized = False
@ -104,6 +111,15 @@ class LlamaCppModel:
prompt = prompt[-get_max_prompt_length(state):]
prompt = self.decode(prompt).decode('utf-8')
logit_processors = LogitsProcessorList()
if state['ban_eos_token']:
logit_processors.append(partial(ban_eos_logits_processor, self.model.tokenizer.eos_token_id))
if state['custom_token_bans']:
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
if len(to_ban) > 0:
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
completion_chunks = self.model.create_completion(
prompt=prompt,
max_tokens=state['max_new_tokens'],
@ -116,9 +132,7 @@ class LlamaCppModel:
mirostat_tau=state['mirostat_tau'],
mirostat_eta=state['mirostat_eta'],
stream=True,
logits_processor=LogitsProcessorList([
partial(ban_eos_logits_processor, self.model.token_eos()),
]) if state['ban_eos_token'] else None,
logits_processor=logit_processors,
)
output = ""