Add customizable ban tokens (#3899)
This commit is contained in:
parent
fb864dad7b
commit
f01b9aa71f
16 changed files with 56 additions and 5 deletions
|
@ -31,6 +31,13 @@ def ban_eos_logits_processor(eos_token, input_ids, logits):
|
|||
return logits
|
||||
|
||||
|
||||
def custom_token_ban_logits_processor(token_ids, input_ids, logits):
|
||||
for token_id in token_ids:
|
||||
logits[token_id] = -float('inf')
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class LlamaCppModel:
|
||||
def __init__(self):
|
||||
self.initialized = False
|
||||
|
@ -104,6 +111,15 @@ class LlamaCppModel:
|
|||
prompt = prompt[-get_max_prompt_length(state):]
|
||||
prompt = self.decode(prompt).decode('utf-8')
|
||||
|
||||
logit_processors = LogitsProcessorList()
|
||||
if state['ban_eos_token']:
|
||||
logit_processors.append(partial(ban_eos_logits_processor, self.model.tokenizer.eos_token_id))
|
||||
|
||||
if state['custom_token_bans']:
|
||||
to_ban = [int(x) for x in state['custom_token_bans'].split(',')]
|
||||
if len(to_ban) > 0:
|
||||
logit_processors.append(partial(custom_token_ban_logits_processor, to_ban))
|
||||
|
||||
completion_chunks = self.model.create_completion(
|
||||
prompt=prompt,
|
||||
max_tokens=state['max_new_tokens'],
|
||||
|
@ -116,9 +132,7 @@ class LlamaCppModel:
|
|||
mirostat_tau=state['mirostat_tau'],
|
||||
mirostat_eta=state['mirostat_eta'],
|
||||
stream=True,
|
||||
logits_processor=LogitsProcessorList([
|
||||
partial(ban_eos_logits_processor, self.model.token_eos()),
|
||||
]) if state['ban_eos_token'] else None,
|
||||
logits_processor=logit_processors,
|
||||
)
|
||||
|
||||
output = ""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue