Add some checks to AutoGPTQ loader

This commit is contained in:
oobabooga 2023-06-14 18:44:43 -03:00
parent 134430bbe2
commit 4d508cbe58

View file

@ -53,9 +53,12 @@ def load_quantized(model_name):
model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params) model = AutoGPTQForCausalLM.from_quantized(path_to_model, **params)
# These lines fix the multimodal extension when used with AutoGPTQ # These lines fix the multimodal extension when used with AutoGPTQ
if hasattr(model, 'model'):
if not hasattr(model, 'dtype'): if not hasattr(model, 'dtype'):
if hasattr(model.model, 'dtype'):
model.dtype = model.model.dtype model.dtype = model.model.dtype
if hasattr(model.model, 'model') and hasattr(model.model.model, 'embed_tokens'):
if not hasattr(model, 'embed_tokens'): if not hasattr(model, 'embed_tokens'):
model.embed_tokens = model.model.model.embed_tokens model.embed_tokens = model.model.model.embed_tokens