Style improvements (#1957)
This commit is contained in:
parent
334486f527
commit
3913155c1f
23 changed files with 64 additions and 50 deletions
|
@ -2,11 +2,10 @@ import json
|
|||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
|
|
@ -5,6 +5,7 @@ from modules import shared
|
|||
BLOCKING_PORT = 5000
|
||||
STREAMING_PORT = 5005
|
||||
|
||||
|
||||
def setup():
|
||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import json
|
||||
import asyncio
|
||||
from websockets.server import serve
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
from websockets.server import serve
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# get the current directory of the script
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import gradio as gr
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def ui():
|
||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||
|
|
|
@ -6,10 +6,11 @@ from io import BytesIO
|
|||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -7,6 +7,7 @@ from io import BytesIO
|
|||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
|
|||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path.startswith('/v1/models'):
|
||||
|
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
"created": created_time,
|
||||
"model": model, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
|
|
|
@ -6,12 +6,13 @@ from datetime import date
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
import requests
|
||||
import torch
|
||||
from modules.models import reload_model, unload_model
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import reload_model, unload_model
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
|
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
|
|||
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
|
@ -122,7 +124,6 @@ def input_modifier(string):
|
|||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params
|
||||
|
||||
if params['manage_VRAM']:
|
||||
|
@ -259,6 +260,7 @@ def SD_api_address_update(address):
|
|||
|
||||
return gr.Textbox.update(label=msg)
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
|
@ -290,12 +292,11 @@ def ui():
|
|||
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
|
||||
with gr.Column() as hr_options:
|
||||
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
|
||||
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
||||
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
||||
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
|
||||
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||
|
||||
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||
|
|
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
|
@ -216,4 +217,4 @@ def ui():
|
|||
|
||||
# Play preview
|
||||
preview_text.submit(voice_preview, preview_text, preview_audio)
|
||||
preview_play.click(voice_preview, preview_text, preview_audio)
|
||||
preview_play.click(voice_preview, preview_text, preview_audio)
|
||||
|
|
|
@ -2,7 +2,6 @@ import time
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
import tts_preprocessor
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
|
|
@ -69,7 +69,7 @@ def remove_surrounded_chars(string):
|
|||
# first this expression will check if there is a string nested exclusively between a alt=
|
||||
# and a style= string. This would correspond to only a the alt text of an embedded image
|
||||
# If it matches it will only keep that part as the string, and rend it for further processing
|
||||
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
||||
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
||||
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
|
||||
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)
|
||||
|
|
|
@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
|
|||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x : int(x[2:]), result))
|
||||
return list(map(lambda x: int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
|
@ -162,13 +162,13 @@ def input_modifier(string):
|
|||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
if len(shared.history['internal']) > 2 and user_input != '':
|
||||
chunks = []
|
||||
for i in range(len(shared.history['internal'])-1):
|
||||
for i in range(len(shared.history['internal']) - 1):
|
||||
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||
|
||||
add_chunks_to_collector(chunks)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
|
||||
|
||||
# Sort the history by relevance instead of by chronological order,
|
||||
# except for the latest message
|
||||
|
@ -226,7 +226,7 @@ def ui():
|
|||
|
||||
## Chat mode
|
||||
|
||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||
|
||||
That is, the prompt will include (starting from the end):
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import gradio as gr
|
||||
import speech_recognition as sr
|
||||
|
||||
from modules import shared
|
||||
|
||||
input_hijack = {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue