Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions aiter/ops/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def cmdGenFunc_mha_fwd(
q_descale: Optional[Tensor] = None,
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[Generator] = None,
):
(_, seqlen_q, _, _) = q.shape
Expand Down Expand Up @@ -175,6 +176,7 @@ def gen_mha_fwd_fake_tensors(
q_descale: Optional[Tensor] = None,
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[torch.Generator] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return common_mha_fwd_fake_tensors(
Expand Down Expand Up @@ -208,6 +210,7 @@ def mha_fwd(
q_descale: Optional[Tensor] = None,
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...

Expand Down Expand Up @@ -285,6 +288,7 @@ def cmdGenFunc_mha_varlen_fwd(
gen: Optional[torch.Generator] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[torch.Tensor] = None,
):
# causal=true is the same as causal=false in this case
causal = is_causal
Expand Down Expand Up @@ -445,6 +449,7 @@ def gen_mha_varlen_fwd_fake_tensor(
gen: Optional[torch.Generator] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
device = q.device
dtype = q.dtype
Expand Down Expand Up @@ -513,6 +518,7 @@ def mha_varlen_fwd(
gen: Optional[torch.Generator] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -631,6 +637,7 @@ def cmdGenFunc_mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
):
md_name = "mha_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -781,6 +788,7 @@ def gen_mha_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv)

Expand Down Expand Up @@ -812,6 +820,7 @@ def mha_bwd(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -893,6 +902,7 @@ def cmdGenFunc_mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> dict[str, Any]:
md_name = "mha_varlen_bwd"
filter1 = "*" # get_bwd_dot_do_o_blobs()
Expand Down Expand Up @@ -972,6 +982,7 @@ def cmdGenFunc_mha_batch_prefill(
q_descale: Optional[Tensor] = None,
k_descale: Optional[Tensor] = None,
v_descale: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[Generator] = None,
):
# causal=true is the same as causal=false in this case
Expand Down Expand Up @@ -1104,6 +1115,7 @@ def gen_mha_varlen_bwd_fake_tensors(
alibi_slopes: Optional[Tensor] = None,
rng_state: Optional[Tensor] = None,
gen: Optional[Generator] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return gen_mha_varlen_bwd_fake_tensors_common(
q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv
Expand Down Expand Up @@ -1143,6 +1155,7 @@ def mha_varlen_bwd(
gen: Optional[Generator] = None,
cu_seqlens_q_padded: Optional[Tensor] = None,
cu_seqlens_k_padded: Optional[Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...


Expand Down Expand Up @@ -1241,11 +1254,16 @@ def _flash_attn_forward(
how_v3_bf16_cvt: Optional[int] = 1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

(_, seqlen_q, nhead_q, hdim_q) = q.shape
(_, seqlen_k, nhead_k, hdim_v) = v.shape

if sink_ptr is not None:
assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q"
assert sink_ptr.shape[0] == nhead_q, "sink_ptr has incorrect shape"
if sink_ptr.dtype != torch.float32:
sink_ptr = sink_ptr.to(torch.float32)
# mask
window_size_left = -1 if window_size_left >= seqlen_k else window_size_left
window_size_right = -1 if window_size_right >= seqlen_k else window_size_right
Expand Down Expand Up @@ -1319,6 +1337,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]):
q_descale,
k_descale,
v_descale,
sink_ptr,
None,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
Expand Down Expand Up @@ -1527,6 +1546,7 @@ def _flash_attn_backward_fake(
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
sink_ptr: Optional[Tensor] = None,
) -> torch.Tensor:
batch_size = q.size(0)
seqlen_q = q.size(1)
Expand Down Expand Up @@ -1563,6 +1583,7 @@ def _flash_attn_backward(
rng_state: Optional[torch.Tensor] = None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
sink_ptr: Optional[Tensor] = None,
) -> torch.Tensor:
# rtna & rtz are deprecated in gfx950
if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0:
Expand Down Expand Up @@ -1681,6 +1702,7 @@ def can_impl_fmha_v3_bwd_gfx950():
alibi_slopes,
rng_state,
None,
sink_ptr,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
Expand All @@ -1707,6 +1729,7 @@ def forward(
how_v3_bf16_cvt: Optional[int] = 1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
Expand Down Expand Up @@ -1738,6 +1761,7 @@ def forward(
how_v3_bf16_cvt=how_v3_bf16_cvt,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
sink_ptr=sink_ptr,
)
if is_grad:
assert return_lse
Expand Down Expand Up @@ -1795,6 +1819,7 @@ def backward(ctx, dout, *args):
rng_state,
ctx.is_v3_atomic_fp32,
ctx.how_v3_bf16_cvt,
sink_ptr=None,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
Expand Down Expand Up @@ -1836,6 +1861,7 @@ def backward(ctx, dout, *args):
None, # how_v3_bf16_cvt
None, # cu_seqlens_q
None, # cu_seqlens_kv
None, # sink_ptr
)


Expand All @@ -1855,6 +1881,7 @@ def flash_attn_func(
how_v3_bf16_cvt=1,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -1925,6 +1952,7 @@ def flash_attn_func(
how_v3_bf16_cvt,
cu_seqlens_q,
cu_seqlens_kv,
sink_ptr,
)


Expand Down Expand Up @@ -1957,12 +1985,18 @@ def _flash_attn_varlen_forward(
block_table: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
zero_tensors: bool = False,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

(_, nhead_q, hdim_q) = q.shape

nhead_k = v.shape[-2]
hdim_v = v.shape[-1]
if sink_ptr is not None:
assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q"
assert sink_ptr.shape[0] == nhead_q, "sink_ptr has incorrect shape"
if sink_ptr.dtype != torch.float32:
sink_ptr = sink_ptr.to(torch.float32)
# mask
window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left
window_size_right = -1 if window_size_right >= max_seqlen_k else window_size_right
Expand Down Expand Up @@ -2034,7 +2068,6 @@ def _validate(name: str, t: torch.Tensor):
_validate("cu_seqlens_q_padded", cu_seqlens_q_padded)
if cu_seqlens_k_padded is not None:
_validate("cu_seqlens_k_padded", cu_seqlens_k_padded)

out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd(
q,
k,
Expand Down Expand Up @@ -2064,6 +2097,7 @@ def _validate(name: str, t: torch.Tensor):
gen=None,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_k_padded=cu_seqlens_k_padded,
sink_ptr=sink_ptr,
)
return out, softmax_lse, S_dmask, rng_state

Expand Down Expand Up @@ -2095,6 +2129,7 @@ def _flash_attn_varlen_backward(
zero_tensors: bool = False,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> torch.Tensor:

(_, nhead_q, hdim_q) = q.shape
Expand Down Expand Up @@ -2256,6 +2291,7 @@ def can_impl_fmha_v3_bwd_gfx950():
None,
cu_seqlens_q_padded,
cu_seqlens_k_padded,
sink_ptr=sink_ptr,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return softmax_d
Expand Down Expand Up @@ -2290,6 +2326,7 @@ def forward(
cu_seqlens_k_padded=None,
is_v3_atomic_fp32: Optional[bool] = True,
how_v3_bf16_cvt: Optional[int] = 1,
sink_ptr=None,
):
is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v])
if softmax_scale is None:
Expand All @@ -2301,7 +2338,6 @@ def forward(
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
if head_size_v_og % 8 != 0:
v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8])

out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
k,
Expand Down Expand Up @@ -2330,6 +2366,7 @@ def forward(
how_v3_bf16_cvt=how_v3_bf16_cvt,
block_table=block_table,
out=out,
sink_ptr=sink_ptr,
)
if is_grad:

Expand Down Expand Up @@ -2409,6 +2446,7 @@ def backward(ctx, dout, *args):
how_v3_bf16_cvt=ctx.how_v3_bf16_cvt,
cu_seqlens_q_padded=ctx.cu_seqlens_q_padded,
cu_seqlens_k_padded=ctx.cu_seqlens_k_padded,
sink_ptr=None,
)
dq = dq[..., :head_size_q_og] # We could have padded the head dimension
dk = dk[..., :head_size_q_og]
Expand Down Expand Up @@ -2457,6 +2495,7 @@ def backward(ctx, dout, *args):
None, # cu_seqlens_k_padded
None, # is_v3_atomic_fp32
None, # how_v3_bf16_cvt
None, # sink_ptr
)


Expand Down Expand Up @@ -2484,6 +2523,7 @@ def flash_attn_varlen_func(
out=None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_k_padded: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
):
if block_table is not None and (
cu_seqlens_q_padded is not None or cu_seqlens_k_padded is not None
Expand Down Expand Up @@ -2574,6 +2614,7 @@ def flash_attn_varlen_func(
cu_seqlens_k_padded,
True,
how_v3_bf16_cvt,
sink_ptr,
)


Expand All @@ -2600,6 +2641,7 @@ def mha_batch_prefill_fake_tensors(
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -2669,6 +2711,7 @@ def mha_batch_prefill(
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
gen: Optional[Generator] = None,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ...

Expand Down Expand Up @@ -2697,6 +2740,7 @@ def _mha_batch_prefill(
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
sink_ptr: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
Expand Down Expand Up @@ -2724,6 +2768,8 @@ def _mha_batch_prefill(
q_descale,
k_descale,
v_descale,
sink_ptr,
None,
# custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd},
)
return out, softmax_lse, S_dmask, rng_state
Expand Down Expand Up @@ -2751,11 +2797,17 @@ def mha_batch_prefill_func(
q_descale=None,
k_descale=None,
v_descale=None,
sink_ptr=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
head_size_q_og = q.size(2)
head_size_v_og = v.size(2)
if sink_ptr is not None:
assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q"
assert sink_ptr.shape[0] == q.size(1), "sink_ptr has incorrect shape"
if sink_ptr.dtype != torch.float32:
sink_ptr = sink_ptr.to(torch.float32)
if head_size_q_og % 8 != 0:
q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8])
k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8])
Expand Down Expand Up @@ -2783,6 +2835,7 @@ def mha_batch_prefill_func(
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
sink_ptr=sink_ptr,
)
out = out_padded[..., :head_size_v_og]

Expand All @@ -2805,6 +2858,7 @@ def flash_attn_fp8_pertensor_func(
causal=False,
window_size=(-1, -1, 0), # -1 means infinite context window, 0 means no sink
softmax_scale=None,
sink_ptr=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -2832,6 +2886,7 @@ def flash_attn_fp8_pertensor_func(
v_descale=v_descale,
return_lse=False,
return_softmax=False,
sink_ptr=sink_ptr,
)
out = out_padded[..., :head_size_v_og]
return out
Expand All @@ -2853,6 +2908,7 @@ def flash_attn_varlen_fp8_pertensor_func(
causal=False,
window_size=(-1, -1, 0), # -1 means infinite context window
softmax_scale=None,
sink_ptr=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down Expand Up @@ -2888,6 +2944,7 @@ def flash_attn_varlen_fp8_pertensor_func(
v_descale=v_descale,
return_lse=False,
return_softmax=False,
sink_ptr=sink_ptr,
)
out = out_padded[..., :head_size_v_og]
return out
Loading
Loading