LLaVA: small fixes (#1664)

* change multimodal projector to the correct one

* remove reference to custom stopping strings from readme

* fix stopping strings if tokenizer extension adds/removes tokens

* add API example

* LLaVA 7B just dropped, add to readme that there is no support for it currently
This commit is contained in:
Wojtab 2023-05-03 04:12:22 +02:00 committed by GitHub
parent c31b0f15a7
commit 80c2f25131
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 31 deletions

View file

@ -21,10 +21,16 @@ params = {
"clip_device": None,
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
"clip_bits": 32,
# clip repository
"clip_repo": "openai/clip-vit-large-patch14",
# device to run projector on
"projector_device": None,
# projector bits, either 32 or 16
"projector_bits": 32
"projector_bits": 32,
# projector repository
"projector_repo": "liuhaotian/LLaVA-13b-delta-v0",
# file with the projector weights
"projector_file": "mm_projector.bin"
}
@ -49,9 +55,6 @@ class LLaVAEmbedder:
IM_PATCH = Token("<im_patch>", 32000)
IM_START = Token("<im_start>", 32001)
IM_END = Token("<im_end>", 32002)
CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14'
PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0'
PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin'
def __init__(self):
self.clip_device = self._get_device("clip_device")
@ -71,12 +74,12 @@ class LLaVAEmbedder:
def _load_models(self):
start_ts = time.time()
print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device)
print(f"LLaVA - Loading CLIP from {params['clip_repo']} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype).to(self.clip_device)
print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE)
print(f"LLaVA - Loading projector from {params['projector_repo']} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(params["projector_repo"], params["projector_file"])
mm_projector = torch.nn.Linear(1024, 5120)
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)