diff --git a/colossalai/kernel/cuda_native/flash_attention.py b/colossalai/kernel/cuda_native/flash_attention.py index 3db7374509a0..17046e6f41f3 100644 --- a/colossalai/kernel/cuda_native/flash_attention.py +++ b/colossalai/kernel/cuda_native/flash_attention.py @@ -14,7 +14,7 @@ HAS_MEM_EFF_ATTN = True except ImportError: HAS_MEM_EFF_ATTN = False - print('please install xformers from https://github.com/facebookresearch/xformers') + raise ImportError('please install xformers from https://github.com/facebookresearch/xformers') if HAS_MEM_EFF_ATTN: @@ -22,7 +22,12 @@ from einops import rearrange from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp - from xformers.ops.fmha.attn_bias import BlockDiagonalMask, LowerTriangularMask, LowerTriangularMaskWithTensorBias + from xformers.ops.fmha.attn_bias import ( + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + ) from .scaled_softmax import AttnMaskType @@ -86,11 +91,14 @@ def backward(ctx, grad_output): class ColoAttention(torch.nn.Module): - def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0): + def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, scale=None): super().__init__() assert embed_dim % num_heads == 0, \ f"the embed dim ({embed_dim}) is not divisible by the number of attention heads ({num_heads})." - self.scale = 1 / math.sqrt(embed_dim // num_heads) + if scale is not None: + self.scale = scale + else: + self.scale = 1 / math.sqrt(embed_dim // num_heads) self.dropout = dropout @staticmethod @@ -116,7 +124,7 @@ def forward(self, bias: Optional[torch.Tensor] = None): batch_size, tgt_len, src_len = query.shape[0], query.shape[1], key.shape[1] attn_bias = None - if attn_mask_type == AttnMaskType.padding: # bert style + if attn_mask_type and attn_mask_type.value % 2 == 1: # bert style assert attn_mask is not None, \ f"attention mask {attn_mask} is not valid for attention mask type {attn_mask_type}." assert attn_mask.dim() == 2, \ @@ -134,7 +142,10 @@ def forward(self, if batch_size > 1: query = rearrange(query, "b s ... -> c (b s) ...", c=1) key, value = self.unpad(torch.stack([query, key, value], dim=2), kv_indices).unbind(dim=2) - attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + if attn_mask_type == AttnMaskType.padding: + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + elif attn_mask_type == AttnMaskType.paddedcausal: + attn_bias = BlockDiagonalCausalMask.from_seqlens(q_seqlen, kv_seqlen) elif attn_mask_type == AttnMaskType.causal: # gpt style attn_bias = LowerTriangularMask() @@ -146,7 +157,7 @@ def forward(self, out = memory_efficient_attention(query, key, value, attn_bias=attn_bias, p=self.dropout, scale=self.scale) - if attn_mask_type == AttnMaskType.padding and batch_size > 1: + if attn_mask_type and attn_mask_type.value % 2 == 1 and batch_size > 1: out = self.repad(out, q_indices, batch_size, tgt_len) out = rearrange(out, 'b s h d -> b s (h d)') diff --git a/colossalai/kernel/cuda_native/scaled_softmax.py b/colossalai/kernel/cuda_native/scaled_softmax.py index 24e458bb3ea5..41cd4b20faa1 100644 --- a/colossalai/kernel/cuda_native/scaled_softmax.py +++ b/colossalai/kernel/cuda_native/scaled_softmax.py @@ -19,6 +19,7 @@ class AttnMaskType(enum.Enum): padding = 1 causal = 2 + paddedcausal = 3 class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): @@ -139,7 +140,7 @@ def is_kernel_available(self, mask, b, np, sq, sk): if 0 <= sk <= 2048: batch_per_block = self.get_batch_per_block(sq, sk, b, np) - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: if attn_batches % batch_per_block == 0: return True else: @@ -151,7 +152,7 @@ def forward_fused_softmax(self, input, mask): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 - if self.attn_mask_type == AttnMaskType.causal: + if self.attn_mask_type.value > 1: assert sq == sk, "causal mask is only for self attention" # input is 3D tensor (attn_batches, sq, sk) diff --git a/tests/test_utils/test_flash_attention.py b/tests/test_utils/test_flash_attention.py index 2334c84dc778..02c6f5bd44c6 100644 --- a/tests/test_utils/test_flash_attention.py +++ b/tests/test_utils/test_flash_attention.py @@ -36,7 +36,11 @@ def test_attention_gpt(proj_shape, dtype=torch.float16): qkv = c_attn(x) q, k, v = rearrange(qkv, 'b s (n h d) -> n b s h d', n=3, h=H) - y = attn(q, k, v, attn_mask_type=AttnMaskType.causal) + + mask = [torch.ones(S - i, dtype=dtype, device="cuda") for i in range(B)] + mask = torch.nn.utils.rnn.pad_sequence(mask, batch_first=True) + + y = attn(q, k, v, attn_mask=mask, attn_mask_type=AttnMaskType.paddedcausal) assert list(y.shape) == [B, S, D]