Prevent unwanted log messages from modules

This commit is contained in:
oobabooga 2023-05-21 22:42:34 -03:00
parent fb91406e93
commit e116d31180
20 changed files with 120 additions and 111 deletions

View file

@ -1,4 +1,3 @@
import logging
import math
import sys
from typing import Optional, Tuple
@ -8,21 +7,22 @@ import torch.nn as nn
import transformers.models.llama.modeling_llama
import modules.shared as shared
from modules.logging_colors import logger
if shared.args.xformers:
try:
import xformers.ops
except Exception:
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
logger.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
logging.info("Replaced attention with xformers_attention")
logger.info("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
logging.info("Replaced attention with sdp_attention")
logger.info("Replaced attention with sdp_attention")
def xformers_forward(