Generalize multimodality (llava/minigpt4 7b and 13b now supported) (#1741)
This commit is contained in:
parent
a2b25322f0
commit
e9e75a9ec7
22 changed files with 812 additions and 371 deletions
52
extensions/multimodal/pipeline_loader.py
Normal file
52
extensions/multimodal/pipeline_loader.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
from extensions.multimodal.abstract_pipeline import AbstractMultimodalPipeline
|
||||
from modules import shared
|
||||
|
||||
|
||||
def _get_available_pipeline_modules():
|
||||
pipeline_path = Path(__file__).parent / 'pipelines'
|
||||
modules = [p for p in pipeline_path.iterdir() if p.is_dir()]
|
||||
return [m.name for m in modules if (m / 'pipelines.py').exists()]
|
||||
|
||||
|
||||
def load_pipeline(params: dict) -> Tuple[AbstractMultimodalPipeline, str]:
|
||||
pipeline_modules = {}
|
||||
available_pipeline_modules = _get_available_pipeline_modules()
|
||||
for name in available_pipeline_modules:
|
||||
try:
|
||||
pipeline_modules[name] = import_module(f'extensions.multimodal.pipelines.{name}.pipelines')
|
||||
except:
|
||||
logging.warning(f'Failed to get multimodal pipelines from {name}')
|
||||
logging.warning(traceback.format_exc())
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline')(shared.args.multimodal_pipeline, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
else:
|
||||
model_name = shared.args.model.lower()
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'get_pipeline_from_model_name'):
|
||||
pipeline = getattr(pipeline_modules[k], 'get_pipeline_from_model_name')(model_name, params)
|
||||
if pipeline is not None:
|
||||
return (pipeline, k)
|
||||
|
||||
available = []
|
||||
for k in pipeline_modules:
|
||||
if hasattr(pipeline_modules[k], 'available_pipelines'):
|
||||
pipelines = getattr(pipeline_modules[k], 'available_pipelines')
|
||||
available += pipelines
|
||||
|
||||
if shared.args.multimodal_pipeline is not None:
|
||||
log = f'Multimodal - ERROR: Failed to load multimodal pipeline "{shared.args.multimodal_pipeline}", available pipelines are: {available}.'
|
||||
else:
|
||||
log = f'Multimodal - ERROR: Failed to determine multimodal pipeline for model {shared.args.model}, please select one manually using --multimodal-pipeline [PIPELINE]. Available pipelines are: {available}.'
|
||||
logging.critical(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||
raise RuntimeError(f'{log} Please specify a correct pipeline, or disable the extension')
|
||||
Loading…
Add table
Add a link
Reference in a new issue