Better warning messages
This commit is contained in:
parent
0a48b29cd8
commit
95d04d6a8d
13 changed files with 194 additions and 83 deletions
|
@ -1,29 +1,28 @@
|
|||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
if shared.args.xformers:
|
||||
try:
|
||||
import xformers.ops
|
||||
except Exception:
|
||||
print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||
logging.error("xformers not found! Please install it before trying to use it.", file=sys.stderr)
|
||||
|
||||
|
||||
def hijack_llama_attention():
|
||||
if shared.args.xformers:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||
print("Replaced attention with xformers_attention")
|
||||
logging.info("Replaced attention with xformers_attention")
|
||||
elif shared.args.sdp_attention:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
|
||||
print("Replaced attention with sdp_attention")
|
||||
logging.info("Replaced attention with sdp_attention")
|
||||
|
||||
|
||||
def xformers_forward(
|
||||
|
@ -55,16 +54,14 @@ def xformers_forward(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
#We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
dtype = query_states.dtype
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
#This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
#We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
|
||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||
attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
|
||||
|
@ -102,9 +99,7 @@ def xformers_forward(
|
|||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
|
@ -137,7 +132,7 @@ def sdp_attention_forward(
|
|||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
#We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
|
||||
attn_weights = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue