Improve the imports
This commit is contained in:
parent
364529d0c7
commit
7224343a70
10 changed files with 30 additions and 29 deletions
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
from transformers import BlipForConditionalGeneration
|
||||
from transformers import BlipProcessor
|
||||
from transformers import BlipForConditionalGeneration, BlipProcessor
|
||||
|
||||
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
||||
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
|
||||
|
|
|
@ -7,13 +7,12 @@ from datetime import datetime
|
|||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.text_generation import encode
|
||||
from modules.text_generation import generate_reply
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
from PIL import Image
|
||||
from modules.text_generation import encode, generate_reply, get_max_prompt_length
|
||||
|
||||
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
|
||||
import modules.bot_picture as bot_picture
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import modules.shared as shared
|
||||
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
extension_state = {}
|
||||
available_extensions = []
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
|
||||
|
||||
'''
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
|
|
|
@ -4,23 +4,27 @@ import time
|
|||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import modules.shared as shared
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config)
|
||||
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
|
||||
TorchDevice, TorchDisk, TorchMixedDevice,
|
||||
get_opt_config)
|
||||
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
|
||||
from transformers.deepspeed import (HfDeepSpeedConfig,
|
||||
is_deepspeed_zero3_enabled)
|
||||
|
||||
from modules.deepspeed_parameters import generate_ds_config
|
||||
|
||||
# Distributed setup
|
||||
|
|
|
@ -4,9 +4,11 @@ This code was copied from
|
|||
https://github.com/PygmalionAI/gradio-ui/
|
||||
|
||||
'''
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
def __init__(self, sentinel_token_ids: torch.LongTensor,
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
import re
|
||||
import time
|
||||
|
||||
import modules.shared as shared
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from tqdm import tqdm
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_4chan_html
|
||||
from modules.html_generator import generate_basic_html
|
||||
from modules.html_generator import generate_4chan_html, generate_basic_html
|
||||
from modules.models import local_rank
|
||||
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def get_max_prompt_length(tokens):
|
||||
max_length = 2048-tokens
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue