Add extension example, replace input_hijack with chat_input_modifier (#3307)
This commit is contained in:
parent
08c622df2e
commit
ef8637e32d
10 changed files with 335 additions and 100 deletions
|
|
@ -175,7 +175,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
|
||||
# Preparing the input
|
||||
if not any((regenerate, _continue)):
|
||||
text, visible_text = apply_extensions('input_hijack', text, visible_text)
|
||||
text, visible_text = apply_extensions('chat_input', text, visible_text, state)
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import traceback
|
||||
from functools import partial
|
||||
from inspect import signature
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
from inspect import signature
|
||||
|
||||
|
||||
state = {}
|
||||
available_extensions = []
|
||||
|
|
@ -66,15 +65,11 @@ def _apply_string_extensions(function_name, text, state):
|
|||
return text
|
||||
|
||||
|
||||
# Input hijack of extensions
|
||||
def _apply_input_hijack(text, visible_text):
|
||||
# Extension functions that map string -> string
|
||||
def _apply_chat_input_extensions(text, visible_text, state):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
if callable(extension.input_hijack['value']):
|
||||
text, visible_text = extension.input_hijack['value'](text, visible_text)
|
||||
else:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if hasattr(extension, 'chat_input_modifier'):
|
||||
text, visible_text = extension.chat_input_modifier(text, visible_text, state)
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
|
@ -120,7 +115,11 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
|
|||
def _apply_logits_processor_extensions(function_name, processor_list, input_ids):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
getattr(extension, function_name)(processor_list, input_ids)
|
||||
result = getattr(extension, function_name)(processor_list, input_ids)
|
||||
if type(result) is list:
|
||||
processor_list = result
|
||||
|
||||
return processor_list
|
||||
|
||||
|
||||
# Get prompt length in tokens after applying extension functions which override the default tokenizer output
|
||||
|
|
@ -187,12 +186,12 @@ def create_extensions_tabs():
|
|||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"chat_input": _apply_chat_input_extensions,
|
||||
"state": _apply_state_modifier_extensions,
|
||||
"history": _apply_history_modifier_extensions,
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
'logits_processor': partial(_apply_logits_processor_extensions, 'logits_processor_modifier'),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply,
|
||||
"tokenized_length": _apply_custom_tokenized_length,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue