LLaVA: small fixes (#1664)
* change multimodal projector to the correct one * remove reference to custom stopping strings from readme * fix stopping strings if tokenizer extension adds/removes tokens * add API example * LLaVA 7B just dropped, add to readme that there is no support for it currently
This commit is contained in:
parent
c31b0f15a7
commit
80c2f25131
3 changed files with 56 additions and 31 deletions
|
@ -21,10 +21,16 @@ params = {
|
|||
"clip_device": None,
|
||||
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
|
||||
"clip_bits": 32,
|
||||
# clip repository
|
||||
"clip_repo": "openai/clip-vit-large-patch14",
|
||||
# device to run projector on
|
||||
"projector_device": None,
|
||||
# projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
"projector_bits": 32,
|
||||
# projector repository
|
||||
"projector_repo": "liuhaotian/LLaVA-13b-delta-v0",
|
||||
# file with the projector weights
|
||||
"projector_file": "mm_projector.bin"
|
||||
}
|
||||
|
||||
|
||||
|
@ -49,9 +55,6 @@ class LLaVAEmbedder:
|
|||
IM_PATCH = Token("<im_patch>", 32000)
|
||||
IM_START = Token("<im_start>", 32001)
|
||||
IM_END = Token("<im_end>", 32002)
|
||||
CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14'
|
||||
PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0'
|
||||
PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin'
|
||||
|
||||
def __init__(self):
|
||||
self.clip_device = self._get_device("clip_device")
|
||||
|
@ -71,12 +74,12 @@ class LLaVAEmbedder:
|
|||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
print(f"LLaVA - Loading CLIP from {params['clip_repo']} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(params["clip_repo"], torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE)
|
||||
print(f"LLaVA - Loading projector from {params['projector_repo']} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(params["projector_repo"], params["projector_file"])
|
||||
mm_projector = torch.nn.Linear(1024, 5120)
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue