Prevent unwanted log messages from modules
This commit is contained in:
parent
fb91406e93
commit
e116d31180
20 changed files with 120 additions and 111 deletions
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import Optional, Tuple
|
||||
|
@ -8,21 +7,22 @@ import torch.nn as nn
|
|||
import transformers.models.llama.modeling_llama
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
if shared.args.xformers:
|
||||
try:
|
||||
import xformers.ops
|
||||
except Exception:
|
||||
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||
logger.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
if shared.args.xformers:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
logging.info("Replaced attention with xformers_attention")
|
||||
logger.info("Replaced attention with xformers_attention")
|
||||
elif shared.args.sdp_attention:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||
logging.info("Replaced attention with sdp_attention")
|
||||
logger.info("Replaced attention with sdp_attention")
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue