diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index b0ea67e377..68ff3e2f2c 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit b0ea67e37725c26860a3520dc31c1f7a01164db9 +Subproject commit 68ff3e2f2c2bb0a5273f010e4d24614f6178982c diff --git a/aiter/ops/mha.py b/aiter/ops/mha.py index 39a2ae4a2b..dee77a5e8f 100644 --- a/aiter/ops/mha.py +++ b/aiter/ops/mha.py @@ -32,6 +32,7 @@ def cmdGenFunc_mha_fwd( q_descale: Optional[Tensor] = None, k_descale: Optional[Tensor] = None, v_descale: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[Generator] = None, ): (_, seqlen_q, _, _) = q.shape @@ -175,6 +176,7 @@ def gen_mha_fwd_fake_tensors( q_descale: Optional[Tensor] = None, k_descale: Optional[Tensor] = None, v_descale: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[torch.Generator] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return common_mha_fwd_fake_tensors( @@ -208,6 +210,7 @@ def mha_fwd( q_descale: Optional[Tensor] = None, k_descale: Optional[Tensor] = None, v_descale: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -285,6 +288,7 @@ def cmdGenFunc_mha_varlen_fwd( gen: Optional[torch.Generator] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, + sink_ptr: Optional[torch.Tensor] = None, ): # causal=true is the same as causal=false in this case causal = is_causal @@ -445,6 +449,7 @@ def gen_mha_varlen_fwd_fake_tensor( gen: Optional[torch.Generator] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: device = q.device dtype = q.dtype @@ -513,6 +518,7 @@ def mha_varlen_fwd( gen: Optional[torch.Generator] = None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -631,6 +637,7 @@ def cmdGenFunc_mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + sink_ptr: Optional[Tensor] = None, ): md_name = "mha_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -781,6 +788,7 @@ def gen_mha_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return common_mha_bwd_fake_tensors(q, k, v, dq, dk, dv) @@ -812,6 +820,7 @@ def mha_bwd( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -893,6 +902,7 @@ def cmdGenFunc_mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> dict[str, Any]: md_name = "mha_varlen_bwd" filter1 = "*" # get_bwd_dot_do_o_blobs() @@ -972,6 +982,7 @@ def cmdGenFunc_mha_batch_prefill( q_descale: Optional[Tensor] = None, k_descale: Optional[Tensor] = None, v_descale: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[Generator] = None, ): # causal=true is the same as causal=false in this case @@ -1104,6 +1115,7 @@ def gen_mha_varlen_bwd_fake_tensors( alibi_slopes: Optional[Tensor] = None, rng_state: Optional[Tensor] = None, gen: Optional[Generator] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return gen_mha_varlen_bwd_fake_tensors_common( q, k, v, cu_seqlens_q, max_seqlen_q, zero_tensors, dq, dk, dv @@ -1143,6 +1155,7 @@ def mha_varlen_bwd( gen: Optional[Generator] = None, cu_seqlens_q_padded: Optional[Tensor] = None, cu_seqlens_k_padded: Optional[Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -1241,11 +1254,16 @@ def _flash_attn_forward( how_v3_bf16_cvt: Optional[int] = 1, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: (_, seqlen_q, nhead_q, hdim_q) = q.shape (_, seqlen_k, nhead_k, hdim_v) = v.shape - + if sink_ptr is not None: + assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q" + assert sink_ptr.shape[0] == nhead_q, "sink_ptr has incorrect shape" + if sink_ptr.dtype != torch.float32: + sink_ptr = sink_ptr.to(torch.float32) # mask window_size_left = -1 if window_size_left >= seqlen_k else window_size_left window_size_right = -1 if window_size_right >= seqlen_k else window_size_right @@ -1319,6 +1337,7 @@ def _validate_cu(name: str, x: Optional[torch.Tensor]): q_descale, k_descale, v_descale, + sink_ptr, None, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) @@ -1527,6 +1546,7 @@ def _flash_attn_backward_fake( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, + sink_ptr: Optional[Tensor] = None, ) -> torch.Tensor: batch_size = q.size(0) seqlen_q = q.size(1) @@ -1563,6 +1583,7 @@ def _flash_attn_backward( rng_state: Optional[torch.Tensor] = None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, + sink_ptr: Optional[Tensor] = None, ) -> torch.Tensor: # rtna & rtz are deprecated in gfx950 if get_gfx() == "gfx950" and how_v3_bf16_cvt != 0: @@ -1681,6 +1702,7 @@ def can_impl_fmha_v3_bwd_gfx950(): alibi_slopes, rng_state, None, + sink_ptr, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -1707,6 +1729,7 @@ def forward( how_v3_bf16_cvt: Optional[int] = 1, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ): is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) if softmax_scale is None: @@ -1738,6 +1761,7 @@ def forward( how_v3_bf16_cvt=how_v3_bf16_cvt, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + sink_ptr=sink_ptr, ) if is_grad: assert return_lse @@ -1795,6 +1819,7 @@ def backward(ctx, dout, *args): rng_state, ctx.is_v3_atomic_fp32, ctx.how_v3_bf16_cvt, + sink_ptr=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] @@ -1836,6 +1861,7 @@ def backward(ctx, dout, *args): None, # how_v3_bf16_cvt None, # cu_seqlens_q None, # cu_seqlens_kv + None, # sink_ptr ) @@ -1855,6 +1881,7 @@ def flash_attn_func( how_v3_bf16_cvt=1, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -1925,6 +1952,7 @@ def flash_attn_func( how_v3_bf16_cvt, cu_seqlens_q, cu_seqlens_kv, + sink_ptr, ) @@ -1957,12 +1985,18 @@ def _flash_attn_varlen_forward( block_table: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, zero_tensors: bool = False, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: (_, nhead_q, hdim_q) = q.shape nhead_k = v.shape[-2] hdim_v = v.shape[-1] + if sink_ptr is not None: + assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q" + assert sink_ptr.shape[0] == nhead_q, "sink_ptr has incorrect shape" + if sink_ptr.dtype != torch.float32: + sink_ptr = sink_ptr.to(torch.float32) # mask window_size_left = -1 if window_size_left >= max_seqlen_k else window_size_left window_size_right = -1 if window_size_right >= max_seqlen_k else window_size_right @@ -2034,7 +2068,6 @@ def _validate(name: str, t: torch.Tensor): _validate("cu_seqlens_q_padded", cu_seqlens_q_padded) if cu_seqlens_k_padded is not None: _validate("cu_seqlens_k_padded", cu_seqlens_k_padded) - out, softmax_lse, S_dmask, rng_state = mha_varlen_fwd( q, k, @@ -2064,6 +2097,7 @@ def _validate(name: str, t: torch.Tensor): gen=None, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_k_padded=cu_seqlens_k_padded, + sink_ptr=sink_ptr, ) return out, softmax_lse, S_dmask, rng_state @@ -2095,6 +2129,7 @@ def _flash_attn_varlen_backward( zero_tensors: bool = False, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> torch.Tensor: (_, nhead_q, hdim_q) = q.shape @@ -2256,6 +2291,7 @@ def can_impl_fmha_v3_bwd_gfx950(): None, cu_seqlens_q_padded, cu_seqlens_k_padded, + sink_ptr=sink_ptr, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return softmax_d @@ -2290,6 +2326,7 @@ def forward( cu_seqlens_k_padded=None, is_v3_atomic_fp32: Optional[bool] = True, how_v3_bf16_cvt: Optional[int] = 1, + sink_ptr=None, ): is_grad = is_grad_enabled and any(x.requires_grad for x in [q, k, v]) if softmax_scale is None: @@ -2301,7 +2338,6 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8]) if head_size_v_og % 8 != 0: v = torch.nn.functional.pad(v, [0, 8 - head_size_v_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( q, k, @@ -2330,6 +2366,7 @@ def forward( how_v3_bf16_cvt=how_v3_bf16_cvt, block_table=block_table, out=out, + sink_ptr=sink_ptr, ) if is_grad: @@ -2409,6 +2446,7 @@ def backward(ctx, dout, *args): how_v3_bf16_cvt=ctx.how_v3_bf16_cvt, cu_seqlens_q_padded=ctx.cu_seqlens_q_padded, cu_seqlens_k_padded=ctx.cu_seqlens_k_padded, + sink_ptr=None, ) dq = dq[..., :head_size_q_og] # We could have padded the head dimension dk = dk[..., :head_size_q_og] @@ -2457,6 +2495,7 @@ def backward(ctx, dout, *args): None, # cu_seqlens_k_padded None, # is_v3_atomic_fp32 None, # how_v3_bf16_cvt + None, # sink_ptr ) @@ -2484,6 +2523,7 @@ def flash_attn_varlen_func( out=None, cu_seqlens_q_padded: Optional[torch.Tensor] = None, cu_seqlens_k_padded: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ): if block_table is not None and ( cu_seqlens_q_padded is not None or cu_seqlens_k_padded is not None @@ -2574,6 +2614,7 @@ def flash_attn_varlen_func( cu_seqlens_k_padded, True, how_v3_bf16_cvt, + sink_ptr, ) @@ -2600,6 +2641,7 @@ def mha_batch_prefill_fake_tensors( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2669,6 +2711,7 @@ def mha_batch_prefill( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, gen: Optional[Generator] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ... @@ -2697,6 +2740,7 @@ def _mha_batch_prefill( q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, + sink_ptr: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -2724,6 +2768,8 @@ def _mha_batch_prefill( q_descale, k_descale, v_descale, + sink_ptr, + None, # custom_build_args={"md_name": md_name, "blob_gen_cmd": blob_gen_cmd}, ) return out, softmax_lse, S_dmask, rng_state @@ -2751,11 +2797,17 @@ def mha_batch_prefill_func( q_descale=None, k_descale=None, v_descale=None, + sink_ptr=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) head_size_q_og = q.size(2) head_size_v_og = v.size(2) + if sink_ptr is not None: + assert sink_ptr.device == q.device, "sink_ptr must be on the same device as q" + assert sink_ptr.shape[0] == q.size(1), "sink_ptr has incorrect shape" + if sink_ptr.dtype != torch.float32: + sink_ptr = sink_ptr.to(torch.float32) if head_size_q_og % 8 != 0: q = torch.nn.functional.pad(q, [0, 8 - head_size_q_og % 8]) k = torch.nn.functional.pad(k, [0, 8 - head_size_q_og % 8]) @@ -2783,6 +2835,7 @@ def mha_batch_prefill_func( q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + sink_ptr=sink_ptr, ) out = out_padded[..., :head_size_v_og] @@ -2805,6 +2858,7 @@ def flash_attn_fp8_pertensor_func( causal=False, window_size=(-1, -1, 0), # -1 means infinite context window, 0 means no sink softmax_scale=None, + sink_ptr=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2832,6 +2886,7 @@ def flash_attn_fp8_pertensor_func( v_descale=v_descale, return_lse=False, return_softmax=False, + sink_ptr=sink_ptr, ) out = out_padded[..., :head_size_v_og] return out @@ -2853,6 +2908,7 @@ def flash_attn_varlen_fp8_pertensor_func( causal=False, window_size=(-1, -1, 0), # -1 means infinite context window softmax_scale=None, + sink_ptr=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2888,6 +2944,7 @@ def flash_attn_varlen_fp8_pertensor_func( v_descale=v_descale, return_lse=False, return_softmax=False, + sink_ptr=sink_ptr, ) out = out_padded[..., :head_size_v_og] return out diff --git a/csrc/cpp_itfs/mha_fwd.cpp b/csrc/cpp_itfs/mha_fwd.cpp index 633353a446..7d279390b1 100644 --- a/csrc/cpp_itfs/mha_fwd.cpp +++ b/csrc/cpp_itfs/mha_fwd.cpp @@ -273,6 +273,7 @@ float fmha_fwd_ck(mha_fwd_args a, const ck_tile::stream_config& s) a.seqlen_k_ptr, a.cu_seqlen_q_ptr, a.cu_seqlen_k_ptr, + a.sink_ptr, a.seqlen_q, a.seqlen_k, a.batch, diff --git a/csrc/include/mha_fwd.h b/csrc/include/mha_fwd.h index 607c8a55ef..1140e708f0 100644 --- a/csrc/include/mha_fwd.h +++ b/csrc/include/mha_fwd.h @@ -143,6 +143,7 @@ struct mha_fwd_args // array [batch + 1]. (Used with padding) const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length // array [batch + 1]. (Used with padding) + const void* sink_ptr; ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 03a9adf70c..dc05657b7f 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -820,6 +820,7 @@ namespace py = pybind11; py::arg("q_descale") = std::nullopt, \ py::arg("k_descale") = std::nullopt, \ py::arg("v_descale") = std::nullopt, \ + py::arg("sink_ptr") = std::nullopt, \ py::arg("gen") = std::nullopt); #define LIBMHA_FWD_PYBIND \ @@ -1012,7 +1013,8 @@ namespace py = pybind11; py::arg("v_descale") = std::nullopt, \ py::arg("gen") = std::nullopt, \ py::arg("cu_seqlens_q_padded") = std::nullopt, \ - py::arg("cu_seqlens_k_padded") = std::nullopt); + py::arg("cu_seqlens_k_padded") = std::nullopt, \ + py::arg("sink_ptr") = std::nullopt); #define MHA_BATCH_PREFILL_PYBIND \ m.def("mha_batch_prefill", \ @@ -1035,11 +1037,12 @@ namespace py = pybind11; py::arg("return_softmax_lse"), \ py::arg("return_dropout_randval"), \ py::arg("out") = std::nullopt, \ - py::arg("bias") = std::nullopt, \ + py::arg("bias") = std::nullopt, \ py::arg("alibi_slopes") = std::nullopt, \ py::arg("q_descale") = std::nullopt, \ py::arg("k_descale") = std::nullopt, \ py::arg("v_descale") = std::nullopt, \ + py::arg("sink_ptr") = std::nullopt, \ py::arg("gen") = std::nullopt); #define MOE_OP_PYBIND \ diff --git a/csrc/include/torch/mha_batch_prefill.h b/csrc/include/torch/mha_batch_prefill.h index 9f035e0175..f8f8606600 100644 --- a/csrc/include/torch/mha_batch_prefill.h +++ b/csrc/include/torch/mha_batch_prefill.h @@ -29,6 +29,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] std::optional q_descale, // [1] std::optional k_descale, // [1] std::optional v_descale, // [1] + std::optional sink_ptr_, // [hq]; std::optional gen_); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_fwd.h b/csrc/include/torch/mha_fwd.h index 94b5cf056d..acc58fb3cf 100644 --- a/csrc/include/torch/mha_fwd.h +++ b/csrc/include/torch/mha_fwd.h @@ -24,6 +24,7 @@ std::vector mha_fwd(at::Tensor& q, // [b, sq, hq, d] std::optional q_descale, // [1] std::optional k_descale, // [1] std::optional v_descale, // [1] + std::optional sink_ptr, // [hq] std::optional gen); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/include/torch/mha_varlen_fwd.h b/csrc/include/torch/mha_varlen_fwd.h index cbdabcfc85..c2bec0dbd7 100644 --- a/csrc/include/torch/mha_varlen_fwd.h +++ b/csrc/include/torch/mha_varlen_fwd.h @@ -33,6 +33,7 @@ mha_varlen_fwd(at::Tensor& q, // [total_q, hq, d std::optional v_descale, // [1] std::optional gen, std::optional cu_seqlens_q_padded = std::nullopt, - std::optional cu_seqlens_k_padded = std::nullopt); + std::optional cu_seqlens_k_padded = std::nullopt, + std::optional sink_ptr = std::nullopt); } // namespace torch_itfs } // namespace aiter diff --git a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu index 68dc4d45f2..7a13debceb 100644 --- a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu +++ b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu @@ -27,6 +27,7 @@ get_ck_fmha_batch_prefill_args(bool has_lse, const at::Tensor seqlens_q, const at::Tensor kv_indptr, const at::Tensor kv_page_indices, + std::optional sink_ptr_, std::optional& bias_, std::optional& alibi_slopes_, std::optional& q_descale, @@ -109,6 +110,7 @@ get_ck_fmha_batch_prefill_args(bool has_lse, args.q_descale_ptr = q_descale.has_value() ? q_descale.value().data_ptr() : nullptr; args.k_descale_ptr = k_descale.has_value() ? k_descale.value().data_ptr() : nullptr; args.v_descale_ptr = v_descale.has_value() ? v_descale.value().data_ptr() : nullptr; + args.sink_ptr = sink_ptr_.has_value() ? sink_ptr_.value().data_ptr() : nullptr; args.rand_val_ptr = has_dropout_randval ? dropout_randval.data_ptr() : nullptr; args.lse_ptr = has_lse ? softmax_lse.data_ptr() : nullptr; args.o_ptr = out.data_ptr(); @@ -184,6 +186,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] std::optional q_descale, // [1] std::optional k_descale, // [1] std::optional v_descale, // [1] + std::optional sink_ptr, // [hq] std::optional gen_) { auto q_dtype = q.scalar_type(); @@ -242,7 +245,6 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] const int num_heads_k = k.size(1); const int num_blocks = k.size(0); - if(max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; @@ -399,6 +401,7 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] cu_seqlens_q, kv_indptr, kv_page_indices, + sink_ptr, bias_, alibi_slopes_, q_descale, diff --git a/csrc/py_itfs_ck/mha_fwd_kernels.cu b/csrc/py_itfs_ck/mha_fwd_kernels.cu index 6cee830003..8c8680c963 100644 --- a/csrc/py_itfs_ck/mha_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_fwd_kernels.cu @@ -40,7 +40,8 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, const std::optional &cu_seqlens_kv_, const std::string& data_type, bias_enum bias_type, - quant_scale_enum qscale_type) + quant_scale_enum qscale_type, + const std::optional &sink_ptr_) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) @@ -99,7 +100,7 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, const ck_tile::index_t *cu_seqlen_q_ptr = cu_seqlens_q_.has_value() ? reinterpret_cast(cu_seqlens_q_.value().data_ptr()) : nullptr; const ck_tile::index_t *cu_seqlen_kv_ptr = cu_seqlens_kv_.has_value() ? reinterpret_cast(cu_seqlens_kv_.value().data_ptr()) : nullptr; - + const void *sink_ptr = sink_ptr_.has_value() ? sink_ptr_.value().data_ptr() : nullptr; return mha_fwd_args{false, // use_asm_v3 false, // v3_api_check 1, // how_v3_bf16_cvt @@ -108,7 +109,7 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, static_cast(bias_type), has_lse, static_cast(qscale_type), - mask.sink > 0, // hsa_sink + mask.sink > 0, // has_sink q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -125,6 +126,7 @@ mha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr cu_seqlen_q_ptr, // cu_seqlen_q_ptr cu_seqlen_kv_ptr, // cu_seqlen_k_ptr + sink_ptr, // sink_ptr seqlen_q, seqlen_k, b, @@ -185,6 +187,7 @@ mha_fwd(at::Tensor &q, // [b, sq, hq, d] std::optional q_descale_, // [1] std::optional k_descale_, // [1] std::optional v_descale_, // [1] + std::optional sink_ptr, // [hq] std::optional gen_) { auto q_dtype = q.scalar_type(); @@ -369,7 +372,8 @@ mha_fwd(at::Tensor &q, // [b, sq, hq, d] cu_seqlens_kv_, dtype_str, bias_type, - qscale_type); + qscale_type, + sink_ptr); float t = aiter::mha_fwd(args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); diff --git a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu index f20f9c1009..39ebd30774 100644 --- a/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu +++ b/csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu @@ -45,7 +45,8 @@ mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, std::optional &cu_seqlens_k_padded_, const std::string& data_type, bias_enum bias_type, - quant_scale_enum qscale_type) + quant_scale_enum qscale_type, + std::optional &sink_ptr_) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) @@ -106,6 +107,7 @@ mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, const void *q_descale_ptr = q_descale_.has_value() ? q_descale_.value().data_ptr() : nullptr; const void *k_descale_ptr = k_descale_.has_value() ? k_descale_.value().data_ptr() : nullptr; const void *v_descale_ptr = v_descale_.has_value() ? v_descale_.value().data_ptr() : nullptr; + const void *sink_ptr = sink_ptr_.has_value() ? sink_ptr_.value().data_ptr() : nullptr; const void* seqstart_k_ptr = nullptr; const void* seqstart_q_ptr = nullptr; @@ -151,6 +153,7 @@ mha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_k_ptr (per-sequence logical K lengths) cu_seqlen_q_ptr, // cu_seqlen_q_ptr cu_seqlen_k_ptr, // cu_seqlen_k_ptr + sink_ptr, total_q, total_k, b, @@ -216,7 +219,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, at::Tensor out, at::Tensor lse, at::Tensor lse_acc, - at::Tensor out_acc) + at::Tensor out_acc, + std::optional &sink_ptr_) { // q: (total_q, nheads, d) // k: (num_blocks, page_block_size, num_heads_k, d) @@ -270,7 +274,7 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse, else { args.seqlen_k_ptr = nullptr; } - + args.sink_ptr = (sink_ptr_.has_value()) ? sink_ptr_.value().data_ptr() : nullptr; args.batch = b; args.max_seqlen_q = max_seqlen_q; args.hdim_q = d; @@ -369,7 +373,8 @@ mha_varlen_fwd( std::optional v_descale_, // [1] std::optional gen_, std::optional cu_seqlens_q_padded_, // [b+1] physical starts with PAD - std::optional cu_seqlens_k_padded_) // [b+1] + std::optional cu_seqlens_k_padded_, // [b+1] + std::optional sink_ptr) { auto q_dtype = q.scalar_type(); bool is_qkv_fp8 = q_dtype == at::ScalarType::Float8_e4m3fn || q_dtype == at::ScalarType::Float8_e4m3fnuz; @@ -600,8 +605,8 @@ mha_varlen_fwd( out, softmax_lse, softmax_lse_accum, - out_accum); - + out_accum, + sink_ptr); float t = aiter::mha_fwd_splitkv(args, stream_config, dtype_str, @@ -650,7 +655,8 @@ mha_varlen_fwd( const_cast&>(cu_seqlens_k_padded_), dtype_str, bias_type, - qscale_type); + qscale_type, + sink_ptr); float t = aiter::mha_fwd(args, stream_config); // how_v3_bf16_cvt TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); } diff --git a/csrc/py_itfs_cu/asm_mha_fwd.cu b/csrc/py_itfs_cu/asm_mha_fwd.cu index c184cee1a7..efa9ecf734 100644 --- a/csrc/py_itfs_cu/asm_mha_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_fwd.cu @@ -113,6 +113,7 @@ mha_fwd_args get_asm_fmha_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr + nullptr, //sink_ptr seqlen_q, seqlen_k, b, diff --git a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu index 8abf542e52..6f2be697bd 100644 --- a/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu +++ b/csrc/py_itfs_cu/asm_mha_varlen_fwd.cu @@ -142,6 +142,7 @@ mha_fwd_args get_asm_mha_varlen_fwd_args(bool has_lse, seqlens_k.has_value() ? seqlens_k.value().data_ptr() : nullptr, // seqlen_k_ptr cu_seqlen_q_ptr, // cu_seqlen_q_ptr cu_seqlen_k_ptr, // cu_seqlen_k_ptr + nullptr, // sink_ptr total_q, total_k, b,