Better warning messages

This commit is contained in:
oobabooga 2023-05-03 21:43:17 -03:00
parent 0a48b29cd8
commit 95d04d6a8d
13 changed files with 194 additions and 83 deletions

View file

@ -1,29 +1,28 @@
import logging
import math
import sys
from typing import Optional, Tuple
import torch
import torch.nn as nn
import transformers.models.llama.modeling_llama
from typing import Optional
from typing import Tuple
import modules.shared as shared
if shared.args.xformers:
try:
import xformers.ops
except Exception:
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
def hijack_llama_attention():
if shared.args.xformers:
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
print("Replaced attention with xformers_attention")
logging.info("Replaced attention with xformers_attention")
elif shared.args.sdp_attention:
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
print("Replaced attention with sdp_attention")
logging.info("Replaced attention with sdp_attention")
def xformers_forward(
@ -55,16 +54,14 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None
#We only apply xformers optimizations if we don't need to output the whole attention matrix
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
dtype = query_states.dtype
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
@ -102,9 +99,7 @@ def xformers_forward(
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
@ -137,7 +132,7 @@ def sdp_attention_forward(
past_key_value = (key_states, value_states) if use_cache else None
#We only apply sdp attention if we don't need to output the whole attention matrix
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
attn_weights = None