Add support for logits processors in extensions (#3029)
This commit is contained in:
parent
eb823fce96
commit
6d1e911577
2 changed files with 19 additions and 2 deletions
|
@ -8,6 +8,7 @@ import traceback
|
|||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import LogitsProcessorList
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.callbacks import (
|
||||
|
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
|
||||
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
|
||||
|
||||
processor = state.get('logits_processor', LogitsProcessorList([]))
|
||||
# In case folks just pass in a processor by itself.
|
||||
if type(processor) != LogitsProcessorList:
|
||||
processor = LogitsProcessorList([processor])
|
||||
apply_extensions('logits_processor', processor, input_ids)
|
||||
generate_params['logits_processor'] = processor
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not is_chat and not shared.is_seq2seq:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue