Improved multimodal error message

This commit is contained in:
oobabooga 2023-09-17 09:22:16 -07:00
parent 37e2980e05
commit 763ea3bcb2
2 changed files with 10 additions and 7 deletions

View file

@ -56,10 +56,13 @@ class LLaVA_v0_Pipeline(AbstractMultimodalPipeline):
@staticmethod
def embed_tokens(input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(shared.model.model, 'embed_tokens'):
func = shared.model.model.embed_tokens
for attr in ['', 'model', 'model.model', 'model.model.model']:
tmp = getattr(shared.model, attr, None) if attr != '' else shared.model
if tmp is not None and hasattr(tmp, 'embed_tokens'):
func = tmp.embed_tokens
break
else:
func = shared.model.model.model.embed_tokens # AutoGPTQ case
raise ValueError('The embed_tokens method has not been found for this loader.')
return func(input_ids).to(shared.model.device, dtype=shared.model.dtype)