LLaVA support (#1487)

This commit is contained in:
Wojtab 2023-04-24 01:32:22 +02:00 committed by GitHub
parent 9197d3fec8
commit 12212cf6be
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 426 additions and 42 deletions

View file

@ -50,6 +50,8 @@ def find_model_type(model_name):
return 'chatglm'
elif 'galactica' in model_name:
return 'galactica'
elif 'llava' in model_name:
return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
@ -217,11 +219,12 @@ def load_model(model_name):
tokenizer = None
# Try to load an universal LLaMA tokenizer
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
if shared.model_type != 'llava':
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...")
tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True)
break
# Otherwise, load it from the model folder and hope that these
# are not outdated tokenizer files.