extensions/api: models api for blocking_api (updated) (#2539)

This commit is contained in:
matatonic 2023-06-08 10:34:36 -04:00 committed by GitHub
parent 084b006cfe
commit 7be6fe126b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 256 additions and 2 deletions

View file

@ -56,7 +56,12 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
return shared.model.model.embed_tokens(input_ids).to(shared.model.device, dtype=shared.model.dtype)
if hasattr(shared.model.model, 'embed_tokens'):
func = shared.model.model.embed_tokens
else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)
@staticmethod
def placeholder_embeddings() -> torch.Tensor: