Added xformers support to Llama (#950)
This commit is contained in:
parent
625d81f495
commit
992663fa20
4 changed files with 185 additions and 0 deletions
|
@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
|||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
|
@ -169,6 +170,10 @@ def load_model(model_name):
|
|||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Hijack attention with xformers
|
||||
if any((shared.args.xformers, shared.args.sdp_attention)):
|
||||
llama_attn_hijack.hijack_llama_attention()
|
||||
|
||||
# Loading the tokenizer
|
||||
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
|
||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue