Add support for logits processors in extensions (#3029)

This commit is contained in:
Morgan Schweers 2023-07-13 13:22:41 -07:00 committed by GitHub
parent eb823fce96
commit 6d1e911577
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

View file

@ -8,6 +8,7 @@ import traceback
import numpy as np
import torch
import transformers
from transformers import LogitsProcessorList
import modules.shared as shared
from modules.callbacks import (
@ -264,6 +265,13 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())
processor = state.get('logits_processor', LogitsProcessorList([]))
# In case folks just pass in a processor by itself.
if type(processor) != LogitsProcessorList:
processor = LogitsProcessorList([processor])
apply_extensions('logits_processor', processor, input_ids)
generate_params['logits_processor'] = processor
t0 = time.time()
try:
if not is_chat and not shared.is_seq2seq: