Add extension example, replace input_hijack with chat_input_modifier (#3307)

This commit is contained in:
oobabooga 2023-07-25 18:49:56 -03:00 committed by GitHub
parent 08c622df2e
commit ef8637e32d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 335 additions and 100 deletions

View file

@ -7,10 +7,15 @@ from modules import shared
from modules.chat import generate_chat_reply
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
update_model_parameters)
from modules.text_generation import (encode, generate_reply,
stop_everything_event)
from modules.models_settings import (
get_model_settings_from_yamls,
update_model_parameters
)
from modules.text_generation import (
encode,
generate_reply,
stop_everything_event
)
from modules.utils import get_available_models

View file

@ -2,12 +2,15 @@ import asyncio
import json
from threading import Thread
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared, with_api_lock
from extensions.api.util import (
build_parameters,
try_start_cloudflared,
with_api_lock
)
from modules import shared
from modules.chat import generate_chat_reply
from modules.text_generation import generate_reply
from websockets.server import serve
PATH = '/api/v1/stream'

View file

@ -10,7 +10,6 @@ from modules import shared
from modules.chat import load_character_memoized
from modules.presets import load_preset_memoized
# We use a thread local to store the asyncio lock, so that each thread
# has its own lock. This isn't strictly necessary, but it makes it
# such that if we can support multiple worker threads in the future,

View file

@ -0,0 +1,129 @@
"""
An example of extension. It does nothing, but you can add transformations
before the return statements to customize the webui behavior.
Starting from history_modifier and ending in output_modifier, the
functions are declared in the same order that they are called at
generation time.
"""
import torch
from modules import chat
from modules.text_generation import (
decode,
encode,
generate_reply,
)
from transformers import LogitsProcessor
params = {
"display_name": "Example Extension",
"is_tab": False,
}
class MyLogits(LogitsProcessor):
"""
Manipulates the probabilities for the next token before it gets sampled.
It gets used in the custom_logits_processor function below.
"""
def __init__(self):
pass
def __call__(self, input_ids, scores):
# probs = torch.softmax(scores, dim=-1, dtype=torch.float)
# probs[0] /= probs[0].sum()
# scores = torch.log(probs / (1 - probs))
return scores
def history_modifier(history):
"""
Modifies the chat history.
Only used in chat mode.
"""
return history
def state_modifier(state):
"""
Modifies the state variable, which is a dictionary containing the input
values in the UI like sliders and checkboxes.
"""
return state
def chat_input_modifier(text, visible_text, state):
"""
Modifies the internal and visible input strings in chat mode.
"""
return text, visible_text
def input_modifier(string, state):
"""
In chat mode, modifies the user input. The modified version goes into
history['internal'], and the original version goes into history['visible'].
In default/notebook modes, modifies the whole prompt.
"""
return string
def bot_prefix_modifier(string, state):
"""
Modifies the prefix for the next bot reply in chat mode.
By default, the prefix will be something like "Bot Name:".
"""
return string
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
"""
Modifies the input ids and embeds.
Used by the multimodal extension to put image embeddings in the prompt.
Only used by loaders that use the transformers library for sampling.
"""
return prompt, input_ids, input_embeds
def logits_processor_modifier(processor_list, input_ids):
"""
Adds logits processors to the list.
Only used by loaders that use the transformers library for sampling.
"""
processor_list.append(MyLogits())
return processor_list
def output_modifier(string, state):
"""
Modifies the LLM output before it gets presented.
In chat mode, the modified version goes into history['internal'], and the original version goes into history['visible'].
"""
return string
def custom_generate_chat_prompt(user_input, state, **kwargs):
"""
Replaces the function that generates the prompt from the chat history.
Only used in chat mode.
"""
result = chat.generate_chat_prompt(user_input, state, **kwargs)
return result
def custom_css():
"""
Returns a CSS string that gets appended to the CSS for the webui.
"""
return ''
def custom_js():
"""
Returns a javascript string that gets appended to the javascript for the webui.
"""
return ''
def setup():
"""
Gets executed only once, when the extension is imported.
"""
pass
def ui():
"""
Gets executed when the UI is drawn. Custom gradio elements and their corresponding
event handlers should be defined here.
"""
pass

View file

@ -35,6 +35,15 @@ input_hijack = {
multimodal_embedder: MultimodalEmbedder = None
def chat_input_modifier(text, visible_text, state):
global input_hijack
if input_hijack['state']:
input_hijack['state'] = False
return input_hijack['value'](text, visible_text)
else:
return text, visible_text
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)

View file

@ -9,8 +9,6 @@ from modules import chat, shared
from modules.ui import gather_interface_values
from modules.utils import gradio
# If 'state' is True, will hijack the next chat generation with
# custom input text given by 'value' in the format [text, visible_text]
input_hijack = {
'state': False,
'value': ["", ""]
@ -20,6 +18,15 @@ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def chat_input_modifier(text, visible_text, state):
global input_hijack
if input_hijack['state']:
input_hijack['state'] = False
return input_hijack['value']
else:
return text, visible_text
def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100)
@ -42,7 +49,10 @@ def ui():
# Prepare the input hijack, update the interface values, call the generation function, and clear the picture
picture_select.upload(
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
lambda picture, name1, name2: input_hijack.update({
"state": True,
"value": generate_chat_picture(picture, name1, name2)
}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
lambda: None, None, picture_select, show_progress=False)

View file

@ -16,6 +16,15 @@ params = {
}
def chat_input_modifier(text, visible_text, state):
global input_hijack
if input_hijack['state']:
input_hijack['state'] = False
return input_hijack['value']
else:
return text, visible_text
def do_stt(audio, whipser_model, whipser_language):
transcription = ""
r = sr.Recognizer()
@ -56,6 +65,7 @@ def ui():
audio.change(
auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
auto_submit.change(lambda x: params.update({"auto_submit": x}), auto_submit, None)