Adapt to the new model names
This commit is contained in:
parent
0345e04249
commit
1cb9246160
6 changed files with 18 additions and 25 deletions
|
@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
|
|||
def load_quantized(model_name):
|
||||
if not shared.args.model_type:
|
||||
# Try to determine model type from model name
|
||||
if model_name.lower().startswith(('llama', 'alpaca')):
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca'])):
|
||||
model_type = 'llama'
|
||||
elif model_name.lower().startswith(('opt', 'galactica')):
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
|
||||
elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
|
||||
model_type = 'gptj'
|
||||
else:
|
||||
print("Can't determine model type from model name. Please specify it manually using --model_type "
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue