LLaVA support (#1487)
This commit is contained in:
parent
9197d3fec8
commit
12212cf6be
12 changed files with 426 additions and 42 deletions
|
@ -50,6 +50,8 @@ def find_model_type(model_name):
|
|||
return 'chatglm'
|
||||
elif 'galactica' in model_name:
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name:
|
||||
return 'llava'
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
|
@ -217,11 +219,12 @@ def load_model(model_name):
|
|||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
if shared.model_type != 'llava':
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
|
||||
break
|
||||
|
||||
# Otherwise, load it from the model folder and hope that these
|
||||
# are not outdated tokenizer files.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue