Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View file

@ -2,11 +2,10 @@ import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler):
def do_GET(self):

View file

@ -5,6 +5,7 @@ from modules import shared
BLOCKING_PORT = 5000
STREAMING_PORT = 5005
def setup():
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)

View file

@ -1,12 +1,12 @@
import json
import asyncio
from websockets.server import serve
import json
from threading import Thread
from modules import shared
from modules.text_generation import generate_reply
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'

View file

@ -1,6 +1,7 @@
import gradio as gr
import os
import gradio as gr
# get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))

View file

@ -1,6 +1,8 @@
import gradio as gr
import logging
import gradio as gr
def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View file

@ -6,10 +6,11 @@ from io import BytesIO
from typing import Any, List, Optional
import torch
from PIL import Image
from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared
from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass

View file

@ -7,6 +7,7 @@ from io import BytesIO
import gradio as gr
import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared

View file

@ -1,11 +1,12 @@
import base64
import json
import numpy as np
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
import numpy as np
from modules import shared
from modules.text_generation import encode, generate_reply
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path.startswith('/v1/models'):
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
"created": created_time,
"model": model, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": "stop",
"index": 0,
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": token_count,

View file

@ -6,12 +6,13 @@ from datetime import date
from pathlib import Path
import gradio as gr
import modules.shared as shared
import requests
import torch
from modules.models import reload_model, unload_model
from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
picture_response = False # specifies if the next model response should appear as a picture
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
@ -122,7 +124,6 @@ def input_modifier(string):
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
global params
if params['manage_VRAM']:
@ -259,6 +260,7 @@ def SD_api_address_update(address):
return gr.Textbox.update(label=msg)
def ui():
# Gradio elements
@ -290,12 +292,11 @@ def ui():
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
with gr.Column() as hr_options:
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
# Event functions to update the parameters in the backend
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)

View file

@ -4,6 +4,7 @@ from pathlib import Path
import gradio as gr
import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
@ -216,4 +217,4 @@ def ui():
# Play preview
preview_text.submit(voice_preview, preview_text, preview_audio)
preview_play.click(voice_preview, preview_text, preview_audio)
preview_play.click(voice_preview, preview_text, preview_audio)

View file

@ -2,7 +2,6 @@ import time
from pathlib import Path
import torch
import tts_preprocessor
torch._C._jit_set_profiling_mode(False)

View file

@ -69,7 +69,7 @@ def remove_surrounded_chars(string):
# first this expression will check if there is a string nested exclusively between a alt=
# and a style= string. This would correspond to only a the alt text of an embedded image
# If it matches it will only keep that part as the string, and rend it for further processing
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)

View file

@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result))
return list(map(lambda x: int(x[2:]), result))
def clear(self):
self.collection.delete(ids=self.ids)
@ -162,13 +162,13 @@ def input_modifier(string):
def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '':
chunks = []
for i in range(len(shared.history['internal'])-1):
for i in range(len(shared.history['internal']) - 1):
chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
try:
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
# Sort the history by relevance instead of by chronological order,
# except for the latest message
@ -226,7 +226,7 @@ def ui():
## Chat mode
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
That is, the prompt will include (starting from the end):

View file

@ -1,5 +1,6 @@
import gradio as gr
import speech_recognition as sr
from modules import shared
input_hijack = {