Adapt to the new model names

This commit is contained in:
oobabooga 2023-03-29 21:47:36 -03:00
parent 0345e04249
commit 1cb9246160
6 changed files with 18 additions and 25 deletions

View file

@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
def load_quantized(model_name):
if not shared.args.model_type:
# Try to determine model type from model name
if model_name.lower().startswith(('llama', 'alpaca')):
name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca'])):
model_type = 'llama'
elif model_name.lower().startswith(('opt', 'galactica')):
elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt'
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
model_type = 'gptj'
else:
print("Can't determine model type from model name. Please specify it manually using --model_type "