Improve the imports

This commit is contained in:
oobabooga 2023-02-23 14:41:42 -03:00
parent 364529d0c7
commit 7224343a70
10 changed files with 30 additions and 29 deletions

View file

@ -1,6 +1,5 @@
import torch
from transformers import BlipForConditionalGeneration
from transformers import BlipProcessor
from transformers import BlipForConditionalGeneration, BlipProcessor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")

View file

@ -7,13 +7,12 @@ from datetime import datetime
from io import BytesIO
from pathlib import Path
from PIL import Image
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode
from modules.text_generation import generate_reply
from modules.text_generation import get_max_prompt_length
from PIL import Image
from modules.text_generation import encode, generate_reply, get_max_prompt_length
if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
import modules.bot_picture as bot_picture

View file

@ -1,6 +1,5 @@
import modules.shared as shared
import extensions
import modules.shared as shared
extension_state = {}
available_extensions = []

View file

@ -3,6 +3,7 @@
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
'''
import base64
import os
import re

View file

@ -4,23 +4,27 @@ import time
import zipfile
from pathlib import Path
import modules.shared as shared
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
import modules.shared as shared
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, get_opt_config)
from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
TorchDevice, TorchDisk, TorchMixedDevice,
get_opt_config)
if shared.args.deepspeed:
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
from transformers.deepspeed import (HfDeepSpeedConfig,
is_deepspeed_zero3_enabled)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup

View file

@ -4,9 +4,11 @@ This code was copied from
https://github.com/PygmalionAI/gradio-ui/
'''
import torch
import transformers
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
def __init__(self, sentinel_token_ids: torch.LongTensor,

View file

@ -1,16 +1,17 @@
import re
import time
import modules.shared as shared
import numpy as np
import torch
import transformers
from tqdm import tqdm
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_4chan_html
from modules.html_generator import generate_basic_html
from modules.html_generator import generate_4chan_html, generate_basic_html
from modules.models import local_rank
from modules.stopping_criteria import _SentinelTokenStoppingCriteria
from tqdm import tqdm
def get_max_prompt_length(tokens):
max_length = 2048-tokens