Refactor text_generation.py, add support for custom generation functions (#1817)
This commit is contained in:
parent
876fbb97c0
commit
8aafb1f796
12 changed files with 289 additions and 195 deletions
|
@ -86,6 +86,15 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
|||
return None
|
||||
|
||||
|
||||
# Extension that modifies the input parameters before they are used
|
||||
def _apply_state_modifier_extensions(state):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, "state_modifier"):
|
||||
state = getattr(extension, "state_modifier")(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
|
@ -95,13 +104,24 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
|
|||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
# Custom generate reply handling
|
||||
def _apply_custom_generate_reply():
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_generate_reply'):
|
||||
return getattr(extension, 'custom_generate_reply')
|
||||
|
||||
return None
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"state": _apply_state_modifier_extensions,
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt,
|
||||
"custom_generate_reply": _apply_custom_generate_reply
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue