Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions colossalai/kernel/cuda_native/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@
HAS_MEM_EFF_ATTN = True
except ImportError:
HAS_MEM_EFF_ATTN = False
print('please install xformers from https://github.com/facebookresearch/xformers')
raise ImportError('please install xformers from https://github.com/facebookresearch/xformers')

if HAS_MEM_EFF_ATTN:

from typing import Optional

from einops import rearrange
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp
from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)

from .scaled_softmax import AttnMaskType

Expand Down Expand Up @@ -86,11 +91,14 @@ def backward(ctx, grad_output):

class ColoAttention(torch.nn.Module):

def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0):
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None):
super().__init__()
assert embed_dim % num_heads == 0, \
f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})."
self.scale = 1 / math.sqrt(embed_dim // num_heads)
if scale is not None:
self.scale = scale
else:
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout

@staticmethod
Expand All @@ -116,7 +124,7 @@ def forward(self,
bias: Optional[torch.Tensor] = None):
batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1]
attn_bias = None
if attn_mask_type == AttnMaskType.padding: # bert style
if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style
assert attn_mask is not None, \
f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}."
assert attn_mask.dim() == 2, \
Expand All @@ -134,7 +142,10 @@ def forward(self,
if batch_size > 1:
query = rearrange(query, "b s ... -> c (b s) ...", c=1)
key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2)
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
if attn_mask_type == AttnMaskType.padding:
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.paddedcausal:
attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen)
elif attn_mask_type == AttnMaskType.causal: # gpt style
attn_bias = LowerTriangularMask()

Expand All @@ -146,7 +157,7 @@ def forward(self,

out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale)

if attn_mask_type == AttnMaskType.padding and batch_size > 1:
if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1:
out = self.repad(out, q_indices, batch_size, tgt_len)

out = rearrange(out, 'b s h d -> b s (h d)')
Expand Down
5 changes: 3 additions & 2 deletions colossalai/kernel/cuda_native/scaled_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3


class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
Expand Down Expand Up @@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
Expand All @@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0

if self.attn_mask_type == AttnMaskType.causal:
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"

# input is 3D tensor (attn_batches, sq, sk)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_utils/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def test_attention_gpt(proj_shape, dtype=torch.float16):

qkv = c_attn(x)
q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H)
y = attn(q, k, v, attn_mask_type=AttnMaskType.causal)

mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)]
mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True)

y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal)

assert list(y.shape) == [B, S, D]

Expand Down