Move "send picture" into an extension

I am not proud of how I did it for now.
This commit is contained in:
oobabooga 2023-02-25 00:23:51 -03:00
parent e51ece21c0
commit 7a527a5581
4 changed files with 34 additions and 49 deletions

View file

@ -4,19 +4,16 @@ import io
import json
import re
from datetime import datetime
from io import BytesIO
from pathlib import Path
from PIL import Image
import modules.shared as shared
import modules.extensions as extensions_module
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture
# This gets the new line characters right.
def clean_chat_message(text):
text = text.replace('\n', '\n\n')
@ -84,18 +81,10 @@ def extract_message_from_reply(question, reply, current, other, check, extension
return reply, next_character_found, substring_found
def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
return text, visible_text
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
@ -103,13 +92,18 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
if 'pygmalion' in shared.model_name.lower():
name1 = "You"
if shared.args.picture and picture is not None:
text, visible_text = generate_chat_picture(picture, name1, name2)
else:
# Hijacking the input using an extension
visible_text = None
for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
text, visible_text = extension.input_hijack['value']
extension.input_hijack['state'] = False
if visible_text is None:
visible_text = text
if shared.args.chat:
visible_text = visible_text.replace('\n', '<br>')
text = apply_extensions(text, "input")
text = apply_extensions(text, "input")
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
# Generate
@ -138,7 +132,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
break
yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower():
@ -154,11 +148,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
break
yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
yield generate_chat_html(_history, name1, name2, shared.character)
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
if shared.character != 'None' and len(shared.history['visible']) == 1:
if shared.args.cai_chat:
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
@ -168,7 +162,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_visible = shared.history['visible'].pop()
last_internal = shared.history['internal'].pop()
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)

View file

@ -1,5 +1,3 @@
import gradio as gr
import extensions
import modules.shared as shared

View file

@ -14,6 +14,9 @@ stop_everything = False
# UI elements (buttons, sliders, HTML, etc)
gradio = {}
# Generation input parameters
input_params = []
settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
@ -40,7 +43,6 @@ parser.add_argument('--model', type=str, help='Name of the model to load by defa
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
parser.add_argument('--picture', action='store_true', help='Adds an ability to send pictures in chat UI modes. Captions are generated by BLIP.')
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')