Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/kernel/kernel_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class FlashAttentionLoader(KernelLoader):
]


class FlashAttentionDaoLoader(KernelLoader):
REGISTRY = [FlashAttentionDaoCudaExtension]


class FlashAttentionWithCustomMaskLoader(KernelLoader):
REGISTRY = [FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension]

Expand Down
48 changes: 38 additions & 10 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from einops import rearrange

from colossalai.kernel.kernel_loader import (
FlashAttentionDaoLoader,
FlashAttentionForFloatAndCustomMaskLoader,
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
Expand All @@ -17,6 +18,8 @@

from .utils import RingComm, get_half_index, split_varlen_zigzag

MEMORY_BOUND = 10 * 1e9

__all__ = [
"AttnMaskType",
"ColoAttention",
Expand Down Expand Up @@ -77,6 +80,7 @@ def get_pad_info(

class ColoAttention:
_kernel_dispatch_map: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None
_flash_kernel_dispatch: Optional[Dict[torch.dtype, Dict[Optional[AttnMaskType], Callable]]] = None

@staticmethod
def _init_kernels_dispatch():
Expand All @@ -102,9 +106,11 @@ def _init_kernels_dispatch():
torch.bfloat16: half_dispatch_map,
torch.float32: float_dispatch_map,
}
if ColoAttention._flash_kernel_dispatch is None:
ColoAttention._flash_kernel_dispatch = FlashAttentionDaoLoader()

@staticmethod
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> Callable:
def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType], size) -> Callable:
ColoAttention._init_kernels_dispatch()
if (
dtype not in ColoAttention._kernel_dispatch_map
Expand All @@ -113,12 +119,19 @@ def _dispatch_kernel(dtype: torch.dtype, mask_type: Optional[AttnMaskType]) -> C
raise ValueError(
"FlashAttention kernel is not available for dtype {} and mask_type {}".format(dtype, mask_type)
)

if size >= MEMORY_BOUND:
ColoAttention._flash_kernel_dispatch = ColoAttention._flash_kernel_dispatch.load()
# lazy load
if isinstance(ColoAttention._kernel_dispatch_map[dtype][mask_type], KernelLoader):
ColoAttention._kernel_dispatch_map[dtype][mask_type] = ColoAttention._kernel_dispatch_map[dtype][
mask_type
].load()
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

if size >= MEMORY_BOUND and mask_type in (AttnMaskType.PADDED_CAUSAL, AttnMaskType.CAUSAL):
return ColoAttention._flash_kernel_dispatch
else:
return ColoAttention._kernel_dispatch_map[dtype][mask_type]

@staticmethod
def prepare_attn_kwargs(
Expand Down Expand Up @@ -154,17 +167,22 @@ def prepare_attn_kwargs(
return {}
assert len(shape_4d) == 4 and shape_4d[1] == 1
b, _, s_q, s_kv = shape_4d
element_size = torch.tensor([], dtype=dtype).element_size()
memory_size = s_q * s_kv * element_size
outputs = {}
if (q_padding_mask is None or q_padding_mask.bool().all()) and (
kv_padding_mask is None or kv_padding_mask.bool().all()
):
# no padding
assert is_causal
outputs["attention_mask_type"] = AttnMaskType.CAUSAL
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask = attention_mask.tril(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
if memory_size < MEMORY_BOUND:
attention_mask = torch.ones(s_q, s_kv, dtype=dtype, device=device)
if s_q != 1:
attention_mask.tril_(diagonal=0)
attention_mask = attention_mask.expand(b, s_q, s_kv)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
max_seqlen_q, cu_seqlens_q, q_indices = get_pad_info(q_padding_mask)
if kv_padding_mask is None:
Expand All @@ -177,7 +195,6 @@ def prepare_attn_kwargs(
b,
s_kv,
), f"Padding mask shape {kv_padding_mask.shape} should align with shape 4d ({b}, {s_kv})"
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
outputs.update(
{
"cu_seqlens_q": cu_seqlens_q,
Expand All @@ -190,10 +207,17 @@ def prepare_attn_kwargs(
)
if is_causal:
outputs["attention_mask_type"] = AttnMaskType.PADDED_CAUSAL
if s_q != 1:
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
if memory_size < MEMORY_BOUND:
if s_q != 1:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)
attention_mask = attention_mask * attention_mask.new_ones(s_q, s_kv).tril(diagonal=0)
else:
attention_mask = torch.empty((0,), dtype=dtype, device=device)
else:
outputs["attention_mask_type"] = AttnMaskType.PADDED
if memory_size < MEMORY_BOUND:
attention_mask = kv_padding_mask[:, None, :].expand(b, s_q, s_kv).to(dtype=dtype, device=device)

if invert:
attention_mask = invert_mask(attention_mask).unsqueeze(1)
outputs["attention_mask"] = attention_mask
Expand Down Expand Up @@ -278,8 +302,12 @@ def attention(
assert attention_mask_type == AttnMaskType.CUSTOM

# kernel dispatch
b, _, s_q, _ = q.shape
b, _, s_kv, _ = v.shape
element_size = torch.tensor([], dtype=q.dtype).element_size()
memory_size = s_q * s_kv * element_size
mask_type = attention_mask_type if attention_mask is not None else None
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type)
attn_func = ColoAttention._dispatch_kernel(q.dtype, mask_type, memory_size)
is_causal = attention_mask is not None and attention_mask_type in (
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
Expand Down