Prevent unwanted log messages from modules
This commit is contained in:
parent
fb91406e93
commit
e116d31180
20 changed files with 120 additions and 111 deletions
|
@ -1,16 +1,17 @@
|
|||
import logging
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modules import shared
|
||||
from modules.text_generation import encode
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.text_generation import encode
|
||||
|
||||
|
||||
class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||
CLIP_REPO = "openai/clip-vit-large-patch14"
|
||||
|
@ -26,11 +27,11 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
logging.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
logger.info(f"LLaVA - Loading CLIP from {LLaVA_v0_Pipeline.CLIP_REPO} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVA_v0_Pipeline.CLIP_REPO, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
logging.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
logger.info(f"LLaVA - Loading projector from {self.llava_projector_repo()} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(self.llava_projector_repo(), self.llava_projector_filename())
|
||||
mm_projector = torch.nn.Linear(*self.llava_projector_shape())
|
||||
projector_data = torch.load(projector_path)
|
||||
|
@ -38,7 +39,7 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
|||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
logging.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
logger.info(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
@staticmethod
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue