Bump transformers (16-bit llama must be reconverted/redownloaded)
This commit is contained in:
parent
5f4f38ca5d
commit
113f94b61e
3 changed files with 8 additions and 2 deletions
|
@ -10,7 +10,7 @@ import torch
|
|||
import transformers
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig)
|
||||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
@ -172,6 +172,8 @@ def load_model(model_name):
|
|||
# Loading the tokenizer
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
elif type(model) is transformers.LlamaForCausalLM:
|
||||
tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
|
||||
tokenizer.truncation_side = 'left'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue