Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions colossalai/shardformer/modeling/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,6 @@ def t5_encoder_model_forward(


def get_t5_flash_attention_forward():
try:
from xformers.ops import memory_efficient_attention as me_attention
except:
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
from transformers.models.t5.modeling_t5 import T5Attention

def forward(
Expand Down Expand Up @@ -632,11 +628,11 @@ def forward(

def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
return states.view(batch_size, -1, self.inner_dim)
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
Expand All @@ -653,8 +649,8 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=1)
elif past_key_value.shape[1] != key_value_states.shape[1]:
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
Expand Down Expand Up @@ -701,10 +697,15 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value):
else:
position_bias_masked = position_bias

position_bias_masked = position_bias_masked.contiguous()
attn_output = me_attention(
query_states, key_states, value_states, attn_bias=position_bias_masked, p=self.dropout, scale=1.0
)
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=position_bias_masked,
dropout_p=self.dropout,
scale=1.0,
)
attn_output = unshape(attn_output)
attn_output = self.o(attn_output)

Expand Down