Added xformers support to Llama (#950)

This commit is contained in:
MarkovInequality 2023-04-09 22:08:40 -04:00 committed by GitHub
parent 625d81f495
commit 992663fa20
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 185 additions and 0 deletions

View file

@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
from modules import llama_attn_hijack
transformers.logging.set_verbosity_error()
@ -169,6 +170,10 @@ def load_model(model_name):
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Hijack attention with xformers
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
# Loading the tokenizer
if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))