Merge branch 'main' into dev
This commit is contained in:
commit
878250d609
1 changed files with 6 additions and 1 deletions
|
@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
if hasattr(shared.model.model, 'embed_tokens'):
|
||||||
|
func = shared.model.model.embed_tokens
|
||||||
|
else:
|
||||||
|
func = shared.model.model.model.embed_tokens # AutoGPTQ case
|
||||||
|
|
||||||
|
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def placeholder_embeddings() -> torch.Tensor:
|
def placeholder_embeddings() -> torch.Tensor:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue