Add some checks to AutoGPTQ loader

This commit is contained in:
oobabooga 2023-06-14 18:44:43 -03:00
parent 134430bbe2
commit 4d508cbe58

View file

@ -53,13 +53,16 @@ def load_quantized(model_name):
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
# These lines fix the multimodal extension when used with AutoGPTQ # These lines fix the multimodal extension when used with AutoGPTQ
if not hasattr(model, 'dtype'): if hasattr(model, 'model'):
model.dtype = model.model.dtype if not hasattr(model, 'dtype'):
if hasattr(model.model, 'dtype'):
model.dtype = model.model.dtype
if not hasattr(model, 'embed_tokens'): if hasattr(model.model, 'model') and hasattr(model.model.model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens if not hasattr(model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens
if not hasattr(model.model, 'embed_tokens'): if not hasattr(model.model, 'embed_tokens'):
model.model.embed_tokens = model.model.model.embed_tokens model.model.embed_tokens = model.model.model.embed_tokens
return model return model