From 763ea3bcb284317c64185a9cb85e03fdf37d4ee1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 17 Sep 2023 09:22:16 -0700 Subject: [PATCH] Improved multimodal error message --- extensions/multimodal/README.md | 8 ++++---- extensions/multimodal/pipelines/llava/llava.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/extensions/multimodal/README.md b/extensions/multimodal/README.md index 10bbc7f..5068103 100644 --- a/extensions/multimodal/README.md +++ b/extensions/multimodal/README.md @@ -11,10 +11,10 @@ https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b To run this extension, download a LLM that supports multimodality, and then start server.py with the appropriate `--multimodal-pipeline` argument. Examples: ``` -python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b --chat -python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b --chat -python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b --chat -python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b --chat +python server.py --model wojtab_llava-7b-v0-4bit-128g --multimodal-pipeline llava-7b +python3 server.py --model wojtab_llava-13b-v0-4bit-128g --multimodal-pipeline llava-13b +python server.py --model anon8231489123_vicuna-13b-GPTQ-4bit-128g --multimodal-pipeline minigpt4-13b +python server.py --model llama-7b-4bit --multimodal-pipeline minigpt4-7b ``` There is built-in support for LLaVA-v0-13B and LLaVA-v0-7b. To install `minigpt4`: diff --git a/extensions/multimodal/pipelines/llava/llava.py b/extensions/multimodal/pipelines/llava/llava.py index eca2be5..306ab22 100644 --- a/extensions/multimodal/pipelines/llava/llava.py +++ b/extensions/multimodal/pipelines/llava/llava.py @@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline): @staticmethod def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor: - if hasattr(shared.model.model, 'embed_tokens'): - func = shared.model.model.embed_tokens + for attr in ['', 'model', 'model.model', 'model.model.model']: + tmp = getattr(shared.model, attr, None) if attr != '' else shared.model + if tmp is not None and hasattr(tmp, 'embed_tokens'): + func = tmp.embed_tokens + break else: - func = shared.model.model.model.embed_tokens # AutoGPTQ case + raise ValueError('The embed_tokens method has not been found for this loader.') return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)