From 01bb614054fec49c8fed11d66d128a378f4fc7c6 Mon Sep 17 00:00:00 2001 From: Kingsley Kim Date: Fri, 10 Apr 2026 14:14:18 -0400 Subject: [PATCH 1/3] initialized gdn structure --- csrc/api/gdn_sm90.cu | 136 ++++++++++++++++ csrc/api/pybind.cu | 16 ++ cula/gdn/__init__.py | 20 +++ cula/gdn/gate.py | 168 ++++++++++++++++++++ cula/gdn/hopper_fused_fwd.py | 251 +++++++++++++++++++++++++++++ cula/utils.py | 30 ++++ tests/test_gdn_fused_fwd.py | 296 +++++++++++++++++++++++++++++++++++ 7 files changed, 917 insertions(+) create mode 100644 csrc/api/gdn_sm90.cu create mode 100644 cula/gdn/__init__.py create mode 100644 cula/gdn/gate.py create mode 100644 cula/gdn/hopper_fused_fwd.py create mode 100644 tests/test_gdn_fused_fwd.py diff --git a/csrc/api/gdn_sm90.cu b/csrc/api/gdn_sm90.cu new file mode 100644 index 0000000..1edc1ea --- /dev/null +++ b/csrc/api/gdn_sm90.cu @@ -0,0 +1,136 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "gdn/sm90/prefill_kernel.hpp" + +using OptionalTensor = std::optional; + +std::tuple +gdn_fwd_prefill( + OptionalTensor output_, + OptionalTensor output_state_, + torch::Tensor const& q, + torch::Tensor const& k, + torch::Tensor const& v, + OptionalTensor input_state_, + OptionalTensor alpha_, + OptionalTensor beta_, + torch::Tensor const& cu_seqlens, + torch::Tensor workspace_buffer, + float scale, + bool safe_gate) { + // Q, K, V: [packed_seq, H, D] (already packed by Python layer) + auto packed_seq = q.size(0); + auto num_heads = q.size(1); + auto head_size = q.size(2); + auto num_seqs = cu_seqlens.size(0) - 1; + + // GDN constraint: all head counts must be the same + TORCH_CHECK(num_heads == k.size(1), "GDN requires num_q_heads == num_k_heads, got ", num_heads, " vs ", k.size(1)); + TORCH_CHECK(num_heads == v.size(1), "GDN requires num_q_heads == num_v_heads, got ", num_heads, " vs ", v.size(1)); + TORCH_CHECK(head_size == v.size(2), "GDN requires Q and V head dim to match, got ", head_size, " vs ", v.size(2)); + + // Allocate output if not provided + torch::Tensor output = output_.has_value() ? output_.value() + : torch::empty( + {packed_seq, num_heads, head_size}, + torch::TensorOptions().dtype(q.dtype()).device(q.device())); + + // Allocate output state if not provided + torch::Tensor output_state = output_state_.has_value() + ? output_state_.value() + : torch::zeros( + {num_seqs, num_heads, head_size, head_size}, + torch::TensorOptions().dtype(torch::kFloat32).device(q.device())); + + // Validate dtypes + TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16"); + TORCH_CHECK(k.dtype() == torch::kBFloat16, "k must be bfloat16"); + TORCH_CHECK(v.dtype() == torch::kBFloat16, "v must be bfloat16"); + TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32, "cu_seqlens must be int32"); + + // Validate contiguity + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(output.is_contiguous(), "output must be contiguous"); + TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous"); + TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens must be contiguous"); + TORCH_CHECK(workspace_buffer.is_contiguous(), "workspace_buffer must be contiguous"); + + // Extract optional pointers + float const* alpha_ptr = nullptr; + float const* beta_ptr = nullptr; + float const* input_state_ptr = nullptr; + + if (alpha_.has_value()) { + auto& alpha = alpha_.value(); + TORCH_CHECK(alpha.dtype() == torch::kFloat32, "alpha must be float32"); + TORCH_CHECK(alpha.is_contiguous(), "alpha must be contiguous"); + TORCH_CHECK( + alpha.size(0) == packed_seq && alpha.size(1) == num_heads, "alpha shape must be [packed_seq, num_heads]"); + alpha_ptr = alpha.data_ptr(); + } + if (beta_.has_value()) { + auto& beta = beta_.value(); + TORCH_CHECK(beta.dtype() == torch::kFloat32, "beta must be float32"); + TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous"); + TORCH_CHECK( + beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]"); + beta_ptr = beta.data_ptr(); + } + if (input_state_.has_value()) { + auto& input_state = input_state_.value(); + TORCH_CHECK(input_state.dtype() == torch::kFloat32, "input_state must be float32"); + TORCH_CHECK(input_state.is_contiguous(), "input_state must be contiguous"); + input_state_ptr = input_state.data_ptr(); + } + + // Auto-compute scale if 0 + if (scale == 0.0f) { + scale = 1.0f / std::sqrt(static_cast(head_size)); + } + + auto stream = at::cuda::getCurrentCUDAStream(); + auto sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + + using bf16 = cute::bfloat16_t; + using Sm90 = cutlass::arch::Sm90; + gdn::sm90::launch_gdn_fwd_prefill_kernel( + stream, + reinterpret_cast(output.data_ptr()), + output_state.data_ptr(), + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(v.data_ptr()), + input_state_ptr, + alpha_ptr, + beta_ptr, + cu_seqlens.data_ptr(), + workspace_buffer.data_ptr(), + static_cast(num_seqs), + static_cast(num_heads), + static_cast(head_size), + static_cast(packed_seq), + scale, + safe_gate, + static_cast(sm_count)); + + return {output, output_state}; +} diff --git a/csrc/api/pybind.cu b/csrc/api/pybind.cu index ba2deb6..e19311f 100644 --- a/csrc/api/pybind.cu +++ b/csrc/api/pybind.cu @@ -65,6 +65,21 @@ kda_fwd_prefill( torch::Tensor workspace_buffer, float scale, bool safe_gate); + +std::tuple +gdn_fwd_prefill( + std::optional output_, + std::optional output_state_, + torch::Tensor const& q, + torch::Tensor const& k, + torch::Tensor const& v, + std::optional input_state_, + std::optional alpha_, + std::optional beta_, + torch::Tensor const& cu_seqlens, + torch::Tensor workspace_buffer, + float scale, + bool safe_gate); #endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -75,5 +90,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #endif #if defined(CULA_SM90A_ENABLED) m.def("kda_fwd_prefill", &kda_fwd_prefill); + m.def("gdn_fwd_prefill", &gdn_fwd_prefill); #endif } diff --git a/cula/gdn/__init__.py b/cula/gdn/__init__.py new file mode 100644 index 0000000..1b69a56 --- /dev/null +++ b/cula/gdn/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cula.gdn.hopper_fused_fwd import cula_gdn_prefill as gdn_prefill_hopper + +__all__ = [ + "gdn_prefill_hopper" +] + diff --git a/cula/gdn/gate.py b/cula/gdn/gate.py new file mode 100644 index 0000000..7ee1b41 --- /dev/null +++ b/cula/gdn/gate.py @@ -0,0 +1,168 @@ +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.ops.utils.softplus import softplus +from fla.ops.utils.index import prepare_chunk_indices +from fla.utils import autotune_cache_kwargs, input_guard + +BT_LIST_AUTOTUNE = [32, 64, 128] +NUM_WARPS_AUTOTUNE = [4, 8, 16, 32] + +def naive_gdn_gate( + g: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + Torch reference implementation for KDA gate computation. + + Computes: g = -A_log.exp().unsqueeze(-1) * softplus(g + dt_bias.view(g.shape[-1])) + + Args: + g (torch.Tensor): + Input tensor of shape `[..., H]`. + A_log (torch.Tensor): + Parameter tensor with `H` elements. + dt_bias (torch.Tensor | None): + Optional bias tensor added to `g` before activation, shape `[H]`. + + Returns: + Output tensor of shape `[..., H]` . + """ + H = g.shape[-1] + g = g.float() + if dt_bias is not None: + g = g + dt_bias + g = (-A_log.float().exp() * F.softplus(g.float())).to(output_dtype) + return g + +# naive gdn lowerbound method based off of fla.ops.kda.gate +def naive_gdn_lowerbound_gate( + g: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor | None = None, + lower_bound: float = -5.0, + output_dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + num_heads = g.shape[-1] + g = g.float() + if dt_bias is not None: + g = g + dt_bias + g = lower_bound * F.sigmoid(A_log.exp() * g) + return g.to(output_dtype) + +@triton.heuristics({ + "HAS_BIAS": lambda args: args["dt_bias"] is not None, + 'HAS_SCALE': lambda args: args['scale'] is not None, + 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, + 'USE_LOWER_BOUND': lambda args: args['lower_bound'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8] + ], + key=['H', 'BT', 'IS_VARLEN', 'REVERSE'], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=['T']) +def gdn_gate_chunk_cumsum_scalar_kernel( + s, + A_log, + dt_bias, + o, + scale, + cu_seqlens, + chunk_indices, + lower_bound, + T, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_LOWER_BOUND: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0)).to(tl.float32) + + # Apply dt_bias if exists + if HAS_BIAS: + b_bias = tl.load(dt_bias + i_h).to(tl.float32) + b_s = b_s + b_bias + + b_A = tl.load(A_log + i_h).to(tl.float32) + if not USE_LOWER_BOUND: + # Apply gate: -exp(A_log) * softplus(g + bias) + b_gate = -exp(b_A) * softplus(b_s) + else: + b_gate = lower_bound * tl.sigmoid(exp(b_A) * b_s) + + # Apply chunk local cumsum + if REVERSE: + b_o = tl.cumsum(b_gate, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_gate, axis=0) + + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@input_guard +def gdn_gate_chunk_cumsum_lowerbound( + g: torch.Tensor, + A_log: torch.Tensor, + chunk_size: int, + scale: float = None, + dt_bias: torch.Tensor | None = None, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype | None = torch.float, + chunk_indices: torch.LongTensor | None = None, + lower_bound: float | None = None, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch_size 1 is supported when cu_seqlens is provided" + assert len(g.shape) == 3 + B, T, H = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype = output_dtype or g.dtype) + def grid(meta): return (NT, B * H) + gdn_gate_chunk_cumsum_scalar_kernel[grid]( + s=g_org, + A_log=A_log, + dt_bias=dt_bias, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound, + T=T, + H=H, + BT=BT, + REVERSE=False, + ) + return g \ No newline at end of file diff --git a/cula/gdn/hopper_fused_fwd.py b/cula/gdn/hopper_fused_fwd.py new file mode 100644 index 0000000..087273e --- /dev/null +++ b/cula/gdn/hopper_fused_fwd.py @@ -0,0 +1,251 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from einops import rearrange + +from fla.modules.l2norm import l2norm_fwd +from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.constant import RCP_LN2 +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +import cula.cudac as cula_cuda +from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens +from cula.gdn.gate import gdn_gate_chunk_cumsum_lowerbound + +class HopperChunkGDNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state : bool = False, + use_qk_l2norm_in_kernel : bool = False, + use_gate_in_kernel : bool = False, + safe_gate : bool = False, + lower_bound : float | None = None, + cu_seqlens : torch.IntTensor | None = None, + chunk_indices : torch.IntTensor | None = None, + ): + chunk_size = 64 + assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same across q, k, v" + + batch_size, seq_len, num_heads, head_dim = q.shape + + if cu_seqlens is None: + cu_seqlens = prepare_uniform_cu_seqlens(batch_size, seq_len, q.device, torch.int32) + + # after setting up cu_seqlens, set batch size to 1 + if batch_size != 1: + q, k, v, g, beta = map(lambda x : rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) + + # compute gate inside kernel + if use_gate_in_kernel: + if safe_gate: + assert lower_bound is not None, "lower_bound must be set when using safe_gate" + g = gdn_gate_chunk_cumsum_lowerbound( + g=g, + A_log=A_log, + dt_bias=dt_bias, + chunk_size=chunk_size, + scale=RCP_LN2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + lower_bound=lower_bound + ) + torch.cuda.synchronize() # DEBUG: surface errors from gdn_gate_chunk_cumsum_lowerbound + + else: + print("launching FLA chunk local cumsum") + g = chunk_local_cumsum( + g=g, + chunk_size=chunk_size, + scale=RCP_LN2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices + ) + torch.cuda.synchronize() # DEBUG: surface errors from chunk_local_cumsum + + q_rstd, k_rstd = None, None + if use_qk_l2norm_in_kernel: + q, q_rstd = l2norm_fwd(q) + k, k_rstd = l2norm_fwd(k) + torch.cuda.synchronize() # DEBUG: surface errors from l2norm + packed_seq = batch_size * seq_len + q = q.reshape(packed_seq, num_heads, head_dim).contiguous() + k = k.reshape(packed_seq, num_heads, head_dim).contiguous() + v = v.reshape(packed_seq, num_heads, head_dim).contiguous() + g = g.reshape(packed_seq, num_heads).contiguous() + beta = beta.reshape(packed_seq, num_heads).contiguous() + + # set up tensormap workspace buffer for CollectiveStoreTMA O + sm_count = get_device_sm_count(q.device) + workspace_count = sm_count * 128 + workspace_buffer = _get_cache_buf("hopper_gdn_fwd_workspace", workspace_count, q.device) + + # call the C++ kernel + # Signature:gdn_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) + print("launching prefill kernel cpp") + o, final_state = cula_cuda.gdn_fwd_prefill( + None, # output_ (auto-allocate) + None, # output_state_ (auto-allocate) + q, + k, + v, + initial_state, # input_state_ + g, # alpha_ + beta, # beta_ + cu_seqlens, + workspace_buffer, + scale, + safe_gate, + ) + torch.cuda.synchronize() # DEBUG: surface errors from gdn_fwd_prefill kernel + print(f"DEBUG: o has nan={o.isnan().any().item()}, shape={o.shape}, dtype={o.dtype}") + print(f"DEBUG: final_state has nan={final_state.isnan().any().item()}, shape={final_state.shape}") + print(f"DEBUG: o[:3]={o.flatten()[:8].tolist()}") + + o = rearrange(o, "(b t) h d -> b t h d", b = batch_size) + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, do, dht + ): + raise NotImplementedError("Backward pass not implemented yet") + + +@torch.compiler.disable +def cula_gdn_prefill( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + safe_gate: bool = False, + lower_bound: float | None = None, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, + **kwargs, +): + r""" + Hopper (SM90) fully-fused GDN forward prefill using CUTLASS TMA warp-specialized kernel. + + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]`. + beta (torch.Tensor): + betas of shape `[B, T, H]`. + scale (Optional[float]): + Scale factor for the KDA attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + use_qk_l2norm_in_kernel (bool): + Whether to apply L2norm to the q,k tensor internally. Default: `False`. + use_gate_in_kernel (bool): + Whether to compute the log-space KDA decay internally. Default: `False`. + safe_gate (bool): + Whether the kernel can assume the input gate values `g` are in a safe range. + When `True`, the kernel can use M=16 TensorCore acceleration. + The safe range is approximately [-5, 0). Default: `False`. + lower_bound (Optional[float]): + Lower bound for the forget gate activation function. Default: `None`. + cu_seqlens (torch.IntTensor): + Cumulative sequence lengths of shape `[N+1]`, int32. + chunk_indices (torch.IntTensor): + Chunk indices for variable-length training. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert_hopper() + assert safe_gate, "Only support safe_gate=True." + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.", + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.", + ) + if initial_state is not None: + assert initial_state.dtype == torch.float32, "initial_state must be in float32." + + A_log, dt_bias = None, None + if use_gate_in_kernel: + assert "A_log" in kwargs, "A_log must be provided when use_gate_in_kernel=True." + A_log, dt_bias = kwargs["A_log"], kwargs.get("dt_bias") + if safe_gate: + if lower_bound is None: + raise ValueError("`lower_bound` must be specified when `safe_gate=True` and `use_gate_in_kernel=True`.") + if not (-5 <= lower_bound < 0): + raise ValueError(f"`lower_bound` must be in the safe range [-5, 0), got {lower_bound}.") + + assert q.shape == k.shape, "q, k, g must have the same shape." + assert beta.shape == g.shape == q.shape[:3], "beta and gate must be of shape (batch size, seq len, num of head)." + assert v.shape == (*q.shape[:3], v.shape[-1]), "v must be of shape (batch size, seq len, num of head, head dim)." + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = HopperChunkGDNFunction.apply( + q, + k, + v, + g, + beta, + A_log, + dt_bias, + scale, + initial_state, + output_final_state, + use_qk_l2norm_in_kernel, + use_gate_in_kernel, + safe_gate, + lower_bound, + cu_seqlens, + chunk_indices, + ) + return o, final_state \ No newline at end of file diff --git a/cula/utils.py b/cula/utils.py index bd70730..1fc6c4f 100644 --- a/cula/utils.py +++ b/cula/utils.py @@ -110,6 +110,36 @@ def get_kda_fused_fwd(device: torch.device | str | int | None = None) -> Callabl ) +def get_gdn_fused_fwd(device: torch.device | str | int | None = None) -> Callable: + """Return the appropriate ``gdn_prefill`` implementation for *device*. + + - sm100/sm103 (Blackwell) → cula.gdn.gdn_prefill_blackwell (not yet available) + - sm90 (Hopper) → cula.gdn.gdn_prefill_hopper + + Args: + device: CUDA device to query. Defaults to the currently active device. + + Raises: + RuntimeError: If the device architecture is not supported. + """ + major, minor = get_device_sm_version(device) + if major == 10 and minor in (0, 3): + # TODO + raise NotImplementedError( + "The Blackwell implementation of fused prefill is not yet available. " + "Please use a sm90a (Hopper) device or wait for future updates." + ) + elif major == 9 and minor == 0: + from cula.gdn import gdn_prefill_hopper + + return gdn_prefill_hopper + else: + raise RuntimeError( + f"Unsupported CUDA compute capability sm_{major}{minor}. " + f"Only sm90a (Hopper) and Blackwell (SM100/SM103) are supported." + ) + + @cute.jit def print_tensor_2d(tensor: cute.Tensor): """ diff --git a/tests/test_gdn_fused_fwd.py b/tests/test_gdn_fused_fwd.py new file mode 100644 index 0000000..0674304 --- /dev/null +++ b/tests/test_gdn_fused_fwd.py @@ -0,0 +1,296 @@ +# Copyright 2025-2026 Ant Group Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from flash-linear-attention: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_kda.py + + +import pytest +import torch +import torch.nn.functional as F +from fla.ops.gated_delta_rule import naive_recurrent_gated_delta_rule, chunk_gated_delta_rule +from fla.utils import assert_close, device + +from cula.utils import get_gdn_fused_fwd +from cula.gdn.gate import naive_gdn_gate + +pytestmark = pytest.mark.sm90_only + + +@pytest.mark.parametrize( + ( + "B", + "T", + "H", + "D", + "gate_logit_normalizer", + "mask_p", + "use_qk_l2norm_in_kernel", + "use_gate_in_kernel", + "safe_gate", + "dtype", + ), + [ + pytest.param( + *test, + id="B{}-T{}-H{}-D{}-gln{}-mask_p{}-l2norm{}-gate{}-safe_gate{}-{}".format(*test), + ) + for test in [ + (1, 63, 1, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 500, 3, 128, 1, 0, False, False, True, torch.bfloat16), + (2, 1000, 3, 128, 1, 0.5, False, False, True, torch.bfloat16), + (3, 1024, 4, 128, 0.1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 128, 1, 0, False, False, True, torch.bfloat16), + (4, 1024, 4, 128, 1, 0, True, False, True, torch.bfloat16), + (2, 1500, 4, 128, 10, 0, False, True, True, torch.bfloat16), + (4, 2048, 8, 128, 1, 0, False, True, True, torch.bfloat16), + ] + ], +) +def test_safe_gate_chunk( + B: int, + T: int, + H: int, + D: int, + gate_logit_normalizer: float, + mask_p: float, + use_qk_l2norm_in_kernel: bool, + use_gate_in_kernel: bool, + safe_gate: bool, + dtype: torch.dtype, +): + from cula.gdn.gate import naive_gdn_lowerbound_gate + + cula_gdn_fused_fwd = get_gdn_fused_fwd(device) + + torch.manual_seed(42) + q = torch.rand(B, T, H, D, dtype=dtype) + k = torch.rand(B, T, H, D, dtype=dtype) + v = torch.rand(B, T, H, D, dtype=dtype) + g = torch.randn(B, T, H, dtype=torch.float if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = torch.randn(H, dtype=torch.float) + dt_bias = torch.randn(H, dtype=torch.float) + else: + g = F.logsigmoid(g) / gate_logit_normalizer + g = g * (torch.rand_like(g) > mask_p) + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clamp(-5, 0) + naive_gdn_gate_fn = naive_gdn_lowerbound_gate + else: + lower_bound = None + naive_gdn_gate_fn = naive_gdn_gate + + beta = torch.randn(B, T, H, dtype=torch.float32).sigmoid() + h0 = torch.randn(B, H, D, D, dtype=torch.float32) + if use_gate_in_kernel: + A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) + q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, g, beta, h0)) + + ref, ref_ht = naive_recurrent_gated_delta_rule( + q=F.normalize(q.clone(), p=2, dim=-1), + k=F.normalize(k.clone(), p=2, dim=-1), + v=v.clone(), + g=(naive_gdn_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + ) + + # Unlike KDA, GDN does not have fused gate preprocessing kernel. It only does chunk_local_cumsum, so the gate_fn must be applied outside + ref_fla, ref_ht_fla = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=(naive_gdn_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + tri, tri_ht = cula_gdn_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else q.clone(), + k=F.normalize(k.clone(), p=2, dim=-1) if not use_qk_l2norm_in_kernel else k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + lower_bound=lower_bound, + ) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("o", ref_fla, tri, 0.005) + assert_close("ht", ref_ht_fla, tri_ht, 0.005) + + +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "safe_gate", "use_gate_in_kernel"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-safe_gate{}-gate{}".format(*test)) + for test in [ + (4, 128, 0.1, [0, 15], torch.bfloat16, True, False), + (4, 128, 0.9, [0, 256, 500, 1000], torch.bfloat16, True, False), + (4, 128, 0.5, [0, 256, 500, 1000], torch.bfloat16, True, False), + (4, 128, 0, [0, 15, 100, 300, 1200, 2000], torch.bfloat16, True, False), + (4, 128, 0, [0, 100, 300, 1200, 3000, 4096], torch.bfloat16, True, False), + # ======Varlen test with simulated trace======= + ( + 32, + 128, + 0, + [0, 247, 699, 982, 1688, 1985, 2383, 3081, 3526, 3973, 4096, 4824, 5101, 5919, 6426, 7137, 7392, 7800, 8192], + torch.bfloat16, + True, + False, + ), + ( + 32, + 128, + 0, + [0, 652, 1255, 1600, 2083, 2345, 2756, 3172, 3767, 4096, 4891, 5236, 5543, 6255, 6480, 6947, 7616, 8192], + torch.bfloat16, + True, + False, + ), + ( + 32, + 128, + 0, + [0, 315, 973, 1283, 2162, 2459, 2678, 2998, 3781, 4096, 4503, 5459, 6318, 6669, 6979, 7583, 8192], + torch.bfloat16, + True, + False, + ), + ( + 32, + 128, + 0, + [0, 494, 1004, 1561, 1908, 2240, 2849, 3116, 4096, 4986, 5626, 6090, 6718, 7244, 7870, 8192], + torch.bfloat16, + True, + False, + ), + ] + ], +) +def test_safe_gate_chunk_varlen( + H: int, + D: int, + mask_p: float, + cu_seqlens: list[int], + dtype: torch.dtype, + safe_gate: bool, + use_gate_in_kernel: bool, +): + from cula.gdn.gate import naive_gdn_lowerbound_gate + + cula_gdn_fused_fwd = get_gdn_fused_fwd(device) + + torch.manual_seed(42) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + cu_seqlens_cpu = cu_seqlens.cpu() + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + q = torch.randn((1, T, H, D), dtype=dtype) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype) + v = torch.randn((1, T, H, D), dtype=dtype) + g = torch.randn(1, T, H, dtype=torch.float if not use_gate_in_kernel else dtype) + if use_gate_in_kernel: + A_log = torch.randn(H, dtype=torch.float) + dt_bias = torch.randn(H, dtype=torch.float) + else: + g = F.logsigmoid(g) + mask = torch.rand_like(g) > mask_p + g = g * mask + (~mask) * (-1000) + if safe_gate: + lower_bound = -5.0 + if not use_gate_in_kernel: + g = g.clamp(-5, 0) + naive_gdn_gate_fn = naive_gdn_lowerbound_gate + else: + lower_bound = None + naive_gdn_gate_fn = naive_gdn_gate + + beta = torch.randn(1, T, H, dtype=torch.float32).sigmoid() + h0 = torch.randn((N, H, D, D), dtype=torch.float32) + + if use_gate_in_kernel: + A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) + q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(), (q, k, v, g, beta, h0)) + torch.randn_like(v) + torch.rand_like(h0) + + tri, tri_ht = cula_gdn_fused_fwd( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + use_gate_in_kernel=use_gate_in_kernel, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + safe_gate=safe_gate, + lower_bound=lower_bound, + ) + + ref_fla, ref_ht_fla = chunk_gated_delta_rule( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), + v=v.clone(), + g=(naive_gdn_gate_fn(g.clone(), A_log, dt_bias) if use_gate_in_kernel else g.clone()), + beta=beta.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + ) + + ref = [] + ref_ht = [] + for i in range(N): + g_slice = naive_gdn_gate_fn(g[:, cu_seqlens[i] : cu_seqlens[i + 1]], A_log, dt_bias) if use_gate_in_kernel else g[:, cu_seqlens[i] : cu_seqlens[i + 1]] + ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( + q=F.normalize(q[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), + k=k[:, cu_seqlens[i] : cu_seqlens[i + 1]], + v=v[:, cu_seqlens[i] : cu_seqlens[i + 1]], + beta=beta[:, cu_seqlens[i] : cu_seqlens[i + 1]], + g=g_slice, + initial_state=h0[i], + output_final_state=True, + ) + ref.append(ref_i) + ref_ht.append(ref_ht_i) + ref = torch.cat(ref, 1) + ref_ht = torch.cat(ref_ht, 0) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + assert_close("o", ref_fla, tri, 0.005) + assert_close("ht", ref_ht_fla, tri_ht, 0.005) From 2aa1772cb2cb962ccf376a4ad0e1e3f72510f584 Mon Sep 17 00:00:00 2001 From: Kingsley Kim Date: Sun, 12 Apr 2026 17:24:27 -0400 Subject: [PATCH 2/3] fixed issues in gdn kernel, passes all tests in tests/gdn for sm90 fused fwd --- csrc/gdn/sm90/collective/common.hpp | 477 ++++ csrc/gdn/sm90/collective/load_predicated.hpp | 193 ++ csrc/gdn/sm90/collective/load_tma.hpp | 171 ++ csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp | 2079 +++++++++++++++++ csrc/gdn/sm90/collective/named_barriers.hpp | 44 + csrc/gdn/sm90/collective/store_tma.hpp | 313 +++ csrc/gdn/sm90/device/device_universal.hpp | 244 ++ csrc/gdn/sm90/gdn_fwd_sm90.cu | 144 ++ csrc/gdn/sm90/gdn_fwd_sm90_safe_gate.cu | 68 + csrc/gdn/sm90/kernel/builder_gdn_fwd.hpp | 80 + csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp | 630 +++++ csrc/gdn/sm90/kernel/options.hpp | 86 + csrc/gdn/sm90/kernel/tile_scheduler.hpp | 139 ++ csrc/gdn/sm90/prefill_kernel.hpp | 49 + csrc/gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh | 154 ++ csrc/gdn/sm90/utils/common.hpp | 59 + csrc/gdn/sm90/utils/debug.hpp | 149 ++ csrc/gdn/sm90/utils/math.hpp | 53 + csrc/gdn/sm90/utils/math_order_barrier.hpp | 116 + csrc/gdn/sm90/utils/type_traits.hpp | 66 + csrc/gdn/sm90/utils/unused.hpp | 54 + cula/gdn/hopper_fused_fwd.py | 10 - tests/test_gdn_fused_fwd.py | 8 +- 23 files changed, 5371 insertions(+), 15 deletions(-) create mode 100644 csrc/gdn/sm90/collective/common.hpp create mode 100644 csrc/gdn/sm90/collective/load_predicated.hpp create mode 100644 csrc/gdn/sm90/collective/load_tma.hpp create mode 100644 csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp create mode 100644 csrc/gdn/sm90/collective/named_barriers.hpp create mode 100644 csrc/gdn/sm90/collective/store_tma.hpp create mode 100644 csrc/gdn/sm90/device/device_universal.hpp create mode 100644 csrc/gdn/sm90/gdn_fwd_sm90.cu create mode 100644 csrc/gdn/sm90/gdn_fwd_sm90_safe_gate.cu create mode 100644 csrc/gdn/sm90/kernel/builder_gdn_fwd.hpp create mode 100644 csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp create mode 100644 csrc/gdn/sm90/kernel/options.hpp create mode 100644 csrc/gdn/sm90/kernel/tile_scheduler.hpp create mode 100644 csrc/gdn/sm90/prefill_kernel.hpp create mode 100644 csrc/gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh create mode 100644 csrc/gdn/sm90/utils/common.hpp create mode 100644 csrc/gdn/sm90/utils/debug.hpp create mode 100644 csrc/gdn/sm90/utils/math.hpp create mode 100644 csrc/gdn/sm90/utils/math_order_barrier.hpp create mode 100644 csrc/gdn/sm90/utils/type_traits.hpp create mode 100644 csrc/gdn/sm90/utils/unused.hpp diff --git a/csrc/gdn/sm90/collective/common.hpp b/csrc/gdn/sm90/collective/common.hpp new file mode 100644 index 0000000..735a285 --- /dev/null +++ b/csrc/gdn/sm90/collective/common.hpp @@ -0,0 +1,477 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace gdn::sm90::collective { + +using namespace cute; + +template +CUTE_DEVICE void +gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + if constexpr (rA == 2 && rB == 2 && rC == 1) { + CUTE_UNROLL + for (int k_block = 0; k_block < size<1>(tA); k_block++) { + cute::gemm(atom, tA(_, k_block), tB(_, k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } else { + static_assert(rA == 3 && rB == 3 && rC == 3); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_, _, k_block), tB(_, _, k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } +} + +template +CUTE_DEVICE void +gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = GMMA::ScaleOut::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template < + template + class Primitive, + cute::GMMA::Major tA, + cute::GMMA::Major tB, + cute::GMMA::ScaleIn sA, + cute::GMMA::ScaleIn sB> +CUTE_DEVICE constexpr auto +convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template < + template + class Primitive, + cute::GMMA::ScaleIn sA, + cute::GMMA::ScaleIn sB> +CUTE_DEVICE constexpr auto +convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + constexpr auto tA = cute::GMMA::Major::K; + constexpr auto tB = cute::GMMA::Major::K; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template +CUTE_DEVICE constexpr auto +convert_to_gmma_rs(cute::TiledMMA const& tiled_mma) { + return cute::TiledMMA{}; +} + +template +CUTE_DEVICE constexpr auto +convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) { + return make_layout( + make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))), + make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0, 2>(c)))); +} + +template +CUTE_DEVICE constexpr auto +unstage_smem_layout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, make_layout(stages))); +} + +template +CUTE_DEVICE auto +make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) { + Tensor operand = + make_fragment_like(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv))); + Tensor operand_as_acc = make_tensor(operand.data(), acc.layout()); + + cute::copy(acc, operand_as_acc); + + if constexpr (sizeof(Element) == 1) { + // 00 11 22 33 00 11 22 33 acc layout + // 00 00 11 11 22 22 33 33 operand layout + // BB AA AA BB AA BB BB AA conflict-free exchange pattern + // 16-bit exchange; so process two at a time potentially + int tid = threadIdx.x % 4; + auto values_u32 = recast(operand); + + CUTE_UNROLL + for (int n = 0; n < size<1>(values_u32); n++) { + CUTE_UNROLL + for (int k = 0; k < size<2>(values_u32); k++) { + CUTE_UNROLL + for (int ii = 0; ii < 8; ii += 4) { + uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k); + uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k); + + // step A: + // t 1 v 0 -> t 0 v 1 + // t 2 v 0 -> t 1 v 0 + // t 0 v 1 -> t 2 v 0 + // t 3 v 1 -> t 3 v 1 + + int v_to_send = tid == 1 || tid == 2 ? 0 : 1; + int v_to_recv = v_to_send; + int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4); + + // step B: + // t 0 v 0 -> t 0 v 0 + // t 3 v 0 -> t 1 v 1 + // t 1 v 1 -> t 2 v 1 + // t 2 v 1 -> t 3 v 0 + + v_to_send = 1 - v_to_send; + v_to_recv = 1 - v_to_recv; + t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4); + + values_u32(ii / 2 + 0, n, k) = + __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410); + values_u32(ii / 2 + 1, n, k) = + __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632); + } + } + } + } + + return operand; +} + +// Convert float register values from BF16 MMA operand A layout to TF32 MMA operand A layout. +// +// Both SM80_16x8x8_F32BF16BF16F32_TN and SM80_16x8x8_F32TF32TF32F32_TN have the same +// per-thread fragment shape: ((_2,_2),_1,_4):((_1,_2),_0,_4) → 16 values per thread. +// But they map different (M, K) positions to each thread. +// +// BF16 LayoutA_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8)) +// t0 = tid % 4, t1 = tid / 4 (note: Shape<_4,_8> → colmajor → t0 = tid%4) +// v_flat = v0 + v1*2: v0 → K offset (stride 16), v1 → M offset (stride 8) +// Thread t0 holds K = {2*t0, 2*t0+1} (consecutive K per thread) +// +// TF32 LayoutA_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64)) +// v_flat = v0 + v1*2: v0 → M offset (stride 8), v1 → K offset (stride 64) +// Thread t0 holds K = {t0, t0+4} (stride-4 K per thread) +// +// Algorithm (two-phase shuffle): +// For each v0_tf32 (M selector in TF32), the source data is at v1_bf16 = v0_tf32 +// in BF16 layout. Both v1_tf32 outputs need the same BF16 value index but from +// different source threads. +// +// Phase 1: shuffle bf16_frag[idx0] (v0_bf16=0) from source thread +// Phase 2: shuffle bf16_frag[idx1] (v0_bf16=1) from source thread +// Then select phase1 or phase2 result based on t0%2. +// +// Supports BK=32 (NumKAtoms=4, 16 values) and BK=64 (NumKAtoms=8, 32 values). +// The shuffle pattern is identical per k-atom; only the loop count changes. +// +// @param frag_A Float tensor with NumKAtoms*4 values in BF16 MMA A layout, converted +// in-place to TF32 MMA A layout. Values remain as float; the TF32 MMA +// hardware will truncate mantissa bits automatically during execution. +// Caller uses recast(frag_A) to obtain a typed view if needed. +// @param local_thread_idx Thread index within the MMA tile (0..63) +template +CUTE_DEVICE void +convert_bf16_to_tf32_operandA_layout(FragA& frag_A, int local_thread_idx) { + static_assert(cute::is_tensor>::value); + static_assert(NumKAtoms == 4 || NumKAtoms == 8, "Only BK=32 (4 k-atoms) and BK=64 (8 k-atoms) supported"); + static_assert(decltype(size(frag_A))::value == NumKAtoms * 4); + // Fragment must hold float values (gated results are already in float). + // MMA hardware will truncate to tf32 precision automatically. + using ElemType = typename cute::remove_cvref_t::value_type; + static_assert(cute::is_same_v, "Fragment must be float; tf32 truncation is done by MMA hw"); + + int tid = local_thread_idx % 32; // lane within warp + int t0 = tid % 4; + bool sel_odd = (t0 & 1); // t0%2: selects v0_bf16=1 result + + // Source lane for v1_tf32=0: t0_src = t0/2, lane = t0_src + (tid & ~3) + // Source lane for v1_tf32=1: t0_src = t0/2+2, lane = (t0/2+2) + (tid & ~3) + int src_lane_lo = (t0 / 2) + (tid & ~3); + int src_lane_hi = (t0 / 2 + 2) + (tid & ~3); + + // Process NumKAtoms k-iterations, each with 4 values: [4j+0, 4j+1, 4j+2, 4j+3] + // BF16 fragment layout per k-iter: (v0_bf16=0,v1_bf16=0), (v0_bf16=1,v1_bf16=0), + // (v0_bf16=0,v1_bf16=1), (v0_bf16=1,v1_bf16=1) + // TF32 output layout per k-iter: (v0_tf32=0,v1_tf32=0), (v0_tf32=1,v1_tf32=0), + // (v0_tf32=0,v1_tf32=1), (v0_tf32=1,v1_tf32=1) + CUTE_UNROLL + for (int j = 0; j < NumKAtoms; j++) { + // Read all 4 input values for this k-iter before writing any output, + // to avoid read-after-write hazard (in-place update). + float in0 = frag_A(0 + 4 * j); // v0_bf16=0, v1_bf16=0 + float in1 = frag_A(1 + 4 * j); // v0_bf16=1, v1_bf16=0 + float in2 = frag_A(2 + 4 * j); // v0_bf16=0, v1_bf16=1 + float in3 = frag_A(3 + 4 * j); // v0_bf16=1, v1_bf16=1 + + // For v0_tf32=0: M is selected by v1_bf16=0, so source values are (in0, in1) + // For v0_tf32=1: M is selected by v1_bf16=1, so source values are (in2, in3) + float out_vals[4]; + CUTE_UNROLL + for (int v0_tf32 = 0; v0_tf32 < 2; v0_tf32++) { + float val0 = (v0_tf32 == 0) ? in0 : in2; // v0_bf16=0 at chosen v1_bf16 + float val1 = (v0_tf32 == 0) ? in1 : in3; // v0_bf16=1 at chosen v1_bf16 + + // Shuffle to get values from source threads + float recv0_lo = __shfl_sync(0xFFFFFFFF, val0, src_lane_lo); + float recv1_lo = __shfl_sync(0xFFFFFFFF, val1, src_lane_lo); + float recv0_hi = __shfl_sync(0xFFFFFFFF, val0, src_lane_hi); + float recv1_hi = __shfl_sync(0xFFFFFFFF, val1, src_lane_hi); + + // Select based on t0%2: even → v0_bf16=0, odd → v0_bf16=1 + out_vals[v0_tf32 + 0] = sel_odd ? recv1_lo : recv0_lo; // v1_tf32=0 + out_vals[v0_tf32 + 2] = sel_odd ? recv1_hi : recv0_hi; // v1_tf32=1 + } + + // Write all 4 output values + frag_A(0 + 4 * j) = out_vals[0]; + frag_A(1 + 4 * j) = out_vals[1]; + frag_A(2 + 4 * j) = out_vals[2]; + frag_A(3 + 4 * j) = out_vals[3]; + } +} + +// Convert float register values from BF16 MMA operand B layout to TF32 MMA operand B layout. +// +// BF16 LayoutB_TV: ((_4,_8),_2):((_16,_1),_8) +// Thread t0 holds K = {2*t0, 2*t0+1}. v selects K offset (consecutive). +// +// TF32 LayoutB_TV: ((_4,_8),_2):((_8,_1),_32) +// Thread t0 holds K = {t0, t0+4}. v selects K offset (stride-4). +// +// Same two-phase shuffle approach as operand A, but B has only 2 values per atom +// (no M dimension in the value index). +// +// Supports BK=32 (NumKAtoms=4, 8 values) and BK=64 (NumKAtoms=8, 16 values). +// +// @param frag_B Float tensor with NumKAtoms*2 values in BF16 MMA B layout, converted +// in-place. Values remain as float; TF32 MMA hardware truncates automatically. +// Caller uses recast(frag_B) to obtain a typed view if needed. +// @param local_thread_idx Thread index within the MMA tile (0..63) +template +CUTE_DEVICE void +convert_bf16_to_tf32_operandB_layout(FragB& frag_B, int local_thread_idx) { + static_assert(cute::is_tensor>::value); + static_assert(NumKAtoms == 4 || NumKAtoms == 8, "Only BK=32 (4 k-atoms) and BK=64 (8 k-atoms) supported"); + static_assert(decltype(size(frag_B))::value == NumKAtoms * 2); + // Fragment must hold float values; MMA hardware truncates to tf32 automatically. + using ElemType = typename cute::remove_cvref_t::value_type; + static_assert(cute::is_same_v, "Fragment must be float; tf32 truncation is done by MMA hw"); + + int tid = local_thread_idx % 32; + int t0 = tid % 4; + bool sel_odd = (t0 & 1); + + int src_lane_lo = (t0 / 2) + (tid & ~3); + int src_lane_hi = (t0 / 2 + 2) + (tid & ~3); + + // Process NumKAtoms k-iterations, each with 2 values: [2j, 2j+1] + CUTE_UNROLL + for (int j = 0; j < NumKAtoms; j++) { + int idx0 = 2 * j; // BF16 v=0, K = 2*t0 + int idx1 = 2 * j + 1; // BF16 v=1, K = 2*t0+1 + + float val0 = frag_B(idx0); + float val1 = frag_B(idx1); + + // v_tf32=0: need K=t0 from src_t0=t0/2 + float recv0_lo = __shfl_sync(0xFFFFFFFF, val0, src_lane_lo); + float recv1_lo = __shfl_sync(0xFFFFFFFF, val1, src_lane_lo); + // v_tf32=1: need K=t0+4 from src_t0=t0/2+2 + float recv0_hi = __shfl_sync(0xFFFFFFFF, val0, src_lane_hi); + float recv1_hi = __shfl_sync(0xFFFFFFFF, val1, src_lane_hi); + + frag_B(idx0) = sel_odd ? recv1_lo : recv0_lo; + frag_B(idx1) = sel_odd ? recv1_hi : recv0_hi; + } +} + +// Broadcast row 0 from a BF16 MMA operand A fragment and output directly +// into a BF16 MMA operand B fragment. +// +// Combines broadcast_row0 + extract_A_to_B into one step, avoiding the +// intermediate 16-float operand A broadcast tensor (saves 8 float registers). +// +// Since g_first is broadcast (all M rows identical), operand B only needs +// the K-dimension values. We shuffle v1=0 values from the t1=0 thread +// and output the 8-value B fragment directly. +// +// BF16 A layout: t0 = tid % 4, t1 = tid / 4. Row 0 at t1=0, v1=0. +// frag_A(4j+0) = (v0=0, v1=0), frag_A(4j+1) = (v0=1, v1=0) +// BF16 B layout: frag_B(2j+0) = v=0, frag_B(2j+1) = v=1 +// Both have K = {2*t0, 2*t0+1} at same positions. +// +// Supports BK=32 (NumKAtoms=4) and BK=64 (NumKAtoms=8). +// +// Cost: 2 shuffles per k-iter × NumKAtoms k-iters. +// Saves: NumKAtoms*4-float intermediate tensor (vs broadcast_row0 + extract). +// +// @param frag_A Input: alpha[m, k] in BF16 MMA A layout (NumKAtoms*4 values) +// @param frag_B_first Output: alpha[0, k] in BF16 MMA B layout (NumKAtoms*2 values) +// @param local_thread_idx Thread index within the MMA tile (0..63) +template +CUTE_DEVICE void +broadcast_row0_operandA_to_operandB_bf16_layout(FragA const& frag_A, FragB& frag_B_first, int local_thread_idx) { + static_assert(cute::is_tensor>::value); + static_assert(cute::is_tensor>::value); + static_assert(NumKAtoms == 4 || NumKAtoms == 8, "Only BK=32 (4 k-atoms) and BK=64 (8 k-atoms) supported"); + static_assert(decltype(size(frag_A))::value == NumKAtoms * 4); + static_assert(decltype(size(frag_B_first))::value == NumKAtoms * 2); + + int tid = local_thread_idx % 32; // lane within warp + // Row 0 is at t1=0. In BF16 A layout, t0 = tid % 4, t1 = tid / 4. + // Source lane: same t0, t1=0 → src = tid % 4. + int src_lane = tid % 4; + + CUTE_UNROLL + for (int j = 0; j < NumKAtoms; j++) { + // Shuffle v1=0 values from thread with t1=0 (row 0 holder) + // frag_A(4j+0) = (v0=0, v1=0) → K=2*t0 + // frag_A(4j+1) = (v0=1, v1=0) → K=2*t0+1 + auto val0 = __shfl_sync(0xFFFFFFFF, frag_A(4 * j + 0), src_lane); + auto val1 = __shfl_sync(0xFFFFFFFF, frag_A(4 * j + 1), src_lane); + + // Output directly into B layout: frag_B(2j) = K_lo, frag_B(2j+1) = K_hi + frag_B_first(2 * j + 0) = val0; + frag_B_first(2 * j + 1) = val1; + } +} + +// Broadcast row 0 across all M rows in a BF16 MMA operand A fragment. +// +// Given frag_A holding alpha[m, k] per thread, produces frag_A_first holding alpha[0, k] +// for all m (broadcast). This eliminates a redundant S2R load of g_first from shared memory. +// +// BF16 LayoutA_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8)) +// tid decomposition: t0 = tid % 4, t1 = tid / 4 +// m = t1 + v1*8 (v1 selects M row within thread) +// Row 0 is held by threads with t1=0, at v1=0 positions. +// +// Algorithm: +// For each k-iter, shuffle v1=0 values from thread (tid % 4) (same t0, t1=0), +// then replicate to v1=1 positions. +// +// Supports BK=32 (NumKAtoms=4) and BK=64 (NumKAtoms=8). +// +// Cost: 2 shuffles per k-iter × NumKAtoms k-iters. +// +// @param frag_A Input: alpha[m, k] in BF16 MMA A layout (NumKAtoms*4 values) +// @param frag_A_first Output: alpha[0, k] broadcast in BF16 MMA A layout (NumKAtoms*4 values) +// @param local_thread_idx Thread index within the MMA tile (0..63) +template +CUTE_DEVICE void +broadcast_row0_operandA_bf16_layout(FragA const& frag_A, FragAFirst& frag_A_first, int local_thread_idx) { + static_assert(cute::is_tensor>::value); + static_assert(cute::is_tensor>::value); + static_assert(NumKAtoms == 4 || NumKAtoms == 8, "Only BK=32 (4 k-atoms) and BK=64 (8 k-atoms) supported"); + static_assert(decltype(size(frag_A))::value == NumKAtoms * 4); + static_assert(decltype(size(frag_A_first))::value == NumKAtoms * 4); + + int tid = local_thread_idx % 32; // lane within warp + // Row 0 is at t1=0 → tid % 4 (keep same t0, set t1=0) + int src_lane = tid % 4; + + CUTE_UNROLL + for (int j = 0; j < NumKAtoms; j++) { + // v1=0 positions hold m=t1, v1=1 positions hold m=t1+8 + // We want m=0, which is at t1=0, v1=0 → indices 4j+0 and 4j+1 + auto val0 = __shfl_sync(0xFFFFFFFF, frag_A(4 * j + 0), src_lane); // alpha[0, 2*t0] from t1=0 + auto val1 = __shfl_sync(0xFFFFFFFF, frag_A(4 * j + 1), src_lane); // alpha[0, 2*t0+1] from t1=0 + + // Broadcast to both v1=0 and v1=1 (same value, different M rows) + frag_A_first(4 * j + 0) = val0; // v0=0, v1=0 + frag_A_first(4 * j + 1) = val1; // v0=1, v1=0 + frag_A_first(4 * j + 2) = val0; // v0=0, v1=1 (broadcast) + frag_A_first(4 * j + 3) = val1; // v0=1, v1=1 (broadcast) + } +} + +// Extract BF16 MMA operand B fragment from a BF16 MMA operand A fragment, +// for data that is **broadcast across M rows** (e.g., g_first = g[row=0, :]). +// +// When the source data is broadcast (identical for all M rows), the K-dimension +// mapping is the same in both A and B BF16 MMA layouts: +// A: thread t0 holds K = {2*t0, 2*t0+1} at v0={0,1}, with v1 selecting M row +// B: thread t0 holds K = {2*t0, 2*t0+1} at v={0,1} +// +// Since M rows are identical (broadcast), we can simply pick v1=0 from A: +// frag_B(2j + 0) = frag_A(4j + 0) // v0_bf16=0 → K=2*t0 +// frag_B(2j + 1) = frag_A(4j + 1) // v0_bf16=1 → K=2*t0+1 +// +// Supports BK=32 (NumKAtoms=4) and BK=64 (NumKAtoms=8). +// +// This avoids a redundant S2R load from shared memory. +// No warp shuffles needed — purely register-local extraction. +// +// @param frag_A Float tensor with NumKAtoms*4 values in BF16 MMA A layout (broadcast data) +// @param frag_B Float tensor with NumKAtoms*2 values in BF16 MMA B layout (output) +template +CUTE_DEVICE void +extract_broadcast_operandA_to_operandB_bf16_layout(FragA const& frag_A, FragB& frag_B) { + static_assert(cute::is_tensor>::value); + static_assert(cute::is_tensor>::value); + static_assert(NumKAtoms == 4 || NumKAtoms == 8, "Only BK=32 (4 k-atoms) and BK=64 (8 k-atoms) supported"); + static_assert(decltype(size(frag_A))::value == NumKAtoms * 4); + static_assert(decltype(size(frag_B))::value == NumKAtoms * 2); + + CUTE_UNROLL + for (int j = 0; j < NumKAtoms; j++) { + // A layout per k-iter: [4j+0]=(v0=0,v1=0), [4j+1]=(v0=1,v1=0), [4j+2]=(v0=0,v1=1), [4j+3]=(v0=1,v1=1) + // B layout per k-iter: [2j+0]=v=0, [2j+1]=v=1 + // For broadcast data, v1 doesn't matter, so pick v1=0: + frag_B(2 * j + 0) = frag_A(4 * j + 0); // K = 2*t0 + frag_B(2 * j + 1) = frag_A(4 * j + 1); // K = 2*t0+1 + } +} + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/collective/load_predicated.hpp b/csrc/gdn/sm90/collective/load_predicated.hpp new file mode 100644 index 0000000..9ebeb7d --- /dev/null +++ b/csrc/gdn/sm90/collective/load_predicated.hpp @@ -0,0 +1,193 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "gdn/sm90/utils/debug.hpp" +#include "gdn/sm90/utils/unused.hpp" + +namespace gdn::sm90::collective { + +using namespace cute; + +// Wraps a callable into a predicate "tensor" usable by CuTe's copy_if. +// copy_if calls pred(i) with a linear index; the wrapped function maps that index to bool. +template +struct FunctionPredTensor { + Fn fn_; + CUTE_HOST_DEVICE + FunctionPredTensor(Fn fn) : fn_(fn) { + } + template + CUTE_HOST_DEVICE bool + operator()(Idx const& i) const { + return fn_(i); + } +}; + +enum class LoadKindVector { + kAlpha, + kBeta, +}; + +CUTE_HOST_DEVICE constexpr char const* +to_string(LoadKindVector kind) { + if (kind == LoadKindVector::kAlpha) { + return "alpha"; + } else if (kind == LoadKindVector::kBeta) { + return "beta"; + } else { + return "unknown loadkind"; + } +} + +template < + LoadKindVector kKind, + class Pipeline, + class ElementSrc, + class GmemLayout, + class ElementDst, + class SmemLayout, + class VectorProcessor_ = Unused> +struct CollectiveLoadVector { + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + using VectorProcessor = VectorProcessor_; + + static_assert(rank_v == 2 || rank_v == 3); + + static constexpr LoadKindVector kind = kKind; + static constexpr int VectorSize = size<0>(SmemLayout{}); + + CUTE_DEVICE + CollectiveLoadVector( + ElementSrc const* src, GmemLayout layout, ElementSrc oob_value, Pipeline& pipeline, SharedStorage& storage) + : src_(src), src_layout_(layout), src_oob_value_(oob_value), pipeline_(pipeline), storage_(storage) { + } + + template + CUTE_DEVICE auto + partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + + Tensor g = [&] { + auto head_idx = work_desc.o_head_idx(); + DPRINTF0_W( + "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", + to_string(kind), + work_desc.seq_idx, + head_idx, + work_desc.tok_offset); + Tensor m_varlen_head = make_tensor(make_gmem_ptr(src_), src_layout_); + + Tensor m_varlen = m_varlen_head(_, head_idx); // slice into current head_idx + Tensor m_offset = + domain_offset(make_coord(work_desc.tok_offset), m_varlen); // offset to start of the current sequence + Tensor g_full = flat_divide(m_offset, BlkSeqQ); // (blk, iter_blk) + return g_full; + }(); + // (blk, pipe) or (blk, pipe, N), N for feature rich preprocess, data will be stored at 0 + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayout{}); + + auto thr_layout = Layout<_32>{}; + auto val_layout = Layout<_1>{}; + auto tiled_copy = make_tiled_copy(Copy_Atom, ElementDst>{}, thr_layout, val_layout); + auto thr_copy = tiled_copy.get_thread_slice(cutlass::canonical_lane_idx()); + + auto coord = thr_copy.partition_S(make_identity_tensor(Shape, _1>{})); + int seq_len = work_desc.chunk_len(); + auto len_of_last_blk = seq_len - (ceil_div(seq_len, BlkSeqQ) - 1) * BlkSeqQ; + + auto mask = FunctionPredTensor([coord, len_of_last_blk](auto frag_coord) { + auto coord_in_blk = get<0>(coord(frag_coord)); + return coord_in_blk < len_of_last_blk; + }); + + auto src = thr_copy.partition_S(g); // (cpy, iter_cpy, iter_blk) + auto dst = thr_copy.partition_D(s); // (cpy, iter_cpy, pipe) + + return make_tuple(src, dst, mask); + } + + template + CUTE_DEVICE void + step(SrcDst const& src_dst, int src_iter, PipelineState& dst_pipe, int num_iters, VectorProcessor processor = {}) { + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + auto regs = make_fragment_like(take<0, 2>(shape(dst))); + if constexpr (!IsTail) { + copy(src(_, _, src_iter), regs); + } else { + auto mask = get<2>(src_dst); + fill(regs, src_oob_value_); + copy_if(mask, src(_, _, src_iter), regs); + } + + int dst_pipe_idx = dst_pipe.index(); + + DPRINTF0_WG("%s pipeline.producer_acquire smem_pipe_write:%d\n", to_string(kind), dst_pipe_idx); + pipeline_.producer_acquire(dst_pipe); + cutlass::arch::fence_view_async_shared(); + + if constexpr (rank_v == 3) { + copy(regs, dst(_, _, _0{}, dst_pipe_idx)); + } else { + copy(regs, dst(_, _, dst_pipe_idx)); + } + + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayout{}); + if constexpr (!std::is_same_v) { + if constexpr (rank_v == 3) { + processor(s(_, _, dst_pipe_idx)); + } else { + processor(s(_, dst_pipe_idx)); + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_.producer_commit(dst_pipe); + ++dst_pipe; + } + + private: + ElementSrc const* src_; + GmemLayout src_layout_; // in (packed_seq, H) coordinate + ElementSrc src_oob_value_; + Pipeline& pipeline_; + SharedStorage& storage_; +}; + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/collective/load_tma.hpp b/csrc/gdn/sm90/collective/load_tma.hpp new file mode 100644 index 0000000..c491a32 --- /dev/null +++ b/csrc/gdn/sm90/collective/load_tma.hpp @@ -0,0 +1,171 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "gdn/sm90/utils/debug.hpp" + +namespace gdn::sm90::collective { + +using namespace cute; + +enum class LoadKind { + kQ, + kK, + kV, + kAlpha, +}; + +CUTE_HOST_DEVICE constexpr char const* +to_string(LoadKind kind) { + if (kind == LoadKind::kQ) { + return "Q"; + } else if (kind == LoadKind::kK) { + return "K"; + } else if (kind == LoadKind::kV) { + return "V"; + } else if (kind == LoadKind::kAlpha) { + return "Alpha"; + } else { + return "unknown loadkind"; + } +} + +template +struct CollectiveLoadTma { + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + static constexpr LoadKind kind = kKind; + TMA const& tma_load; + Pipeline& pipeline; + SharedStorage& storage; + + CUTE_DEVICE + CollectiveLoadTma(TMA const& tma_load, Pipeline& pipeline, SharedStorage& storage) + : tma_load(tma_load), pipeline(pipeline), storage(storage) { + } + + template + CUTE_DEVICE auto + partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + constexpr auto BlkSeqKV = decltype(get<1>(tile_shape))::value; + constexpr auto HeadSize = decltype(get<2>(tile_shape))::value; + + Tensor g = [&] { + if constexpr (kind == LoadKind::kQ) { + DPRINTF0_W( + "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", + to_string(kind), + work_desc.seq_idx, + work_desc.q_head_idx(), + work_desc.tok_offset); + Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( + problem_size.total_seqlen, + problem_size.head_size, + problem_size.num_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx + Tensor m_offset = domain_offset( + make_coord(work_desc.tok_offset, _0{}), + m_varlen); // offset to start of the current sequence + Tensor g_full = + local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) + return g_full; + } else if constexpr (kind == LoadKind::kAlpha) { // same as Q currently + DPRINTF0_W( + "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", + to_string(kind), + work_desc.seq_idx, + work_desc.q_head_idx(), + work_desc.tok_offset); + Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( + problem_size.total_seqlen, + problem_size.head_size, + problem_size.num_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.q_head_idx()); // slice into current head_idx + Tensor m_offset = domain_offset( + make_coord(work_desc.tok_offset, _0{}), + m_varlen); // offset to start of the current sequence + Tensor g_full = + local_tile(m_offset, make_tile(BlkSeqQ, HeadSize), make_coord(_, _0{})); // (blk, d, iter_blk) + return g_full; + } else { + auto head_idx = (kind == LoadKind::kK ? work_desc.k_head_idx() : work_desc.v_head_idx()); + DPRINTF0_W( + "slice view GMEM %s: seq_idx:%d head_idx:%d tok_offset:%lld\n", + to_string(kind), + work_desc.seq_idx, + head_idx, + work_desc.tok_offset); + Tensor m_varlen_head = tma_load.get_tma_tensor(make_shape( + problem_size.head_size, + problem_size.total_seqlen, + problem_size.num_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, head_idx); // slice into current head_idx + Tensor m_offset = domain_offset( + make_coord(_0{}, work_desc.tok_offset), + m_varlen); // offset to start of the current sequence + Tensor g_full = + local_tile(m_offset, make_tile(HeadSize, BlkSeqKV), make_coord(_0{}, _)); // (d, blk, iter_blk) + return g_full; + } + }(); + Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{}); + + auto block_tma = tma_load.get_slice(_0{}); // do not support cluster + return make_tuple(block_tma.partition_S(g), block_tma.partition_D(s)); + } + + template + CUTE_DEVICE void + step(SrcDst const& src_dst, int src_iter, PipelineState& dst_pipe, uint32_t lane_predicate) { + if (lane_predicate == 1) { + DPRINTF_WG("%s pipeline.producer_acquire smem_pipe_write:%d\n", to_string(kind), dst_pipe.index()); + if constexpr (kAcquireBarrier) { + pipeline.producer_acquire(dst_pipe); + } + using BarrierType = typename Pipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(dst_pipe); + + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + copy(tma_load.with(*tma_barrier), src(_, _, _, src_iter), dst(_, _, _, dst_pipe.index())); + ++dst_pipe; + } + } +}; + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp b/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp new file mode 100644 index 0000000..931c1a9 --- /dev/null +++ b/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp @@ -0,0 +1,2079 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "cute/numeric/numeric_types.hpp" +#include "cute/util/type_traits.hpp" +#include +#include + +#include "kerutils/kerutils.cuh" + +#include "gdn/sm90/collective/common.hpp" +#include "gdn/sm90/collective/load_predicated.hpp" +#include "gdn/sm90/collective/load_tma.hpp" +#include "gdn/sm90/collective/named_barriers.hpp" +#include "gdn/sm90/collective/store_tma.hpp" +#include "gdn/sm90/kernel/options.hpp" +#include "gdn/sm90/utils/debug.hpp" +#include "gdn/sm90/utils/math_order_barrier.hpp" +#include "gdn/sm90/utils/unused.hpp" + +// #define INLINE_LAMBDA [[gnu::always_inline]] +#define INLINE_LAMBDA __attribute__((always_inline)) +// #define INLINE_LAMBDA [[msvc::forceinline]] + +#define WORKAROUND_WGMMA_PERFORMANCE_LOSS() \ + if (thread_idx > 8192) { \ + __syncwarp(); \ + } + +namespace gdn::sm90::collective { + +struct GdnNamedBarriers : FlatSharedNamedBarriers { + static constexpr int StateMath = FlatSharedNamedBarriers::NumBarriersUsed + 0; + static constexpr int AuxMath = FlatSharedNamedBarriers::NumBarriersUsed + 1; + static constexpr int StateMathWG0 = FlatSharedNamedBarriers::NumBarriersUsed + 2; + // NOTE: only for debug + // used for subchunk MMA with two groups, each group has 2 warps + // static constexpr int AuxMathWarp0 = FlatSharedNamedBarriers::NumBarriersUsed + 3; + // static constexpr int AuxMathWarp1 = FlatSharedNamedBarriers::NumBarriersUsed + 4; +}; + +using ku::alignment_for_swizzle; +using ku::select_layout; +using ku::select_tensor; +using namespace cute; +using gdn::sm90::kernel::find_option_t; +using gdn::sm90::kernel::Tag; + +template < + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorKV_, + class TileShape_, // (seqlen_q, seqlen_kv, d) + class LayoutQ_, + class LayoutK_, + class LayoutV_, + class LayoutO_, // (seqlen_q/k, d, h) + class Options> +struct FlatMainloopTmaWarpSpecializedGdnFwd { + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorO = ElementAccumulatorQK; + using ElementAccumulatorKV = ElementAccumulatorKV_; + using ElementO = Element; + using ElementAlpha = float; + // TODO: support bf16 beta + using ElementBeta = float; + using ElementGatedMMA = cutlass::tfloat32_t; + + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; // (seqlen_q, d, h) + using LayoutK = LayoutK_; // (seqlen_k, d, h) + using LayoutV = LayoutV_; // (seqlen_k, d, h) + using LayoutO = LayoutO_; // (seqlen_k, d, h) + using LayoutAlpha = LayoutQ_; // (seqlen_q, d, h) + + // Options + static constexpr bool kIsPersistent = find_option_t::value; + + static constexpr bool kInitStateFromInput = find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumStateMmaWarpGroups = 2; + static constexpr int NumAuxMmaWarpGroups = 1; + + static constexpr int StageCountQ = find_option_t, Options>::value; + static constexpr int StageCountK = find_option_t, Options>::value; + static constexpr int StageCountV = find_option_t, Options>::value; + + static constexpr int NeedsAlpha = find_option_t::value; + static constexpr int NeedsBeta = find_option_t::value; + static_assert(NeedsAlpha && NeedsBeta, "Alpha and Beta are both used in GDN."); + + static constexpr int SafeGate = true; // only support safe_gate=true + + static constexpr int NumLoadThreads = NumLoadWarpGroups * 128; + static constexpr int NumStateMmaThreads = NumStateMmaWarpGroups * 128; + static constexpr int NumAuxMmaThreads = NumAuxMmaWarpGroups * 128; + + static constexpr uint32_t OrderedBarrierId0 = uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier0); + static constexpr uint32_t OrderedBarrierId1 = uint32_t(cutlass::arch::ReservedNamedBarriers::StreamkBarrier1); + + using OrderedMathBarriers = std::conditional_t< + NumStateMmaWarpGroups == 2, + OrderedNamedBarriers, + OrderedNamedBarriers>; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesK = cutlass::gemm::collective::StageCount; + using StagesV = cutlass::gemm::collective::StageCount; + using StagesQ_K_Scaled = cutlass::gemm::collective::StageCount<2>; + using StagesO = cutlass::gemm::collective::StageCount<1>; + using ClusterShape = Shape<_1, _1, _1>; + + using StagesQK = cutlass::gemm::collective::StageCount<2>; + using StagesKK = cutlass::gemm::collective::StageCount<2>; + + using StagesAlpha = cutlass::gemm::collective::StageCount<2>; + using StagesBeta = cutlass::gemm::collective::StageCount<2>; + + static constexpr int Alignment = 16 / sizeof(Element); + + static constexpr auto BlkSeqQ = get<0>(TileShape{}); // Blk_Q + static constexpr auto BlkSeqKV = get<1>(TileShape{}); // Blk_K/V + static constexpr auto HeadSize = get<2>(TileShape{}); // D (Dq, Dk, Dv all equal) + static constexpr auto HeadSizeQK = HeadSize; + static constexpr auto HeadSizeV = HeadSize; + using HeadSizeHalf = _64; + using HeadSizeQuar = _32; + + using TileShapeQK = decltype(make_shape(BlkSeqQ, BlkSeqKV, HeadSizeQK)); + // used for element-wise in compute_aux prologue, to reduce register usage + using TileShapeQK_Half = decltype(make_shape(BlkSeqQ, BlkSeqKV, HeadSizeHalf{})); + using TileShapeQK_Quar = decltype(make_shape(BlkSeqQ, BlkSeqKV, HeadSizeQuar{})); + using TileShapeKK = decltype(make_shape(BlkSeqKV, BlkSeqKV, HeadSizeQK)); + using TileShapeKV = decltype(make_shape(HeadSizeV, HeadSizeQK, BlkSeqKV)); + static_assert(std::is_same_v); + + using TileShapeO2 = decltype(make_shape(HeadSizeV, BlkSeqQ, BlkSeqKV)); + using TileShapeO1 = decltype(make_shape(HeadSizeV, BlkSeqQ, HeadSizeQK)); + + static_assert(BlkSeqQ % 64 == 0); + static_assert(BlkSeqQ == 64 || BlkSeqQ == 128); + static_assert(BlkSeqQ == BlkSeqKV); + static constexpr bool IsQKCooperative = BlkSeqQ == 128; + static constexpr bool IsKKCooperative = IsQKCooperative; + + using DummyStages = cutlass::gemm::collective::StageCount<2>; + ; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + LayoutQ, + Alignment, + Element, + LayoutK, + Alignment, + ElementAccumulatorQK, + TileShapeQK, + ClusterShape, + DummyStages, + std::conditional_t< + IsQKCooperative, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::gemm::KernelTmaWarpSpecialized>>::CollectiveOp; + + // dummy TiledMmaQK RS for S2R/R2S layout consistency + using AtomLayoutQK = Layout, _1, _1>>; + using TiledMmaQK_RS = decltype(make_tiled_mma( + decltype(cute::GMMA::rs_op_selector()){}, AtomLayoutQK{})); + static_assert(size(TiledMmaQK_RS{}) == NumAuxMmaThreads); + using TiledMmaQK_RS_Quar = decltype(make_tiled_mma( + decltype(cute::GMMA::rs_op_selector()){}, + AtomLayoutQK{})); + static_assert(size(TiledMmaQK_RS_Quar{}) == NumAuxMmaThreads); + + using CollectiveMmaKV_G2S = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + decltype(select<1, 0, 2>(LayoutV{})), + Alignment, // direct TMA copy for GMEM -> SMEM + Element, + decltype(select<1, 0, 2>(LayoutK{})), + Alignment, + ElementAccumulatorKV, + TileShapeKV, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // using SmemLayoutAlphaAtom = GMMA::Layout_K_SW128_Atom; + // using SmemLayoutAlpha_SD = decltype(tile_to_shape( + // SmemLayoutAlphaAtom{}, + // make_shape( + // shape<1>(TileShapeQK{}), + // Int{}))); // (blk_kv), (64) + + // using GmemTiledCopyAlpha = cute::SM90_TMA_LOAD; + // using TMA_Alpha = decltype(make_tma_copy( + // GmemTiledCopyAlpha{}, + // make_tensor(make_gmem_ptr(static_cast(nullptr)), GmemLayoutAlpha{}), + // take<0, 2>(SmemLayoutAlpha_SD{}), + // select<1, 2>(TileShapeQK{}), + // size<0>(ClusterShape{}))); + + // raw layout for copy + using SmemLayoutQ_SD = + decltype(unstage_smem_layout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutQ_K_Scaled_SD = + decltype(unstage_smem_layout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK_DS = + decltype(unstage_smem_layout(typename CollectiveMmaKV_G2S::SmemLayoutB{}, Int{})); + using SmemLayoutQ_K_Scaled_DS = + decltype(unstage_smem_layout(typename CollectiveMmaKV_G2S::SmemLayoutB{}, Int{})); + using SmemLayoutV_DS = + decltype(unstage_smem_layout(typename CollectiveMmaKV_G2S::SmemLayoutA{}, Int{})); + + // Layout for V^T + using RefLayoutV = decltype(make_layout(select<0, 2>(TileShapeKV{}), LayoutRight{})); + using CollectiveMmaKV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + RefLayoutV, + Alignment, // needs a S2R transposition for MMA + Element, + decltype(select<1, 0, 2>(LayoutK{})), + Alignment, + ElementAccumulatorKV, + TileShapeKV, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using RefLayoutKV = decltype(make_layout(select<0, 1>(TileShapeKV{}), LayoutRight{})); // (dv, dk) + using CollectiveMmaO1 = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + RefLayoutKV, + Alignment, // NOTE: S (KV) as operand A + Element, + LayoutQ, + Alignment, + ElementAccumulatorO, + TileShapeO1, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // (blk_q,blk_k) to align with O2 mma, LayoutRight to align with QK mma output + using DesiredLayoutQK = decltype(make_layout(select<0, 1>(TileShapeQK{}), LayoutRight{})); + using CollectiveMmaO2 = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + RefLayoutV, + Alignment, // V^T + Element, + DesiredLayoutQK, + Alignment, + ElementAccumulatorO, + TileShapeO2, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; // Q@K^t + using TiledMmaKV = decltype(convert_to_gmma_rs(typename CollectiveMmaKV::TiledMma{})); + using TiledMmaO1 = decltype(convert_to_gmma_rs(typename CollectiveMmaO1::TiledMma{})); + using TiledMmaO2 = decltype(convert_to_gmma_rs(typename CollectiveMmaO2::TiledMma{})); + + static_assert(size(TiledMmaQK{}) == NumAuxMmaThreads); + + static_assert(size(TiledMmaKV{}) == NumStateMmaThreads); + static_assert(size(TiledMmaO1{}) == NumStateMmaThreads); + static_assert(size(TiledMmaO2{}) == NumStateMmaThreads); + + using CollectiveStoreO = CollectiveStoreTma< + TileShapeO1, + ClusterShape, + ElementO, + ElementAccumulatorO, + /*Smem*/ ElementO, + decltype(select<1, 0, 2>(LayoutO{})), // creates mn-major atom in store_tma.hpp + StagesO::value>; + + // layout for compute + using QKSmemLayoutQ = SmemLayoutQ_SD; + using QKSmemLayoutK = decltype(select_layout<1, 0, 2>(SmemLayoutK_DS{})); + using QKScaledSmemLayoutQ = SmemLayoutQ_K_Scaled_SD; + + using KVSmemLayoutK = SmemLayoutK_DS; + using KVSmemLayoutV = SmemLayoutV_DS; + using QKScaledSmemLayoutKt = SmemLayoutQ_K_Scaled_DS; + + + // layout for compute output + using SmemLayoutQK = decltype(tile_to_shape( + GMMA::Layout_K_INTER_Atom{}, + flatten(make_shape(select<0, 1>(TileShapeQK{}), Int{})), + Step<_1, _2, _3>{})); + using SmemLayoutO = typename CollectiveStoreO::SmemLayoutO; + + using SmemLayoutKK = decltype(tile_to_shape( + GMMA::Layout_K_INTER_Atom{}, + flatten(make_shape(select<0, 1>(TileShapeQK{}), Int{})), + Step<_1, _2, _3>{})); + + using InverseType = cutlass::half_t; + using CollectiveInverse = ku::CollectiveInverse; + + using ElementAccumulatorSK = float; + using TileShapeSK = decltype(make_shape(HeadSizeV, BlkSeqKV, HeadSizeQK)); + using CollectiveMmaSK = typename cutlass::gemm::collective::CollectiveBuilder< // basically the same as O1 + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + RefLayoutKV, + Alignment, + Element, + LayoutK, + Alignment, + ElementAccumulatorSK, + TileShapeSK, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + using ElementAccumulatorNewV = float; + using TileShapeNewV = decltype(make_shape(HeadSizeV, BlkSeqKV, BlkSeqKV)); + using RefLayoutSK = decltype(make_layout(select<0, 2>(TileShapeNewV{}), LayoutRight{})); // (dv, Blk) + using DesiredLayoutKK = decltype(make_layout(select<1, 2>(TileShapeNewV{}), LayoutRight{})); // + using CollectiveMmaNewV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, + cutlass::arch::OpClassTensorOp, + Element, + RefLayoutSK, + Alignment, + Element, + DesiredLayoutKK, + Alignment, + ElementAccumulatorKV, + TileShapeNewV, + ClusterShape, + DummyStages, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + // FIXME: K@K^t are not exactly the same as Q@K^t, but similar enough (what does this mean??) + using TiledMmaKK = typename CollectiveMmaQK::TiledMma; // T = inv(I + strict_lower_triangular(K@K^t)) + using TiledMmaSK = decltype(convert_to_gmma_rs(typename CollectiveMmaSK::TiledMma{})); // ?? = -S@K^t + V^t + using TiledMmaNewV = decltype(convert_to_gmma_rs(typename CollectiveMmaNewV::TiledMma{})); // NewV = ??@T^t + + static_assert(size(TiledMmaKK{}) == NumAuxMmaThreads); + + using GmemStrideBeta = Stride; + using GmemLayoutBeta = Layout, GmemStrideBeta>; // (seq, head) + + using GmemShapeAlpha = Shape; // (seqlen_k, h) + using GmemStrideAlpha = Stride; // TODO: depends on gate cumsum output, so we won't hardset to 1 + using GmemLayoutAlpha = Layout; + + // only store the last Alpha value, either end of chunk or end of sequence + using SmemLayoutAlphaLast = decltype(make_layout(make_shape(_1{}, Int{}))); + using SmemLayoutBeta = decltype(make_layout(make_shape(BlkSeqQ, Int{}))); + using SmemLayoutAlpha = decltype(make_layout(make_shape(BlkSeqQ, Int{}))); + + using MainloopQPipeline = cutlass::PipelineTmaAsync; + using MainloopKPipeline = cutlass::PipelineTmaAsync; + using MainloopVPipeline = cutlass::PipelineTmaAsync; + using MainloopAlphaPipeline = std::conditional_t, Unused>; + using MainloopOPipeline = typename CollectiveStoreO::Pipeline; + + using MainloopQKPipeline = cutlass::PipelineAsync; + using MainloopKKPipeline = cutlass::PipelineAsync; + + using MainloopAlphaLastPipeline = + std::conditional_t, Unused>; + + using MainloopBetaPipeline = std::conditional_t, Unused>; + + using QPipelineState = typename cutlass::PipelineState; + using KPipelineState = typename cutlass::PipelineState; + using VPipelineState = typename cutlass::PipelineState; + using OPipelineState = typename CollectiveStoreO::PipelineState; + + using QKPipelineState = cutlass::PipelineState; + using KKPipelineState = cutlass::PipelineState; + + using AlphaLastPipelineState = + std::conditional_t, Unused>; + + using AlphaPipelineState = + std::conditional_t, Unused>; + using BetaPipelineState = + std::conditional_t, Unused>; + + using AlphaProcessor = Unused; + using BetaProcessor = Unused; + + static constexpr int LoadQBytes = size(QKSmemLayoutQ{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadKBytes = size(KVSmemLayoutK{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadVBytes = size(KVSmemLayoutV{}(_, _, _0{})) * sizeof(Element); + static constexpr int LoadAlphaBytes = size(SmemLayoutAlpha{}(_, _, _0{})) * sizeof(ElementAlpha); + static constexpr int StoreOBytes = CollectiveStoreO::TmaTransactionBytes; + + using SharedStorageO = typename CollectiveStoreO::SharedStorage; + + struct SharedStorage { + alignas( + alignment_for_swizzle(QKSmemLayoutQ{})) cute::array_aligned> smem_q; + alignas( + alignment_for_swizzle(KVSmemLayoutK{})) cute::array_aligned> smem_k; + alignas( + alignment_for_swizzle(KVSmemLayoutV{})) cute::array_aligned> smem_v; + alignas(alignment_for_swizzle( + SmemLayoutAlpha{})) cute::array_aligned> smem_alpha; + alignas( + alignment_for_swizzle(SmemLayoutQK{})) cute::array_aligned> smem_qk; + alignas(alignment_for_swizzle( + SmemLayoutKK{})) cute::array_aligned> smem_kk; + // smemq_k_scaled for exp(alpha) * Q and exp(alpha) * K in QS and KS, computed in Math WG2/3 + alignas(alignment_for_swizzle( + QKScaledSmemLayoutQ{})) cute::array_aligned> smem_q_k_scaled; + + SharedStorageO smem_o; + + cute::array_aligned> smem_beta; + // store last row in Alpha separately, used for S'=K^T NewV's epilogue and S+=decay(S') (one fused epilogue) + cute::array_aligned> smem_alpha_last; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaKV_G2S::Params::TMA_B; + using TMA_V = typename CollectiveMmaKV_G2S::Params::TMA_A; + using TMA_O = typename CollectiveStoreO::Params::TMA_O; + + using LoadQ = CollectiveLoadTma; + using LoadK = CollectiveLoadTma; + using LoadV = CollectiveLoadTma; + // using LoadAlpha = + // CollectiveLoadTma; + using LoadAlpha = CollectiveLoadVector< + LoadKindVector::kAlpha, + MainloopAlphaPipeline, + ElementAlpha, + GmemLayoutAlpha, + ElementAlpha, + SmemLayoutAlpha, + AlphaProcessor + >; + + using LoadBeta = CollectiveLoadVector< + LoadKindVector::kBeta, + MainloopBetaPipeline, + ElementBeta, + GmemLayoutBeta, + ElementBeta, + SmemLayoutBeta, + BetaProcessor>; + + struct Arguments { // clang-format off + Element const* ptr_Q; LayoutQ dQ; + Element const* ptr_K; LayoutK dK; + Element const* ptr_V; LayoutV dV; + Element* ptr_O; LayoutO dO; + // FIXME: needs to be changed to not copy layout of Q + float const* ptr_Alpha; GmemStrideAlpha alpha_stride; + float* ptr_output_state; // layout fixed (kdim, vdim, num_heads, num_seqs):LayoutLeft{} + float const* ptr_input_state; + float scale; + ElementBeta const* beta_ptr; GmemStrideBeta beta_stride; + }; // clang-format on + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_O tma_store_o; + void* tensormaps; + float scale; + + float* ptr_output_state; + float const* ptr_input_state; + + ElementAlpha const * alpha_ptr; + GmemLayoutAlpha alpha_layout; + ElementBeta const* beta_ptr; + GmemLayoutBeta beta_layout; + }; + + template + static bool + can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true && (problem_size.head_size <= get<2>(TileShape{})) && ((problem_size.head_size % Alignment) == 0); + } + + template + static Params + to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + int64_t s = problem_size.total_seqlen; + int64_t t = problem_size.total_seqlen; + int32_t d = problem_size.head_size; + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + make_shape(s, t, d, problem_size.num_heads), + typename CollectiveMmaQK::Arguments{ + args.ptr_Q, args.dQ, args.ptr_K, args.dK, // never used, dummy + }, + /*workspace=*/nullptr); + + auto params_kv_k = CollectiveMmaKV_G2S::to_underlying_arguments( + make_shape(d, d, s, problem_size.num_heads), + typename CollectiveMmaKV_G2S::Arguments{ + args.ptr_V, + select<1, 0, 2>(args.dV), // not used + args.ptr_K, + select<1, 0, 2>(args.dK), // used as G2S for K + }, + /*workspace=*/nullptr); + auto params_kv_v = CollectiveMmaKV_G2S::to_underlying_arguments( + make_shape(d, d, s, problem_size.num_heads), + typename CollectiveMmaKV_G2S::Arguments{ + args.ptr_V, + select<1, 0, 2>(args.dV), // used as G2S for V + args.ptr_K, + select<1, 0, 2>(args.dK), // not used + }, + /*workspace=*/nullptr); + + auto params_o = CollectiveStoreO::to_underlying_arguments( + make_shape(d, s, d, problem_size.num_heads), // in O1 + // make_shape(d, s, s, problem_size.num_heads), // in O2 + typename CollectiveStoreO::Arguments{args.ptr_O, select<1, 0, 2>(args.dO), workspace}, + workspace); + + return Params{ + .tma_load_q = params_qk.tma_load_a, + .tma_load_k = params_kv_k.tma_load_b, + .tma_load_v = params_kv_v.tma_load_a, + .tma_store_o = params_o.tma_store_o, + .tensormaps = params_o.tensormaps, + .scale = args.scale, + + .ptr_output_state = args.ptr_output_state, + .ptr_input_state = args.ptr_input_state, + + // TODO: refactor all name to varname_vartype + .alpha_ptr = args.ptr_Alpha, + .alpha_layout = make_layout(make_shape(s, problem_size.num_heads), args.alpha_stride), + .beta_ptr = args.beta_ptr, + .beta_layout = make_layout(make_shape(s, problem_size.num_heads), args.beta_stride), + }; + } + + static size_t + get_workspace_size(Arguments const& args, int sm_count) { + return CollectiveStoreO::get_workspace_size(sm_count); + } + + template + static cutlass::Status + initialize_workspace( + ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return CollectiveStoreO::initialize_workspace(problem_shape, workspace, stream); + } + + CUTE_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTE_DEVICE void + load_qkv( + Params const& params, + ProblemShape const& problem_size, + LoadTileShape const& load_tile_shape, + WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, + QPipelineState& q_smem_pipe_write, + MainloopKPipeline& k_pipeline, + KPipelineState& k_smem_pipe_write, + MainloopVPipeline& v_pipeline, + VPipelineState& v_smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto q_collective_load = LoadQ(params.tma_load_q, q_pipeline, storage.smem_q); + auto k_collective_load = LoadK(params.tma_load_k, k_pipeline, storage.smem_k); + auto v_collective_load = LoadV(params.tma_load_v, v_pipeline, storage.smem_v); + + auto q_src_dst = q_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto k_src_dst = k_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + auto v_src_dst = v_collective_load.partition_SD(problem_size, load_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + q_collective_load.step(q_src_dst, blk, q_smem_pipe_write, lane_predicate); + k_collective_load.step(k_src_dst, blk, k_smem_pipe_write, lane_predicate); + v_collective_load.step(v_src_dst, blk, v_smem_pipe_write, lane_predicate); + } + } + + template + CUTE_DEVICE void + load_beta_and_alpha( + Params const& params, + ProblemShape const& problem_size, + TileShape const& tile_shape, + WorkDesc const& work_desc, + MainloopBetaPipeline& beta_pipeline, + BetaPipelineState& beta_smem_pipe_write, + MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_write, + SharedStorage& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + + // fuse post inverse diag(beta) into diagonal of IKK + // auto collective_load = LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/1.0f, pipeline, + // storage.smem_beta}; + auto beta_collective_load = + LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, beta_pipeline, storage.smem_beta}; + // oob fill value for alpha is -INFINITY because of exp2f(alpha) + auto alpha_collective_load = + LoadAlpha {params.alpha_ptr, params.alpha_layout, /*oob_value=*/0.0f, alpha_pipeline, storage.smem_alpha}; + auto beta_src_dst = beta_collective_load.partition_SD(problem_size, tile_shape, work_desc); + auto alpha_src_dst = alpha_collective_load.partition_SD(problem_size, tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + beta_collective_load.step(beta_src_dst, blk, beta_smem_pipe_write, num_blocks); + alpha_collective_load.step(alpha_src_dst, blk, alpha_smem_pipe_write, num_blocks); + } + beta_collective_load.step(beta_src_dst, num_blocks - 1, beta_smem_pipe_write, num_blocks); + alpha_collective_load.step(alpha_src_dst, num_blocks - 1, alpha_smem_pipe_write, num_blocks); + } + + template + CUTE_DEVICE void + extract_alpha_last( + Params const& params, + ProblemShape const& problem_size, + TileShape const& tile_shape, + WorkDesc const& work_desc, + MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_read, + MainloopAlphaLastPipeline& alpha_last_pipeline, + AlphaLastPipelineState& alpha_last_smem_pipe_write, + SharedStorage& storage) { + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + + Tensor sA = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + Tensor sAlast = make_tensor(make_smem_ptr(storage.smem_alpha_last.data()), SmemLayoutAlphaLast{}); + + auto extract_loop_body = [&](int blk, auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_final_block = decltype(is_final_block_)::value; + + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + auto sA_curr = sA(_, alpha_smem_pipe_read.index()); + Tensor sAlast_out = sAlast(_, alpha_last_smem_pipe_write.index()); + + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + alpha_last_pipeline.producer_acquire(alpha_last_smem_pipe_write); + + // one warp calls extract_alpha_last, so this is a guaranteed single write. + if (thread_idx == 0) { + sAlast_out(0) = sA_curr(B - 1); + } + __syncwarp(); + cutlass::arch::fence_view_async_shared(); + alpha_last_pipeline.producer_commit(alpha_last_smem_pipe_write); + ++alpha_last_smem_pipe_write; + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + }; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + extract_loop_body(blk, /*is_final_block_=*/cute::false_type{}); + } + extract_loop_body(num_blocks - 1, /*is_final_block_=*/cute::true_type{}); + } + + template + CUTE_DEVICE void + store( + TMA_O const& tma_store, + void* tensormaps, + ProblemSize const& problem_size, + StoreTileShape const& store_tile_shape, + WorkDesc const& work_desc, + MainloopOPipeline& pipeline, + PipelineState& smem_pipe_read, + SharedStorageO& storage) { + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + uint32_t lane_predicate = cute::elect_one_sync(); + + auto collective_store = CollectiveStoreO{tma_store, pipeline, storage, tensormaps}; + auto src_dst = collective_store.partition_SD(problem_size, store_tile_shape, work_desc); + + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks; ++blk) { + DPRINTF0_W( + "O collective_store.step smem_pipe_read:%d -> blk_idx:%d, num_blocks:%d\n", + smem_pipe_read.index(), + blk, + num_blocks); + collective_store.step(problem_size, work_desc, src_dst, smem_pipe_read, blk, num_blocks, lane_predicate); + } + } + + template + CUTE_DEVICE void + compute( + Params const& params, + ProblemShape const& problem_size, + WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, + QPipelineState& q_smem_pipe_read, + MainloopKPipeline& k_pipeline, + KPipelineState& k_smem_pipe_read, + MainloopVPipeline& v_pipeline, + VPipelineState& v_smem_pipe_read, + MainloopOPipeline& o_pipeline, + OPipelineState& o_smem_pipe_write, + MainloopQKPipeline& qk_pipeline, + QKPipelineState& qk_smem_pipe_read, + MainloopKKPipeline& kk_pipeline, + KKPipelineState& kk_smem_pipe_read, + MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_read, + MainloopBetaPipeline& beta_pipeline, + BetaPipelineState& beta_smem_pipe_read, + MainloopAlphaLastPipeline& alpha_last_pipeline, + AlphaLastPipelineState& alpha_last_smem_pipe_read, + OrderedMathBarriers& math_barriers, + SharedStorage& storage) { + // MAKE NVCC HAPPY! + constexpr auto zero = Element{}; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + DPRINTF0_WG("num_blocks: %d\n", num_blocks); + + int thread_idx = int(threadIdx.x) - NumLoadThreads; + int warpgroup_idx = thread_idx / cutlass::NumThreadsPerWarpGroup; + int thread_idx_in_wg = thread_idx % cutlass::NumThreadsPerWarpGroup; + + float scale = params.scale; + + Tensor Beta = make_tensor(make_smem_ptr(storage.smem_beta.data()), SmemLayoutBeta{}); + Tensor AlphaLast = make_tensor(make_smem_ptr(storage.smem_alpha_last.data()), SmemLayoutAlphaLast{}); + + Tensor sQqk = make_tensor(make_smem_ptr(storage.smem_q.data()), QKSmemLayoutQ{}); + Tensor sKqk = make_tensor(make_smem_ptr(storage.smem_k.data()), QKSmemLayoutK{}); + Tensor sAlpha = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + Tensor sVkv = make_tensor(make_smem_ptr(storage.smem_v.data()), KVSmemLayoutV{}); + Tensor sQK = make_tensor(make_smem_ptr(storage.smem_qk.data()), SmemLayoutQK{}); + Tensor sO = make_tensor(make_smem_ptr(storage.smem_o.data()), SmemLayoutO{}); + + static_assert(sizeof(InverseType) == sizeof(Element)); + Tensor sKK_inv = make_tensor(make_smem_ptr(storage.smem_kk.data()), SmemLayoutKK{}); + Tensor sKK_opd = make_tensor(make_smem_ptr(reinterpret_cast(storage.smem_kk.data())), SmemLayoutKK{}); + + Tensor sQ_K_scaled = make_tensor(make_smem_ptr(storage.smem_q_k_scaled.data()), QKScaledSmemLayoutQ{}); + Tensor sQ_K_scaled_Kt = make_tensor(make_smem_ptr(storage.smem_q_k_scaled.data()), QKScaledSmemLayoutKt{}); + + /////////////////////////////////////////////////////////////////////////// + // Q@S, K@S, Q/K prologue + // each WG process 32 at a time, reduce peak register usage + // each WG process half head dim (64) at all + auto qk_tiled_mma_rs_quar = TiledMmaQK_RS_Quar{}; + auto qk_thr_mma_rs_quar = qk_tiled_mma_rs_quar.get_thread_slice(thread_idx_in_wg); + constexpr auto tiler_alpha = Shape<_64>{}; + constexpr auto tiler_qk = Shape<_64, Shape<_32, _1>>{}; + // used for Alpha S2R (float) + using CopyAlphaAtom = Copy_Atom, ElementAlpha>; + // used for Q/K S2R and R2S (fp16/bf16) + using CopyOpS2R = SM75_U32x4_LDSM_N; + using CopyOpR2S = SM90_U32x4_STSM_N; + + // reduce S2R loads to quarter, 32 at a time for both loads and stores + auto tiled_load_qk_quar = make_tiled_copy_A(Copy_Atom{}, qk_thr_mma_rs_quar); + auto thr_load_qk_quar = tiled_load_qk_quar.get_thread_slice(thread_idx_in_wg); + auto tiled_store_qk_quar = make_tiled_copy_A(Copy_Atom{}, qk_thr_mma_rs_quar); + auto thr_store_qk_quar = tiled_store_qk_quar.get_thread_slice(thread_idx_in_wg); + + auto cMq_quar = make_identity_tensor(select<0, 2>(TileShapeQK_Quar{})); // (QTok, HeadDim / 2) + auto tQcMq_quar = qk_thr_mma_rs_quar.partition_A(cMq_quar); // (idx) -> (tok_q, head_dim / 2) + + /////////////////////////////////////////////////////////////////////////// + // K@K (basically I + strict_lower_triangular(K K^T) + auto kk_tiled_mma = TiledMmaKK{}; + auto kk_thr_mma = kk_tiled_mma.get_thread_slice(thread_idx_in_wg); + Tensor tKKsK = kk_thr_mma.partition_B(sKqk); + Tensor tKKrA = kk_thr_mma.make_fragment_A(tKKsK); + auto cMqk = make_identity_tensor(select<0, 1>(TileShapeQK{})); // (QTok, KTok) + auto const& cMkk = cMqk; + auto tKKcMkk = kk_thr_mma.partition_C(cMkk); + + // S@K (-S K^T + V^T) - K and T + auto sk_tiled_mma = TiledMmaSK{}; + auto sk_thr_mma = sk_tiled_mma.get_thread_slice(thread_idx); + + // tSKrV adds to tSKrSK (acc), so we set up a copy + using SK_V_S2R = Copy_Atom; + auto tSKrV_tiled_copy = make_tiled_copy_C(SK_V_S2R{}, sk_tiled_mma); + auto tSKrV_thr_copy = tSKrV_tiled_copy.get_thread_slice(thread_idx); + + Tensor tSKsK = sk_thr_mma.partition_B(sQ_K_scaled); + Tensor tSKrK = sk_thr_mma.make_fragment_B(tSKsK); + + /////////////////////////////////////////////////////////////////////////// + // NewV = (S@K result) @ T^t + auto newv_tiled_mma = TiledMmaNewV{}; + auto newv_thr_mma = newv_tiled_mma.get_thread_slice(thread_idx); + + Tensor tNewVsB = newv_thr_mma.partition_B(sKK_opd); + Tensor tNewVrB = newv_thr_mma.make_fragment_B(tNewVsB); + + /////////////////////////////////////////////////////////////////////////// + // K@V + auto kv_tiled_mma = TiledMmaKV{}; // (V, Blk_k) @ (Blk_k, K) = (V, K) + auto kv_thr_mma = kv_tiled_mma.get_thread_slice(thread_idx); + + Tensor tKVrKV = partition_fragment_C(kv_thr_mma, select<0, 1>(TileShapeKV{})); + + // Tensor tKVrV = kv_thr_mma.partition_fragment_A(sVkv(_, _, _0{})); // mma src + // Tensor tKVrV_cv = tKVrV_thr_copy.retile_D(tKVrV); // copy view dst + // Tensor tKVsV = tKVrV_thr_copy.partition_S(sVkv); // copy view src + + auto const cV = make_identity_tensor(Shape, Int>{}); + Tensor tKVcV = kv_thr_mma.partition_A(cV); + auto const cS = make_identity_tensor(Shape, Int>{}); + Tensor tKVcS = kv_thr_mma.partition_C(cS); + + /////////////////////////////////////////////////////////////////////////// + // Q@K@V + auto o1_tiled_mma = TiledMmaO1{}; + auto o1_thr_mma = o1_tiled_mma.get_thread_slice(thread_idx); + auto o2_tiled_mma = TiledMmaO2{}; + auto o2_thr_mma = o2_tiled_mma.get_thread_slice(thread_idx); + + // A1 for Q@(KV) + // Tensor tOrKV = make_acc_into_op(tKVrKV, typename TiledMmaO1::LayoutA_TV{}); + // B1 for Q@(KV) + Tensor tOsQ = o1_thr_mma.partition_B(sQ_K_scaled); + Tensor tOrQ = o1_thr_mma.make_fragment_B(tOsQ); + + // A2 for QK@V + // Tensor tOsV = o2_thr_mma.partition_A(sVkv); + // Tensor tOrV = o2_thr_mma.make_fragment_A(tOsV); + // B2 for QK@V + Tensor tOsQK = o2_thr_mma.partition_B(sQK); + Tensor tOrQK = o2_thr_mma.make_fragment_B(tOsQK); + + using O_R2S = typename CollectiveStoreO::CopyAtomR2S; + auto tiled_copy_o = make_tiled_copy_C(O_R2S{}, o1_tiled_mma); + auto thr_copy_o = tiled_copy_o.get_thread_slice(thread_idx); + auto tOsO = thr_copy_o.partition_D(sO); + + auto const cO = make_identity_tensor(Shape, Int>{}); + Tensor tOcO = o1_thr_mma.partition_C(cO); + + auto const seq_idx = work_desc.seq_idx; + auto const q_head_idx = work_desc.q_head_idx(); + auto const k_head_idx = work_desc.k_head_idx(); + auto const v_head_idx = work_desc.v_head_idx(); + + auto sk_load_v = [&](int pipe_idx) INLINE_LAMBDA { + Tensor tSKrV = make_fragment_like(partition_fragment_C(sk_thr_mma, sVkv(_, _, _0{}))); // mma acc + Tensor tSKrV_cv = tSKrV_thr_copy.retile_D(tSKrV); // copy view dst + Tensor tSKsV = tSKrV_thr_copy.partition_S(sVkv); // copy view src + copy(tSKrV_tiled_copy, tSKsV(_, _, _, pipe_idx), tSKrV_cv); + return tSKrV; + }; + // kv_load here describes the 'kv cache', or the cached block state of size head_dim_k x head_dim_v + auto kv_load = [&](auto& tKVrKV) INLINE_LAMBDA { + DPRINTF0_WG("[%d,%d,%d,%d]>> load tKVgKV -> tKVrKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + int num_state_heads = problem_size.num_heads; + int state_head_idx = work_desc.o_head_idx(); + auto gKV = make_tensor( + make_gmem_ptr(params.ptr_input_state), + make_layout(make_shape(Int{}, Int{}, num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + // NOTE: load S in transposed GMEM + // because in GDN's equation, S = NewV^T @ K, while in KDA, S = K^T @ NewV + auto gKV_trans = make_tensor( + make_gmem_ptr(gKV.data()), + make_layout( + make_shape(get<1>(gKV.layout().shape()), get<0>(gKV.layout().shape())), + make_stride(get<1>(gKV.layout().stride()), get<0>(gKV.layout().stride())))); + + auto tiled_copy_kv = make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_S(select_tensor<1, 0>(gKV_trans)); + copy(tiled_copy_kv, tKVgKV, tKVrKV); + }; + + auto kv_store = [&]() INLINE_LAMBDA { // tKVrKV is carried over whole mainloop + DPRINTF0_WG("[%d,%d,%d,%d]>> save tKVrKV -> tKVgKV\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + int num_state_heads = problem_size.num_heads; + int state_head_idx = work_desc.o_head_idx(); + auto gKV = make_tensor( + make_gmem_ptr(params.ptr_output_state), + make_layout(make_shape(Int{}, Int{}, num_state_heads, problem_size.num_seqs)))( + _, _, state_head_idx, seq_idx); // (KDim, VDim), K-contiguous + // NOTE: store S in transposed GMEM + // because in GDN's equation, S = NewV^T @ K, while in KDA, S = K^T @ NewV + auto gKV_trans = make_tensor( + make_gmem_ptr(gKV.data()), + make_layout( + make_shape(get<1>(gKV.layout().shape()), get<0>(gKV.layout().shape())), + make_stride(get<1>(gKV.layout().stride()), get<0>(gKV.layout().stride())))); + + auto tiled_copy_kv = make_tiled_copy_C(Copy_Atom{}, kv_tiled_mma); + auto thr_copy_kv = tiled_copy_kv.get_thread_slice(thread_idx); + + auto tKVgKV = thr_copy_kv.partition_D(select_tensor<1, 0>(gKV_trans)); + copy(tiled_copy_kv, tKVrKV, tKVgKV); + }; + + auto s_decay = [&](auto& tKVrKV, auto const& alpha_last_smem_pipe_read) INLINE_LAMBDA { + auto alpha_last_curr = AlphaLast(0, alpha_last_smem_pipe_read.index()); + for_each(make_int_sequence{}, [&](auto i) { + tKVrKV(i) *= exp2f(alpha_last_curr); + }); + }; + + auto o1_epi = [&](auto& tOrO1) INLINE_LAMBDA { + CUTE_UNROLL + for (int i = 0; i < size(tOrO1); ++i) { + tOrO1(i) = scale * tOrO1(i); + } + }; + + auto o_store = [&](auto tOrO) INLINE_LAMBDA { + auto tOrO_cvt = make_fragment_like(tOrO); + copy(tOrO, tOrO_cvt); + + DPRINTF0_WG("compute: o_pipeline.producer_wait: smem_pipe_write:%d\n", o_smem_pipe_write.index()); + o_pipeline.producer_acquire(o_smem_pipe_write); + Tensor tOrO_cvt_cv = thr_copy_o.retile_S(tOrO_cvt); + cutlass::arch::fence_view_async_shared(); + copy(tiled_copy_o, tOrO_cvt_cv, tOsO(_, _, _, o_smem_pipe_write.index())); + cutlass::arch::fence_view_async_shared(); + o_pipeline.producer_commit(o_smem_pipe_write); + ++o_smem_pipe_write; + }; + + auto kk_inv = [&](auto const& kk_smem_pipe_read) INLINE_LAMBDA { + auto sKK_inv_pipe_slice = sKK_inv(_, _, kk_smem_pipe_read.index()); + static_assert(sizeof(Element) == 2); + using CopyOpR2S = SM90_U32x4_STSM_N; + auto tiled_store_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_store_kk = tiled_store_kk.get_thread_slice(thread_idx); + auto tKKsKK = thr_store_kk.partition_D(sKK_inv_pipe_slice); + // TODO: use tKKcMkk? no more allocating fragments + auto tKKrKK = kk_thr_mma.partition_fragment_C(sKK_inv_pipe_slice); + auto tKKrKK_cv = thr_store_kk.retile_S(tKKrKK); + auto collective_inverse = CollectiveInverse(GdnNamedBarriers::StateMathWG0); + collective_inverse.compute(sKK_inv_pipe_slice); + // FIXME: we can ignore core matrices above diagonal + if constexpr (NeedsBeta || !std::is_same_v) { + cutlass::arch::NamedBarrier::arrive_and_wait( + cutlass::NumThreadsPerWarpGroup, GdnNamedBarriers::StateMathWG0); + using CopyOpS2R = SM75_U32x4_LDSM_N; + auto tiled_load_kk = make_tiled_copy_C(Copy_Atom{}, kk_tiled_mma); + auto thr_load_kk = tiled_load_kk.get_thread_slice(thread_idx); + auto tKKrKK_cpy = make_fragment_like(tKKrKK_cv); + auto tKKrKK_cvt = make_fragment_like(tKKrKK_cv); + auto tKKcMkk_cv = thr_load_kk.retile_D(tKKcMkk); + copy(tiled_load_kk, thr_load_kk.partition_S(sKK_inv_pipe_slice), tKKrKK_cpy); + cute::transform(tKKrKK_cpy, tKKcMkk_cv, tKKrKK_cvt, [&](auto val, auto coord) { + auto [_, t] = coord; + if constexpr (NeedsBeta) { + return Element(float(val) * Beta(t, beta_smem_pipe_read.index())); + } else { + return Element(val); + } + }); + copy(tiled_store_kk, tKKrKK_cvt, recast(tKKsKK)); + } + }; + + // this method syncs with the other pipelines involved through kernel.hpp + auto compute_loop_body = [&](int blk, auto is_first_block_, auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_first_block = decltype(is_first_block_)::value; + constexpr bool is_final_block = decltype(is_final_block_)::value; + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + auto sQqk_curr = sQqk(_, _, q_smem_pipe_read.index()); + auto sKqk_curr = sKqk(_, _, k_smem_pipe_read.index()); + auto sQ_scaled_curr = sQ_K_scaled(_, _, _0{}); + auto sK_scaled_curr = sQ_K_scaled(_, _, _1{}); + auto sAlast_curr = AlphaLast(_, alpha_last_smem_pipe_read.index()); + auto sAlpha_curr = sAlpha(_, alpha_smem_pipe_read.index()); + auto sQqk_slice = flat_divide(sQqk_curr, tiler_qk); + auto sKqk_slice = flat_divide(sKqk_curr, tiler_qk); + auto sQ_scaled_slice = flat_divide(sQ_scaled_curr, tiler_qk); + auto sK_scaled_slice = flat_divide(sK_scaled_curr, tiler_qk); + auto sAlpha_slice = flat_divide(sAlpha_curr, tiler_alpha); + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + DPRINTF0_WG("compute: q_pipeline.consumer_wait: smem_pipe_read:%d\n", q_smem_pipe_read.index()); + q_pipeline.consumer_wait(q_smem_pipe_read); + DPRINTF0_WG("compute: k_pipeline.consumer_wait: smem_pipe_read:%d\n", k_smem_pipe_read.index()); + k_pipeline.consumer_wait(k_smem_pipe_read); + + // load alpha and exp2(alpha) only once + // and reuse these registers in exp(alpha) * Q/K prologue + if constexpr (!is_first_block) { + // make sure sQ_K_scaled is already consumed for previous K^@V + cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); + // Each WG iterates over 2 slices of 32 elements each. + // WG0 (thread_idx < 128): wg_idx=0, processes alpha indices {0,1}, Q/K dim1=0 + // WG1 (thread_idx >= 128): wg_idx=1, processes alpha indices {2,3}, Q/K dim1=1 + { + int wg_idx = thread_idx / 128; // 0 or 1 + int alpha_base = wg_idx * 2; // 0 or 2 + + // Allocate Q/K register fragments once (reused across slices) + // Only shape/layout matters for partition_fragment_A, use compile-time indices + auto tQKrQ_wg = + qk_thr_mma_rs_quar.partition_fragment_A(sQqk_slice(_, _, _0{}, make_coord(_0{}, _0{}))); + auto tQKrK_wg = + qk_thr_mma_rs_quar.partition_fragment_A(sKqk_slice(_, _, _0{}, make_coord(_0{}, _0{}))); + auto tArA = make_fragment_like(tQKrQ_wg); + + auto sA_cur = sAlpha_slice(_, _0{}); + #pragma unroll + for (int v = 0; v < size(tArA); v++) { + tArA(v) = sA_cur(get<0>(tQcMq_quar(v))); + } + cute::transform(tArA, [](auto g) { return exp2f(g); }); + for (int s = 0; s < 2; ++s) { + + // S2R Q + auto sQqk_cur = sQqk_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsQ_cur = thr_load_qk_quar.partition_S(sQqk_cur); + auto tQKrQ_cv = thr_load_qk_quar.retile_D(tQKrQ_wg); + copy(tiled_load_qk_quar, tQKsQ_cur, tQKrQ_cv); + + // element-wise exp(alpha) * Q + cute::transform(tQKrQ_wg, tArA, tQKrQ_wg, [&](auto q, auto alpha) { + Element dst = Element(alpha * float(q)); + return dst; + }); + + // R2S Q -> stage 0 + auto sQ_scaled_cur = sQ_scaled_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsQ_out = thr_store_qk_quar.partition_D(sQ_scaled_cur); + auto tQKrQ_out_cv = thr_store_qk_quar.retile_S(tQKrQ_wg); + copy(tiled_store_qk_quar, tQKrQ_out_cv, tQKsQ_out); + + // S2R K + auto sKqk_cur = sKqk_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsK_cur = thr_load_qk_quar.partition_S(sKqk_cur); + auto tQKrK_cv = thr_load_qk_quar.retile_D(tQKrK_wg); + copy(tiled_load_qk_quar, tQKsK_cur, tQKrK_cv); + + // element-wise exp(alpha) * K + cute::transform(tQKrK_wg, tArA, tQKrK_wg, [&](auto k, auto alpha) { + Element dst = Element(alpha * float(k)); + return dst; + }); + + // R2S K -> stage 1 + auto sK_scaled_cur = sK_scaled_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsK_out = thr_store_qk_quar.partition_D(sK_scaled_cur); + auto tQKrK_out_cv = thr_store_qk_quar.retile_S(tQKrK_wg); + copy(tiled_store_qk_quar, tQKrK_out_cv, tQKsK_out); + } + } + cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); + // fence to produce data for WGMMA async proxy + cutlass::arch::fence_view_async_shared(); + // if (blk <= 1 && thread_idx == 0) { + // printf("After Q/K prologue: exp(alpha) * Q at stage 0, exp(alpha) * K at stage 1\n"); + // cute::print_tensor(sQ_K_scaled_curr); + // } + } + + // 2.1 Q @ KV, NOTE: use old KV here + + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch O WGMMA\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + auto tOrO = partition_fragment_C(o1_thr_mma, select<0, 1>(TileShapeO1{})); + if constexpr (is_first_block) { + DPRINTF0_WG("compute: q_pipeline.consumer_release: smem_pipe_read:%d\n", q_smem_pipe_read.index()); + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + } else { + // change layout of S_t to A layout for O1 + Tensor tOrKV = make_acc_into_op(tKVrKV, typename TiledMmaO1::LayoutA_TV{}); + warpgroup_fence_operand(tOrKV); + warpgroup_fence_operand(tOrO); + // ======DEBUG======= + // if (blk <= 6 && thread_idx == 0) { + // printf("=======Before Q@S, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tOrKV); + // cute::print_tensor(sQ_K_scaled_slice); + // } + + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm_zero_acc(o1_thr_mma, tOrKV, tOrQ(_, _, _, 0), tOrO); + warpgroup_commit_batch(); // q@kv batch + math_barriers.notify_next_blocked(warpgroup_idx); + } + if constexpr (!is_first_block) { + warpgroup_wait<0>(); // q@kv batch + // ======DEBUG======= + // if (blk <= 1 && thread_idx == 0) { + // printf("\n"); + // printf("=======O_inter after Q@S, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tOrO); + // } + DPRINTF0_WG("compute: q_pipeline.consumer_release: smem_pipe_read:%d\n", q_smem_pipe_read.index()); + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + o1_epi(tOrO); + } + + auto tSKrSK = partition_fragment_C(sk_thr_mma, sVkv(_, _, _0{})); + if constexpr (!is_first_block) { + auto tSKrS = make_acc_into_op(tKVrKV, typename TiledMmaSK::LayoutA_TV{}); + warpgroup_fence_operand(tSKrSK); + warpgroup_fence_operand(tSKrS); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + + // ======DEBUG======= + // if (blk <= 6 && thread_idx == 0) { + // printf("=======Before K@S, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tSKrS); + // } + + // SK: K_scaled is in stage 1 of sQ_K_scaled + gemm_zero_acc(sk_tiled_mma, tSKrS, tSKrK(_, _, _, 1), tSKrSK); + warpgroup_commit_batch(); + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); + } + // ======DEBUG======= + // if (blk <= 6 && thread_idx == 0) { + // printf("=======After K@S, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tSKrSK); + // } + + DPRINTF0_WG("compute: v_pipeline.consumer_wait: smem_pipe_read:%d\n", v_smem_pipe_read.index()); + v_pipeline.consumer_wait(v_smem_pipe_read); + auto tSKrV = sk_load_v(v_smem_pipe_read.index()); + if constexpr (!is_first_block) { + // sk_epi(tSKrSK, alpha_smem_pipe_read); + // V' = V - SK + transform(tSKrV, tSKrSK, tSKrV, [](auto v, auto sk) { return v - Element(sk); }); + } + + kk_pipeline.consumer_wait(kk_smem_pipe_read); + beta_pipeline.consumer_wait(beta_smem_pipe_read); + cutlass::arch::fence_view_async_shared(); + // KK inverse + if (warpgroup_idx == 0) { + kk_inv(kk_smem_pipe_read); + } + // wait for KK inverse ready + cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); + + auto tNewVrA = make_acc_into_op(tSKrV, typename TiledMmaNewV::LayoutA_TV{}); + auto tNewVrC = partition_fragment_C(newv_thr_mma, select<0, 1>(TileShapeNewV{})); + warpgroup_fence_operand(tNewVrA); + warpgroup_fence_operand(tNewVrC); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + // if constexpr (is_final_block) { + // if (thread_idx == 0) { + // printf("\n"); + // printf("=======tNewVrA, tNewVrB before V'@T, block_idx: %d, thread_idx: %d=======\n", blk, + // thread_idx); printf("tNewVrA\n"); cute::print_tensor(tNewVrA); printf("sKK_opd\n"); + // cute::print_tensor(sKK_opd); + // printf("=======tNewVrA, tNewVrB before V'@T, block_idx: %d, thread_idx: %d=======\n", blk, + // thread_idx); printf("\n"); + // } + // } + // NewV = V'T + gemm_zero_acc(o1_thr_mma, tNewVrA, tNewVrB(_, _, _, kk_smem_pipe_read.index()), tNewVrC); + warpgroup_commit_batch(); // new_v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); // new_v batch + // if constexpr (is_final_block) { + // if (thread_idx == 0) { + // printf("\n"); + // printf("=======tNewVrC after V'@T, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tNewVrC); + // } + // } + DPRINTF0_WG("compute: v_pipeline.consumer_release: smem_pipe_read:%d\n", v_smem_pipe_read.index()); + ++v_smem_pipe_read; // NOTE: if we delay this increment after consumer_release, race condition happens, + // why? + v_pipeline.consumer_release(v_smem_pipe_read); + + kk_pipeline.consumer_release(kk_smem_pipe_read); + ++kk_smem_pipe_read; + beta_pipeline.consumer_release(beta_smem_pipe_read); + ++beta_smem_pipe_read; + + ///////////////////////////////////////////////////////////////////////// + // 2. compute qkv + // 2.2 QK @ V, NOTE: use old KV here and QK is scaled + qk_pipeline.consumer_wait(qk_smem_pipe_read); + auto tOrV_or_tKVrV = make_acc_into_op(tNewVrC, typename TiledMmaKV::LayoutA_TV{}); + warpgroup_fence_operand(tOrV_or_tKVrV); + warpgroup_fence_operand(tOrO); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + // (V_new)^T @ QK + if constexpr (is_first_block) { + gemm_zero_acc(o2_tiled_mma, tOrV_or_tKVrV, tOrQK(_, _, _, qk_smem_pipe_read.index()), tOrO); + } else { + gemm(o2_tiled_mma, tOrV_or_tKVrV, tOrQK(_, _, _, qk_smem_pipe_read.index()), tOrO); + } + warpgroup_commit_batch(); // qk@v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); // qk@v batch + // if (blk <= 6 && thread_idx == 0) { + // printf("\n"); + // printf("=======O_intra after NewV@QK, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // cute::print_tensor(tOrO); + // } + qk_pipeline.consumer_release(qk_smem_pipe_read); + ++qk_smem_pipe_read; + o_store(tOrO); + + ///////////////////////////////////////////////////////////////////////// + // 3. update KV + Tensor tKVsK = kv_thr_mma.partition_B(sQ_K_scaled_Kt); + Tensor tKVrK = kv_thr_mma.make_fragment_B(tKVsK); + + if constexpr (NeedsAlpha) { + alpha_last_pipeline.consumer_wait(alpha_last_smem_pipe_read); + cutlass::arch::fence_view_async_shared(); + } + if constexpr (!is_first_block) { + s_decay(tKVrKV, alpha_last_smem_pipe_read); + } + + + // synchronize 2 WGs before rewriting sQ_K_scaled + cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); + // exp(alpha_last - alpha) * K + // Each WG iterates over 2 slices of 32 elements each. + // WG0 (thread_idx < 128): wg_idx=0, alpha_last indices {0,1}, K/output dim1=0 + // WG1 (thread_idx >= 128): wg_idx=1, alpha_last indices {2,3}, K/output dim1=1 + { + int wg_idx = thread_idx / 128; // 0 or 1 + // Allocate K/Alpha register fragments once (reused across slices) + auto tQKrK_wg = qk_thr_mma_rs_quar.partition_fragment_A(sKqk_slice(_, _, _0{}, make_coord(_0{}, _0{}))); + auto tArA_wg = make_fragment_like(tQKrK_wg); + + auto sA_cur = sAlpha_slice(_, _0{}); + // Dummy tensor to enable broadcast of alpha values across a row + #pragma unroll + for (int v = 0; v < size(tArA_wg); v++) { + tArA_wg(v) = sA_cur(get<0>(tQcMq_quar(v))); + } + auto alpha_last = sAlast_curr(0); + for (int s = 0; s < 2; ++s) { + // S2R K + auto sKqk_cur = sKqk_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsK_cur = thr_load_qk_quar.partition_S(sKqk_cur); + auto tQKrK_cv = thr_load_qk_quar.retile_D(tQKrK_wg); + copy(tiled_load_qk_quar, tQKsK_cur, tQKrK_cv); + + // element-wise: exp(alpha_last - alpha) * K + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQcMq_quar(i); + auto [seq, _] = coord; + auto alpha = tArA_wg(i); + auto k = tQKrK_wg(i); + auto k_scaled = Element(exp2f(alpha_last - alpha) * float(k)); + tQKrK_wg(i) = k_scaled; + if constexpr (is_final_block) { + if (seq >= B) { + tQKrK_wg(i) = Element(0.0f); + } + } + }); + // R2S K -> stage 0 (reuse for KV update) + auto sQ_scaled_cur = sQ_scaled_slice(_, _, _0{}, make_coord(s, wg_idx)); + auto tQKsK_out = thr_store_qk_quar.partition_D(sQ_scaled_cur); + auto tQKrK_out_cv = thr_store_qk_quar.retile_S(tQKrK_wg); + copy(tiled_store_qk_quar, tQKrK_out_cv, tQKsK_out); + } + } + // wait for smemq_k_scaled ready + cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); + // fence to produce data for WGMMA async proxy + cutlass::arch::fence_view_async_shared(); + + if constexpr (NeedsAlpha) { + alpha_last_pipeline.consumer_release(alpha_last_smem_pipe_read); + ++alpha_last_smem_pipe_read; + } + DPRINTF0_WG("[%d,%d,%d,%d]** dispatch KV WGMMA\n", seq_idx, q_head_idx, k_head_idx, v_head_idx); + warpgroup_fence_operand(tOrV_or_tKVrV); + warpgroup_fence_operand(tKVrKV); + math_barriers.ordered_or_wait(warpgroup_idx); + warpgroup_arrive(); + gemm(kv_tiled_mma, tOrV_or_tKVrV, tKVrK(_, _, _, 0), tKVrKV); + warpgroup_commit_batch(); // k@v batch + math_barriers.notify_next_blocked(warpgroup_idx); + warpgroup_wait<0>(); + + // if constexpr (is_final_block) { + // if (thread_idx == 0 && cute::block(0)) { + // printf("\n"); + // printf("=======After K^T@NewV, block_idx: %d, thread_idx: %d=======\n", blk, thread_idx); + // printf("tKVrKV\n"); + // cute::print_tensor(tKVrKV); + // } + // } + + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + } + + DPRINTF0_WG("compute: k_pipeline.consumer_release: smem_pipe_read:%d\n", k_smem_pipe_read.index()); + k_pipeline.consumer_release(k_smem_pipe_read); + ++k_smem_pipe_read; + + // if (blk <= 6 && thread_idx == 0 && cute::block(0)) { + // printf("\n"); + // printf("=======After S epilogue, block_idx: %d, thread_idx: %d, head_idx: %d=======\n", blk, + // thread_idx, q_head_idx); printf("tKVrKV\n"); cute::print_tensor(tKVrKV); + // } + }; + + if constexpr (!kInitStateFromInput) { + clear(tKVrKV); + compute_loop_body(0, /*is_first_block_=*/cute::true_type{}, /*is_final_block_=*/cute::false_type{}); + } else { + kv_load(tKVrKV); // GMEM -> Register, only once at the beginning + compute_loop_body(0, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::false_type{}); + } + CUTE_NO_UNROLL + for (int blk = 1; blk < num_blocks - 1; ++blk) { + compute_loop_body(blk, /*is_first_block_=*/cute::false_type{}, /*is_final_block_=*/cute::false_type{}); + } + if (num_blocks != 1) { + compute_loop_body( + num_blocks - 1, + /*is_first_block_=*/cute::false_type{}, + /*is_final_block_=*/cute::true_type{}); + } + kv_store(); + } + + template + CUTE_DEVICE void + compute_aux_safe( + Params const& params, + ProblemShape const& problem_size, + WorkDesc const& work_desc, + MainloopQPipeline& q_pipeline, + QPipelineState& q_smem_pipe_read, + MainloopKPipeline& k_pipeline, + KPipelineState& k_smem_pipe_read, + MainloopQKPipeline& qk_pipeline, + QKPipelineState& qk_smem_pipe_write, + MainloopKKPipeline& kk_pipeline, + KKPipelineState& kk_smem_pipe_write, + MainloopAlphaPipeline& alpha_pipeline, + AlphaPipelineState& alpha_smem_pipe_read, + MainloopBetaPipeline& beta_pipeline, + BetaPipelineState& beta_smem_pipe_read, + MainloopAlphaLastPipeline& alpha_last_pipeline, + AlphaLastPipelineState& alpha_last_smem_pipe_write, + SharedStorage& storage) { + using TileShape_SubChunk = Shape<_16, _16, _32>; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + float scale = params.scale; + + Tensor Beta = make_tensor(make_smem_ptr(storage.smem_beta.data()), SmemLayoutBeta{}); + + Tensor sQqk = make_tensor(make_smem_ptr(storage.smem_q.data()), QKSmemLayoutQ{}); + Tensor sKqk = make_tensor(make_smem_ptr(storage.smem_k.data()), QKSmemLayoutK{}); + + Tensor Alpha = make_tensor(make_smem_ptr(storage.smem_alpha.data()), SmemLayoutAlpha{}); + Tensor sAlast = make_tensor(make_smem_ptr(storage.smem_alpha_last.data()), SmemLayoutAlphaLast{}); + + Tensor sKkv = make_tensor(make_smem_ptr(storage.smem_k.data()), KVSmemLayoutK{}); + Tensor sVkv = make_tensor(make_smem_ptr(storage.smem_v.data()), KVSmemLayoutV{}); + Tensor sQK = make_tensor(make_smem_ptr(storage.smem_qk.data()), SmemLayoutQK{}); + Tensor sO = make_tensor(make_smem_ptr(storage.smem_o.data()), SmemLayoutO{}); + + static_assert(sizeof(InverseType) == sizeof(Element)); + Tensor sKK_inv = make_tensor(make_smem_ptr(storage.smem_kk.data()), SmemLayoutKK{}); + Tensor sKK_opd = make_tensor(make_smem_ptr(reinterpret_cast(storage.smem_kk.data())), SmemLayoutKK{}); + + constexpr int BK = 32; // should be same as TileShape_SubChunk + constexpr int NK = 128 / BK; + + /////////////////////////////////////////////////////////////////////////// + // Q@K + auto qk_tiled_mma_rs = TiledMmaQK_RS{}; + auto qk_thr_mma_rs = qk_tiled_mma_rs.get_thread_slice(thread_idx); + + auto cMqk = make_identity_tensor(select<0, 1>(TileShapeQK{})); // (QTok, KTok) + auto tQKcMqk = qk_thr_mma_rs.partition_C(cMqk); // (idx) -> (tok_q, tok_k) + auto cMq = make_identity_tensor(select<0, 2>(TileShapeQK{})); // (QTok, HeadDim) + auto tQcMq = qk_thr_mma_rs.partition_A(cMq); // (idx) -> (tok_q, head_dim) + + auto const seq_idx = work_desc.seq_idx; + auto const q_head_idx = work_desc.q_head_idx(); + auto const k_head_idx = work_desc.k_head_idx(); + auto const v_head_idx = work_desc.v_head_idx(); + + auto qk_kk_subchunk_mma_and_store = [&](int blk) INLINE_LAMBDA { + using CopyOp_R2S = SM90_U32x2_STSM_N; + using CopyAlphaAtom = Copy_Atom, ElementAlpha>; + // Q/K S2R: use BF16 MMA's LDSM tiled copy for efficient shared memory loads, + // then convert register layout to TF32 MMA layout via warp shuffles. + // This replaces the previous AutoVectorizingCopy<16> which caused 50% more smem traffic. + using CopyQKAtom_LDSM = Copy_Atom; + // TF32 MMA: float(Element) → tf32 → MMA, better precision than fp16/bf16 MMA + using MMA = SM80_16x8x8_F32TF32TF32F32_TN; + using TiledMma_SubChunk = + decltype(make_tiled_mma(MMA{}, Layout>{}, TileShape_SubChunk{})); + // BF16 MMA (same shape 16x8x8) used only for creating LDSM-compatible tiled copies + using MMA_BF16 = SM80_16x8x8_F32BF16BF16F32_TN; + using TiledMma_BF16_SubChunk = + decltype(make_tiled_mma(MMA_BF16{}, Layout>{}, TileShape_SubChunk{})); + + int local_thread_idx = thread_idx % 64; + auto tiledmma_subchunk = TiledMma_SubChunk{}; + auto thr_mma_subchunk = tiledmma_subchunk.get_thread_slice(local_thread_idx); + auto tiledmma_bf16_subchunk = TiledMma_BF16_SubChunk{}; + auto thr_mma_bf16_subchunk = tiledmma_bf16_subchunk.get_thread_slice(local_thread_idx); + + // Alpha S2R: load in BF16 MMA layout so gating happens before the layout shuffle, + // reducing register pressure (alpha can be freed before the shuffle). + // BF16-layout alpha copies for operand A and B (for element-wise gating) + auto alpha_Q_bf16_tiled_copy = make_tiled_copy_A(CopyAlphaAtom{}, tiledmma_bf16_subchunk); + auto alpha_Kt_bf16_tiled_copy = make_tiled_copy_B(CopyAlphaAtom{}, tiledmma_bf16_subchunk); + // Q/K S2R: LDSM copies using BF16 MMA layout for efficient ldmatrix loads + auto Q_tiled_copy = make_tiled_copy_A(CopyQKAtom_LDSM{}, tiledmma_bf16_subchunk); + auto Kt_tiled_copy = make_tiled_copy_B(CopyQKAtom_LDSM{}, tiledmma_bf16_subchunk); + // R2S copies for accumulators (C layout, same for all MMA types with same output) + // O_tiled_copy uses the bf16 copy, O_tiled_copy_kk uses the InverseType copy + auto O_tiled_copy = make_tiled_copy_C(Copy_Atom{}, tiledmma_subchunk); + auto O_tiled_copy_kk = make_tiled_copy_C(Copy_Atom{}, tiledmma_subchunk); + + auto alpha_Q_bf16_thr_copy = alpha_Q_bf16_tiled_copy.get_thread_slice(local_thread_idx); + auto alpha_Kt_bf16_thr_copy = alpha_Kt_bf16_tiled_copy.get_thread_slice(local_thread_idx); + auto Q_thr_copy = Q_tiled_copy.get_thread_slice(local_thread_idx); + auto Kt_thr_copy = Kt_tiled_copy.get_thread_slice(local_thread_idx); + auto O_thr_copy = O_tiled_copy.get_thread_slice(local_thread_idx); + auto O_thr_copy_kk = O_tiled_copy_kk.get_thread_slice(local_thread_idx); + + // index tensor + auto cMqk_subchunk = make_identity_tensor(select<0, 1>(TileShape_SubChunk{})); + auto cMqKA_subchunk = make_identity_tensor(select<0, 2>(TileShape_SubChunk{})); + auto cMqKB_subchunk = make_identity_tensor(select<1, 2>(TileShape_SubChunk{})); + auto tQKcMqk_subchunk = thr_mma_subchunk.partition_C(cMqk_subchunk); + auto tQKaMqk_subchunk = thr_mma_bf16_subchunk.partition_A(cMqKA_subchunk); + auto tQKbMqk_subchunk = thr_mma_bf16_subchunk.partition_B(cMqKB_subchunk); + + // do MMA at the granularity of 16x16x64 with two warps + constexpr auto tiler_subchunk_qk = Shape<_16, Shape<_32, _1>>{}; + constexpr auto tiler_subchunk_alpha = Shape<_16>{}; + constexpr auto tiler_subchunk_beta = Shape<_16>{}; + auto sQqk_curr = sQqk(_, _, q_smem_pipe_read.index()); + auto sKqk_curr = sKqk(_, _, k_smem_pipe_read.index()); + auto sAlpha_curr = Alpha(_, alpha_smem_pipe_read.index()); // (_64) + Tensor sBeta_curr = Beta(_, beta_smem_pipe_read.index()); + + // (_16,(_32,_1),_4,(_2,_2)):(_64,(_1,_0),_1024,(_32,_4096)) + auto sQqk_slice = flat_divide(sQqk_curr, tiler_subchunk_qk); + auto sKqk_slice = flat_divide(sKqk_curr, tiler_subchunk_qk); + // (_16, _4) + auto sAlpha_slice = flat_divide(sAlpha_curr, tiler_subchunk_alpha); + auto sBeta_slice = flat_divide(sBeta_curr, tiler_subchunk_beta); + + // Acc results + constexpr auto tiler_acc_qk_kk = Shape<_16, _16>{}; + static_assert(sizeof(Element) == 2); + auto sQK_curr = sQK(_, _, qk_smem_pipe_write.index()); + auto sQK_slice = flat_divide(sQK_curr, tiler_acc_qk_kk); + auto sKK_inv_curr = sKK_inv(_, _, kk_smem_pipe_write.index()); + auto sKK_inv_slice = flat_divide(sKK_inv_curr, tiler_acc_qk_kk); + + // used for make_fragment_like in Alpha and S2R (TF32 MMA layout) + Tensor sQqk_1_0 = sQqk_slice(_, _, _1{}, make_coord(_0{}, _0{})); + Tensor sKqk_1_0 = sKqk_slice(_, _, _1{}, make_coord(_0{}, _0{})); + Tensor tQKrQ_1_0 = thr_mma_subchunk.partition_fragment_A(sQqk_1_0); + Tensor tQKrKt_1_0 = thr_mma_subchunk.partition_fragment_B(sKqk_1_0); + auto tv_layout_mma_A = tQKrQ_1_0.layout(); + auto tv_layout_mma_B = tQKrKt_1_0.layout(); + + // BF16 MMA fragment layouts for LDSM-based S2R loads (same shape, different TV mapping) + Tensor tQKrQ_bf16_1_0 = thr_mma_bf16_subchunk.partition_fragment_A(sQqk_1_0); + Tensor tQKrKt_bf16_1_0 = thr_mma_bf16_subchunk.partition_fragment_B(sKqk_1_0); + auto tv_layout_bf16_mma_A = tQKrQ_bf16_1_0.layout(); + auto tv_layout_bf16_mma_B = tQKrKt_bf16_1_0.layout(); + + // S2R Q/K/G for operand A at row r, head dim slice j, and element-wise compute. + // Loads alpha once in BF16 MMA layout, derives g_first via warp shuffle (8 shuffles, + // replaces 1 S2R load), gates Q/K before the BF16→TF32 layout conversion. + // Also extracts g_first in operand B layout for free (broadcast → register copy). + // j0 = j % 2, j1 = j / 2: precomputed by caller to avoid redundant div/mod. + // returns (tQKrQ, tQKrK, tArAfirst_kt) = (Q * exp2(g - g_first), K * exp2(g - g_first), g_first in B + // layout) + auto s2r_compute_subchunk_operandA = [&](auto r_, int j, int j0, int j1) INLINE_LAMBDA { + // S2R g_r_j in BF16 MMA operand A layout (single load) + Tensor sAlpha_r = sAlpha_slice(_, r_); + Tensor tArA_r = make_fragment_like(tv_layout_bf16_mma_A); + // unrolled loop for explicit copy instead of using Cutlass copy + #pragma unroll + for (int v = 0 ; v < size(tArA_r); v++) { + tArA_r(v) = sAlpha_r(get<0>(tQKaMqk_subchunk(v))); + } + + + // Derive g_first (alpha[row=0, :]) from tArA_r_j via warp shuffle, + // directly into operand B layout (8 values instead of 16). + // g_first is broadcast (all M rows identical), so operand B only needs the + // v1=0 subset of operand A. We shuffle v1=0 values from t1=0 thread and + // output directly as operand B fragment, saving 8 float registers. + Tensor tArAfirst_r_j_kt = make_fragment_like(tv_layout_bf16_mma_B); + broadcast_row0_operandA_to_operandB_bf16_layout(tArA_r, tArAfirst_r_j_kt, local_thread_idx); + + // gqn_r_j = exp2(g_r_j - g_r_j_first[None, :]) in BF16 MMA A layout. + // g_first per k-iter is in tArAfirst_r_j_kt: frag_B(2j)=K_lo, frag_B(2j+1)=K_hi. + // In A layout: v1=0 indices (4j+0, 4j+1) have same K as v1=1 indices (4j+2, 4j+3), + // so g_first for index 4j+{0,2} = frag_B(2j), for 4j+{1,3} = frag_B(2j+1). + CUTE_UNROLL + for (int k = 0; k < 4; k++) { + auto gf_lo = tArAfirst_r_j_kt(2 * k); // g_first at K = 2*t0 + auto gf_hi = tArAfirst_r_j_kt(2 * k + 1); // g_first at K = 2*t0+1 + tArA_r(4 * k + 0) = exp2f(tArA_r(4 * k + 0) - gf_lo); // v0=0, v1=0 + tArA_r(4 * k + 1) = exp2f(tArA_r(4 * k + 1) - gf_hi); // v0=1, v1=0 + tArA_r(4 * k + 2) = exp2f(tArA_r(4 * k + 2) - gf_lo); // v0=0, v1=1 + tArA_r(4 * k + 3) = exp2f(tArA_r(4 * k + 3) - gf_hi); // v0=1, v1=1 + } + + Tensor sQqk_r_j = sQqk_slice(_, _, r_, make_coord(j0, j1)); + Tensor sKqk_r_j = sKqk_slice(_, _, r_, make_coord(j0, j1)); + + // --- Process Q --- + // S2R Q in BF16 MMA layout + Tensor tQKrQ_r_j_bf16 = make_fragment_like(tv_layout_bf16_mma_A); + Tensor tQKsQ_r_j = Q_thr_copy.partition_S(sQqk_r_j); + Tensor tQKrQ_r_j_bf16_cv = Q_thr_copy.retile_D(tQKrQ_r_j_bf16); + copy(Q_tiled_copy, tQKsQ_r_j, tQKrQ_r_j_bf16_cv); + // gate: Q * exp2(g - g_first) in BF16 MMA layout, producing float + Tensor tQKrQ_r_j_float = make_fragment_like(tv_layout_bf16_mma_A); + cute::transform( + tQKrQ_r_j_bf16, tArA_r, tQKrQ_r_j_float, [&](auto q, auto g) { return float(q) * g; }); + // convert BF16 MMA layout → TF32 MMA layout in-place via warp shuffles + convert_bf16_to_tf32_operandA_layout(tQKrQ_r_j_float, local_thread_idx); + // NOTE: triton tl.dot also lets MMA hardware for truncation + // recast float storage as tf32 view (zero cost, same 32-bit registers; MMA hw truncates) + auto tQKrQ_r_j = recast(tQKrQ_r_j_float); + + // --- Process K (sequential, after Q is done to reduce peak reg usage) --- + Tensor tQKrK_r_j_bf16 = make_fragment_like(tv_layout_bf16_mma_A); + Tensor tQKsK_r_j = Q_thr_copy.partition_S(sKqk_r_j); + Tensor tQKrK_r_j_bf16_cv = Q_thr_copy.retile_D(tQKrK_r_j_bf16); + copy(Q_tiled_copy, tQKsK_r_j, tQKrK_r_j_bf16_cv); + Tensor tQKrK_r_j_float = make_fragment_like(tv_layout_bf16_mma_A); + cute::transform( + tQKrK_r_j_bf16, tArA_r, tQKrK_r_j_float, [&](auto k, auto g) { return float(k) * g; }); + // convert BF16 MMA layout → TF32 MMA layout in-place via warp shuffles + convert_bf16_to_tf32_operandA_layout(tQKrK_r_j_float, local_thread_idx); + auto tQKrK_r_j = recast(tQKrK_r_j_float); + + return cute::make_tuple(tQKrQ_r_j, tQKrK_r_j, tArAfirst_r_j_kt); + }; + + // S2R K/G for operand B at column c, head dim slice j, and element-wise compute + // Loads alpha in BF16 MMA B layout, gates K in BF16 MMA B layout (before shuffle), + // then converts gated result to TF32 MMA B layout. + // tArAfirst_kt: pre-loaded g_first register tensor (BF16 MMA B layout) for computing gktn = exp2(g_first - + // g_c) returns tQKrKt = K_c * exp2(g_first - g_c) + auto s2r_compute_subchunk_operandB = + [&](auto c_, int j, int j0, int j1, auto const& tArAfirst_kt) INLINE_LAMBDA { + // S2R g_c_j in BF16 MMA operand B layout + Tensor sAlpha_c = sAlpha_slice(_, c_); + Tensor tArA_c = make_fragment_like(tv_layout_bf16_mma_B); + #pragma unroll + for (int v = 0; v < size(tArA_c); v++) { + tArA_c(v) = sAlpha_c(get<0>(tQKbMqk_subchunk(v))); + } + + // compute gktn_c_j = exp2(g_first - g_c_j) in BF16 MMA B layout + cute::transform( + tArA_c, tArAfirst_kt, tArA_c, [&](auto g, auto g_first) { return exp2f(g_first - g); }); + + // S2R k_c_j using BF16 LDSM + Tensor sKqk_c_j = sKqk_slice(_, _, c_, make_coord(j0, j1)); + Tensor tQKrKt_c_j_bf16 = make_fragment_like(tv_layout_bf16_mma_B); + Tensor tQKsKt_c_j = Kt_thr_copy.partition_S(sKqk_c_j); + Tensor tQKrKt_c_j_bf16_cv = Kt_thr_copy.retile_D(tQKrKt_c_j_bf16); + copy(Kt_tiled_copy, tQKsKt_c_j, tQKrKt_c_j_bf16_cv); + + // convert bf16 → float in BF16 MMA B layout + Tensor tQKrKt_c_j_float = make_fragment_like(tv_layout_bf16_mma_B); + // gate in BF16 MMA B layout (alpha and K are in the same layout) + cute::transform( + tQKrKt_c_j_bf16, tArA_c, tQKrKt_c_j_float, [&](auto k, auto g) { return float(k) * g; }); + + // convert BF16 MMA layout → TF32 MMA layout in-place via warp shuffles + convert_bf16_to_tf32_operandB_layout(tQKrKt_c_j_float, local_thread_idx); + auto tQKrKt_c_j = recast(tQKrKt_c_j_float); + + return tQKrKt_c_j; + }; + + // R2S (register to shared memory) store for subchunk accumulator results + // Stores both tQKrQK (QK accumulator, fp32 -> Element) and tKKrKK (KK accumulator, fp32 -> InverseType) + // into their respective shared memory tiles at position (r_, c_) + auto r2s_subchunk_acc = [&](auto r_, auto c_, auto const& tQKrQK, auto const& tKKrKK) INLINE_LAMBDA { + // R2S KK + Tensor sKK_inv_r_c = sKK_inv_slice(_, _, r_, c_); + Tensor tKKsKK_r_c = O_thr_copy_kk.partition_D(sKK_inv_r_c); + Tensor tKKrKK_cv = O_thr_copy_kk.retile_S(tKKrKK); + auto tKKrKK_cvt_cv = make_fragment_like(tKKrKK_cv); + cute::transform(tKKrKK_cv, tKKrKK_cvt_cv, [](auto v) { return InverseType(v); }); + copy(O_tiled_copy_kk, tKKrKK_cvt_cv, tKKsKK_r_c); + + // R2S QK + Tensor sQK_r_c = sQK_slice(_, _, r_, c_); + Tensor tQKsQK_r_c = O_thr_copy.partition_D(sQK_r_c); + Tensor tQKrQK_cv = O_thr_copy.retile_S(tQKrQK); + auto tQKrQK_cvt_cv = make_fragment_like(tQKrQK_cv); + cute::transform(tQKrQK_cv, tQKrQK_cvt_cv, [](auto v) { return Element(v); }); + copy(O_tiled_copy, tQKrQK_cvt_cv, tQKsQK_r_c); + }; + + // do tensor core GEMM with single 16x16x128 + // NOTE: should use safe_gate with lower_bound >= -5, otherwise overflow issues + auto gemm_tensor_core_1x16x16x128 = + [&](auto r_, auto c_, auto is_diagonal_, auto is_first_subchunk_) INLINE_LAMBDA { + constexpr bool is_first_subchunk = decltype(is_first_subchunk_)::value; + + // allocate acc_r_c [16, 16] + Tensor tQKrQK_r_c = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_r_c = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + clear(tQKrQK_r_c); + clear(tKKrKK_r_c); + // wait for data ready + if constexpr (is_first_subchunk) { + q_pipeline.consumer_wait(q_smem_pipe_read); + k_pipeline.consumer_wait(k_smem_pipe_read); + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + } + + // for loop head dim + CUTE_NO_UNROLL + for (int j = 0; j < NK; ++j) { + int j0 = j % 2, j1 = j / 2; + // S2R Q/K/G/g_first for operand A and element-wise compute + auto [tQKrQ_r_j, tQKrK_r_j, tArAfirst_r_j_kt] = s2r_compute_subchunk_operandA(r_, j, j0, j1); + // S2R K/G for operand B and element-wise compute + auto tQKrKt_c_j = s2r_compute_subchunk_operandB(c_, j, j0, j1, tArAfirst_r_j_kt); + + // q_r_j/k_r_j @ k_c_j, accumulate acc_r_c + gemm(tiledmma_subchunk, tQKrQ_r_j, tQKrKt_c_j, tQKrQK_r_c); + gemm(tiledmma_subchunk, tQKrK_r_j, tQKrKt_c_j, tKKrKK_r_c); + } + + // S2R beta_j (maybe resident in register?) + // epilogue: qk^t * scale + cute::transform(tQKrQK_r_c, [scale](auto v) { return v * scale; }); + // epilogue: kk^t * beta_r + if constexpr (is_first_subchunk) { + beta_pipeline.consumer_wait(beta_smem_pipe_read); + cutlass::arch::fence_view_async_shared(); + } + Tensor sBeta_r = sBeta_slice(_, r_); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk_subchunk(i); + auto [s, t] = coord; + tKKrKK_r_c(i) *= sBeta_r(s); + }); + + // R2S qk_r_c, kk_r_c, wait for current QK/KK free + if constexpr (is_first_subchunk) { + kk_pipeline.producer_acquire(kk_smem_pipe_write); + } + if constexpr (is_first_subchunk) { + qk_pipeline.producer_acquire(qk_smem_pipe_write); + } + r2s_subchunk_acc(r_, c_, tQKrQK_r_c, tKKrKK_r_c); + }; + + // zero fill for upper triangular of QK and KK, because smem is randomly initialized + auto zero_fill = [&](int row, int col) INLINE_LAMBDA { + auto sQK_r_c = sQK_slice(_, _, row, col); + auto sKK_r_c = sKK_inv_slice(_, _, row, col); + // allocate regs + Tensor tQKrQK_r_c = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKsQK_r_c = O_thr_copy.partition_D(sQK_r_c); + Tensor tQKrQK_r_c_cv = O_thr_copy.retile_S(tQKrQK_r_c); + Tensor tKKsKK_r_c = O_thr_copy_kk.partition_D(sKK_r_c); + Tensor tKKrKK_r_c_cv = O_thr_copy_kk.retile_S(tQKrQK_r_c); + auto tQKrQK_r_c_cvt_cv = make_fragment_like(tQKrQK_r_c_cv); + auto tKKrKK_r_c_cvt_cv = make_fragment_like(tKKrKK_r_c_cv); + // zero fill + clear(tQKrQK_r_c_cvt_cv); + clear(tKKrKK_r_c_cvt_cv); + // R2S + copy(O_tiled_copy, tQKrQK_r_c_cvt_cv, tQKsQK_r_c); + copy(O_tiled_copy_kk, tKKrKK_r_c_cvt_cv, tKKsKK_r_c); + }; + + // g_i_j/q_i_j/k_i_j: the j-th head dim slice of the i-th subchunk + if (thread_idx < 64) { + // Q/K0@K0, Q/K3@K3, Q/K3@K0, Q/K3@K1, Q/K3@K2 + + // NOTE: tensor core MMA for safe gate with lower_bound >= -5 + gemm_tensor_core_1x16x16x128( + Int<0>{}, + Int<0>{}, + /*is_diagonal_=*/cute::true_type{}, + /*is_first_subchunk_=*/cute::true_type{}); + + // Q/K3@K0, Q/K3@K1, Q/K3@K2 + // allocate acc_3_0, acc_3_1, acc_3_2 [16, 16] + Tensor tQKrQK_3_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_3_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_3_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_3_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_3_2 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_3_2 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_3_3 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_3_3 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + clear(tQKrQK_3_0); + clear(tKKrKK_3_0); + clear(tQKrQK_3_1); + clear(tKKrKK_3_1); + clear(tQKrQK_3_2); + clear(tKKrKK_3_2); + clear(tQKrQK_3_3); + clear(tKKrKK_3_3); + + // for loop head dim + CUTE_NO_UNROLL + for (int j = 0; j < NK; ++j) { + int j0 = j % 2, j1 = j / 2; + // S2R Q/K/G/g_first for operand A (row 3) and element-wise compute + auto [tQKrQ_3_j, tQKrK_3_j, tArAfirst_3_j_kt] = s2r_compute_subchunk_operandA(_3{}, j, j0, j1); + + // S2R K/G for operand B (col 0) and element-wise compute + auto tQKrKt_0_j = s2r_compute_subchunk_operandB(_0{}, j, j0, j1, tArAfirst_3_j_kt); + // q_3_j/k_3_j @ k_0_j, accumulate acc_3_0 + gemm(tiledmma_subchunk, tQKrQ_3_j, tQKrKt_0_j, tQKrQK_3_0); + gemm(tiledmma_subchunk, tQKrK_3_j, tQKrKt_0_j, tKKrKK_3_0); + + // S2R K/G for operand B (col 1) and element-wise compute + auto tQKrKt_1_j = s2r_compute_subchunk_operandB(_1{}, j, j0, j1, tArAfirst_3_j_kt); + // q_3_j/k_3_j @ k_1_j, accumulate acc_3_1 + gemm(tiledmma_subchunk, tQKrQ_3_j, tQKrKt_1_j, tQKrQK_3_1); + gemm(tiledmma_subchunk, tQKrK_3_j, tQKrKt_1_j, tKKrKK_3_1); + + // S2R K/G for operand B (col 2) and element-wise compute + auto tQKrKt_2_j = s2r_compute_subchunk_operandB(_2{}, j, j0, j1, tArAfirst_3_j_kt); + // q_3_j/k_3_j @ k_2_j, accumulate acc_3_2 + gemm(tiledmma_subchunk, tQKrQ_3_j, tQKrKt_2_j, tQKrQK_3_2); + gemm(tiledmma_subchunk, tQKrK_3_j, tQKrKt_2_j, tKKrKK_3_2); + + // S2R K/G for operand B (col 3) and element-wise compute + auto tQKrKt_3_j = s2r_compute_subchunk_operandB(_3{}, j, j0, j1, tArAfirst_3_j_kt); + // q_3_j/k_3_j @ k_3_j, accumulate acc_3_3 + gemm(tiledmma_subchunk, tQKrQ_3_j, tQKrKt_3_j, tQKrQK_3_3); + gemm(tiledmma_subchunk, tQKrK_3_j, tQKrKt_3_j, tKKrKK_3_3); + } + + // S2R beta (maybe resident in register?) + // epilogue: qk^t * scale + cute::transform(tQKrQK_3_0, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_3_1, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_3_2, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_3_3, [scale](auto v) { return v * scale; }); + // epilogue: kk^t * beta_3 + Tensor sBeta_3 = sBeta_slice(_, _3{}); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk_subchunk(i); + auto [s, t] = coord; + auto b = sBeta_3(s); + tKKrKK_3_0(i) *= b; + tKKrKK_3_1(i) *= b; + tKKrKK_3_2(i) *= b; + tKKrKK_3_3(i) *= b; + }); + + // R2S qk_3_0, kk_3_0, wait for current QK/KK free + r2s_subchunk_acc(_3{}, _0{}, tQKrQK_3_0, tKKrKK_3_0); + // R2S qk_3_1, kk_3_1 + r2s_subchunk_acc(_3{}, _1{}, tQKrQK_3_1, tKKrKK_3_1); + // R2S qk_3_2, kk_3_2 + r2s_subchunk_acc(_3{}, _2{}, tQKrQK_3_2, tKKrKK_3_2); + // R2S qk_3_3, kk_3_3 + r2s_subchunk_acc(_3{}, _3{}, tQKrQK_3_3, tKKrKK_3_3); + } else { + // Q/K1@K0, Q/K2@K0, Q/K2@K1, Q/K2@K2, Q/K1@K1 + + // Q/K1@K0 + // allocate acc_1_0 [16, 16] + Tensor tQKrQK_1_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_1_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_1_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_1_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + clear(tQKrQK_1_0); + clear(tKKrKK_1_0); + clear(tQKrQK_1_1); + clear(tKKrKK_1_1); + // wait for data ready + q_pipeline.consumer_wait(q_smem_pipe_read); + k_pipeline.consumer_wait(k_smem_pipe_read); + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_wait(alpha_smem_pipe_read); + } + + // for loop head dim + CUTE_NO_UNROLL + for (int j = 0; j < NK; ++j) { + int j0 = j % 2, j1 = j / 2; + // S2R Q/K/G/g_first for operand A (row 1) and element-wise compute + auto [tQKrQ_1_j, tQKrK_1_j, tArAfirst_1_j_kt] = s2r_compute_subchunk_operandA(_1{}, j, j0, j1); + + // S2R K/G for operand B (col 0) and element-wise compute + auto tQKrKt_0_j = s2r_compute_subchunk_operandB(_0{}, j, j0, j1, tArAfirst_1_j_kt); + // q_1_j/k_1_j @ k_0_j, accumulate acc_1_0 + gemm(tiledmma_subchunk, tQKrQ_1_j, tQKrKt_0_j, tQKrQK_1_0); + gemm(tiledmma_subchunk, tQKrK_1_j, tQKrKt_0_j, tKKrKK_1_0); + + // S2R K/G for operand B (col 1) and element-wise compute + auto tQKrKt_1_j = s2r_compute_subchunk_operandB(_1{}, j, j0, j1, tArAfirst_1_j_kt); + // q_1_j/k_1_j @ k_1_j, accumulate acc_1_1 + gemm(tiledmma_subchunk, tQKrQ_1_j, tQKrKt_1_j, tQKrQK_1_1); + gemm(tiledmma_subchunk, tQKrK_1_j, tQKrKt_1_j, tKKrKK_1_1); + } + + // S2R beta_j (maybe resident in register?) + // epilogue: qk^t * scale + cute::transform(tQKrQK_1_0, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_1_1, [scale](auto v) { return v * scale; }); + // epilogue: kk^t * beta_1 + beta_pipeline.consumer_wait(beta_smem_pipe_read); + cutlass::arch::fence_view_async_shared(); + Tensor sBeta_1 = sBeta_slice(_, _1{}); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk_subchunk(i); + auto [s, t] = coord; + tKKrKK_1_0(i) *= sBeta_1(s); + tKKrKK_1_1(i) *= sBeta_1(s); + }); + + // R2S qk_1_0, kk_1_0, wait for current QK/KK free + kk_pipeline.producer_acquire(kk_smem_pipe_write); + qk_pipeline.producer_acquire(qk_smem_pipe_write); + + r2s_subchunk_acc(_1{}, _0{}, tQKrQK_1_0, tKKrKK_1_0); + // R2S qk_1_1, kk_1_1 + r2s_subchunk_acc(_1{}, _1{}, tQKrQK_1_1, tKKrKK_1_1); + + // Q/K2@K0, Q/K2@K1 + // allocate acc_2_0, acc_2_1 [16, 16] + Tensor tQKrQK_2_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_2_0 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_2_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_2_1 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tQKrQK_2_2 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + Tensor tKKrKK_2_2 = partition_fragment_C(tiledmma_subchunk, select<0, 1>(TileShape_SubChunk{})); + clear(tQKrQK_2_0); + clear(tKKrKK_2_0); + clear(tQKrQK_2_1); + clear(tKKrKK_2_1); + clear(tQKrQK_2_2); + clear(tKKrKK_2_2); + + // for loop head dim + CUTE_NO_UNROLL + for (int j = 0; j < NK; ++j) { + int j0 = j % 2, j1 = j / 2; + // S2R Q/K/G/g_first for operand A (row 2) and element-wise compute + auto [tQKrQ_2_j, tQKrK_2_j, tArAfirst_2_j_kt] = s2r_compute_subchunk_operandA(_2{}, j, j0, j1); + + // S2R K/G for operand B (col 0) and element-wise compute + auto tQKrKt_0_j = s2r_compute_subchunk_operandB(_0{}, j, j0, j1, tArAfirst_2_j_kt); + // q_2_j/k_2_j @ k_0_j, accumulate acc_2_0 + gemm(tiledmma_subchunk, tQKrQ_2_j, tQKrKt_0_j, tQKrQK_2_0); + gemm(tiledmma_subchunk, tQKrK_2_j, tQKrKt_0_j, tKKrKK_2_0); + + // S2R K/G for operand B (col 1) and element-wise compute + auto tQKrKt_1_j = s2r_compute_subchunk_operandB(_1{}, j, j0, j1, tArAfirst_2_j_kt); + // q_2_j/k_2_j @ k_1_j, accumulate acc_2_1 + gemm(tiledmma_subchunk, tQKrQ_2_j, tQKrKt_1_j, tQKrQK_2_1); + gemm(tiledmma_subchunk, tQKrK_2_j, tQKrKt_1_j, tKKrKK_2_1); + + // S2R K/G for operand B (col 2) and element-wise compute + auto tQKrKt_2_j = s2r_compute_subchunk_operandB(_2{}, j, j0, j1, tArAfirst_2_j_kt); + // q_2_j/k_2_j @ k_2_j, accumulate acc_2_2 + gemm(tiledmma_subchunk, tQKrQ_2_j, tQKrKt_2_j, tQKrQK_2_2); + gemm(tiledmma_subchunk, tQKrK_2_j, tQKrKt_2_j, tKKrKK_2_2); + } + + // S2R beta (maybe resident in register?) + // epilogue: qk^t * scale + cute::transform(tQKrQK_2_0, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_2_1, [scale](auto v) { return v * scale; }); + cute::transform(tQKrQK_2_2, [scale](auto v) { return v * scale; }); + // epilogue: kk^t * beta_2 + Tensor sBeta_2 = sBeta_slice(_, 2); + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk_subchunk(i); + auto [s, t] = coord; + auto b = sBeta_2(s); + tKKrKK_2_0(i) *= b; + tKKrKK_2_1(i) *= b; + tKKrKK_2_2(i) *= b; + }); + + // R2S qk_2_0, kk_2_0, wait for current QK/KK free + r2s_subchunk_acc(_2{}, _0{}, tQKrQK_2_0, tKKrKK_2_0); + // R2S qk_2_1, kk_2_1 + r2s_subchunk_acc(_2{}, _1{}, tQKrQK_2_1, tKKrKK_2_1); + // R2S qk_2_2, kk_2_2 + r2s_subchunk_acc(_2{}, _2{}, tQKrQK_2_2, tKKrKK_2_2); + } + }; + + auto qk_and_kk_epi = [&](auto is_final_block_, auto B /*valid seqlen*/) INLINE_LAMBDA { + using CopyOpS2R_Chunk = SM75_U32x4_LDSM_N; + using CopyOpR2S_Chunk = SM90_U32x4_STSM_N; + // S2R QK/KK + auto sQK_curr = sQK(_, _, qk_smem_pipe_write.index()); + auto sKK_inv_curr = sKK_inv(_, _, kk_smem_pipe_write.index()); + Tensor tQKrQK_ref = partition_fragment_C(TiledMmaQK_RS{}, select<0, 1>(TileShapeQK{})); + auto tiled_load_qk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma_rs); + auto thr_load_qk = tiled_load_qk.get_thread_slice(thread_idx); + auto tiled_load_kk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma_rs); + auto thr_load_kk = tiled_load_kk.get_thread_slice(thread_idx); + + auto tQKrQK_cv = thr_load_qk.retile_D(tQKrQK_ref); + auto tQKrQK = make_fragment_like(tQKrQK_cv); + auto tKKrKK = make_fragment_like(tQKrQK_cv); + copy(tiled_load_qk, thr_load_qk.partition_S(sQK_curr), tQKrQK); + copy(tiled_load_kk, thr_load_kk.partition_S(sKK_inv_curr), tKKrKK); + + // triangular mask and boundary mask + constexpr bool is_final_block = decltype(is_final_block_)::value; + for_each(make_int_sequence{}, [&](auto i) { + auto coord = tQKcMqk(i); + auto [s, t] = coord; + bool pred = s >= t; + tQKrQK(i) = pred ? tQKrQK(i) : Element(0.0f); + tKKrKK(i) = + pred ? tKKrKK(i) : InverseType(0.0f); // diagonal is garbage filled, will process during inversion + if constexpr (is_final_block) { + bool pred = s < B && t < B; + tQKrQK(i) = pred ? tQKrQK(i) : Element(0.0f); + tKKrKK(i) = pred ? tKKrKK(i) : InverseType(0.0f); + } + }); + + // R2S QK/KK + auto tiled_store_qk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma_rs); + auto thr_store_qk = tiled_store_qk.get_thread_slice(thread_idx); + auto tiled_store_kk = make_tiled_copy_C(Copy_Atom{}, qk_tiled_mma_rs); + auto thr_store_kk = tiled_store_kk.get_thread_slice(thread_idx); + + copy(tiled_store_qk, thr_store_qk.retile_S(tQKrQK), thr_store_qk.partition_D(sQK_curr)); + copy(tiled_store_kk, thr_store_kk.retile_S(tKKrKK), thr_store_kk.partition_D(sKK_inv_curr)); + }; + + auto compute_aux_loop_body = [&](int blk, auto is_final_block_) INLINE_LAMBDA { + constexpr bool is_final_block = decltype(is_final_block_)::value; + + int B = is_final_block ? valid_seq_len(work_desc, blk) : BlkSeqKV; + + // ====DEBUG===== + // maintain pipeline correctness while removing subchunk + // wait for data ready + // q_pipeline.consumer_wait(q_smem_pipe_read); + // k_pipeline.consumer_wait(k_smem_pipe_read); + // if constexpr (NeedsAlpha) { alpha_pipeline.consumer_wait(alpha_smem_pipe_read); } + // beta_pipeline.consumer_wait(beta_smem_pipe_read); + // cutlass::arch::fence_view_async_shared(); + + // qk_pipeline.producer_acquire(qk_smem_pipe_write); + // kk_pipeline.producer_acquire(kk_smem_pipe_write); + + // SubChunk MMA for QK^T and KK^T for numerical stability + // FIXME: use g_half as anchor in the diagonal subchunk to align with FLA for smaller numerical differences + qk_kk_subchunk_mma_and_store(blk); + // wait for QK/KK ready + cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, GdnNamedBarriers::AuxMath); + qk_and_kk_epi(is_final_block_, B); + // =====DEBUG====== + // cutlass::arch::NamedBarrier::arrive_and_wait(cutlass::NumThreadsPerWarpGroup, GdnNamedBarriers::AuxMath); + // if constexpr (is_final_block) { + // if (thread_idx == 0) { + // printf("sQK\n"); + // cute::print_tensor(sQK(_, _, qk_smem_pipe_write.index())); + // printf("sKK_inv\n"); + // cute::print_tensor(sKK_inv(_, _, kk_smem_pipe_write.index())); + // } + // } + + // QK/KK is ready to consume + cutlass::arch::fence_view_async_shared(); + qk_pipeline.producer_commit(qk_smem_pipe_write); + ++qk_smem_pipe_write; + kk_pipeline.producer_commit(kk_smem_pipe_write); + ++kk_smem_pipe_write; + + k_pipeline.consumer_release(k_smem_pipe_read); + ++k_smem_pipe_read; + q_pipeline.consumer_release(q_smem_pipe_read); + ++q_smem_pipe_read; + if constexpr (NeedsAlpha) { + alpha_pipeline.consumer_release(alpha_smem_pipe_read); + ++alpha_smem_pipe_read; + } + + if constexpr (NeedsBeta) { + beta_pipeline.consumer_release(beta_smem_pipe_read); + ++beta_smem_pipe_read; + } + }; + + int32_t num_blocks = ceil_div(work_desc.seq_len, get<0>(TileShape{})); + CUTE_NO_UNROLL + for (int blk = 0; blk < num_blocks - 1; ++blk) { + compute_aux_loop_body(blk, /*is_final_block_=*/cute::false_type{}); + } + compute_aux_loop_body(num_blocks - 1, /*is_final_block_=*/cute::true_type{}); + } + + template + CUTE_DEVICE int + valid_seq_len(WorkDesc work_desc, int blk_idx) { + int remain_len = work_desc.seq_len - BlkSeqKV * blk_idx; + return remain_len <= BlkSeqKV ? remain_len : BlkSeqKV; + } +}; + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/collective/named_barriers.hpp b/csrc/gdn/sm90/collective/named_barriers.hpp new file mode 100644 index 0000000..9523f44 --- /dev/null +++ b/csrc/gdn/sm90/collective/named_barriers.hpp @@ -0,0 +1,44 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace gdn::sm90::collective { + +struct FlatSharedNamedBarriers { + static constexpr int AllMmaThreadsSync = 0; + static constexpr int AllLdStThreadsSync = 1; + static constexpr int MmaCooperativeStore = 2; + + protected: + static constexpr int NumBarriersUsed = 4; +}; + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/collective/store_tma.hpp b/csrc/gdn/sm90/collective/store_tma.hpp new file mode 100644 index 0000000..8f2a876 --- /dev/null +++ b/csrc/gdn/sm90/collective/store_tma.hpp @@ -0,0 +1,313 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +#include "kerutils/kerutils.cuh" + +#include "gdn/sm90/utils/debug.hpp" + +namespace gdn::sm90::collective { + +using ku::alignment_for_swizzle; +using namespace cute; + +/* +NOTE: what we need is as follows + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveStoreO = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapeO1, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulatorO, ElementAccumulatorO, + void, LayoutO, Alignment, // C, not exists + ElementO, decltype(select<1,0,2>(LayoutO{})), Alignment, // D + cutlass::epilogue::TmaWarpSpecializedCooperative, DefaultOperation>::CollectiveOp; + +but unfortunately the required type alias is only useful for our purpose is private so we roll out our own. +*/ + +CUTE_DEVICE uint32_t +smid() { +#ifdef __CUDA_ARCH__ + uint32_t virtual_smid; + asm("mov.u32 %0, %%smid;" : "=r"(virtual_smid)); + return virtual_smid; +#else + return 0; +#endif +} + +template < + typename TileShape_MNK_, + typename ClusterShape, + typename ElementO, + typename ElementAccumulator, + typename SmemElementO, + typename StrideO, + int Stages> +struct CollectiveStoreTma { + static_assert(size_v == 1); + using TileShape_MNK = TileShape_MNK_; + using TileShape_MN = + decltype(select<0, 1>(TileShape_MNK{})); // Collective work on TileShape_MN, it is also the OutputTile + using SizeM = decltype(get<0>(TileShape_MNK{})); // head_size + using SizeN = decltype(get<1>(TileShape_MNK{})); // seqlen + + constexpr static bool is_m_major_O = cutlass::epilogue::collective::detail::is_m_major(); + +#if 0 + // NOTE: the following derived layout is a bit slower than the manual one, will evaluate it later + using SmemLayoutAtom = decltype(cutlass::epilogue::collective::detail::sm90_get_epilogue_smem_swizzle_layout_atom< + StrideO, ElementO, TileShape_MN>()); +#else + static_assert(sizeof(SmemElementO) == 2); + using SmemLayoutAtom = GMMA::Layout_MN_SW32_Atom; +#endif + + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtom{}, + make_shape(SizeM{}, SizeN{}, Int{}), + cute::conditional_t, Step<_1, _2, _3>>{})); + + constexpr static uint32_t TmaTransactionBytes = + (size(take<0, 2>(SmemLayoutO{})) * static_cast(sizeof_bits::value)) / 8; + + using CopyOpR2S = decltype(cutlass::epilogue::collective::detail:: + sm90_get_smem_store_op_for_accumulator()); + using CopyAtomR2S = Copy_Atom; + + using CopyOpS2G = SM90_TMA_STORE; + + using SharedStorage = + cute::array_aligned, alignment_for_swizzle(SmemLayoutO{})>; + using Pipeline = cutlass::PipelineAsync; // NOT PipelineTmaStore! + using PipelineState = cutlass::PipelineState; + + struct Arguments { + ElementO* ptr_O; + StrideO dO; + void* workspace; + }; + + struct Params { + using TMA_O = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), repeat_like(StrideO{}, int32_t(0)), StrideO{}), + take<0, 2>(SmemLayoutO{}), + TileShape_MN{}, + _1{})); + + TMA_O tma_store_o; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + void* tensormaps; + }; + using TMA = typename Params::TMA_O; + + CUTE_DEVICE + CollectiveStoreTma(TMA const& tma_store, Pipeline& pipeline, SharedStorage& storage, void* tensormaps) + : tma_store_(tma_store), pipeline_(pipeline), storage_(storage), tensormaps_(tensormaps) { + } + + template + static Params + to_underlying_arguments(ProblemSize const& problem_size, Arguments const& args, void* workspace) { + auto problem_size_MNKL = append<4>(problem_size, 1); + auto [M, N, K, L] = problem_size_MNKL; + + Tensor tensor_o = make_tensor(make_gmem_ptr(args.ptr_O), make_layout(make_shape(M, N, L), args.dO)); + TMA tma_store_o = make_tma_copy_C_sm90(CopyOpS2G{}, tensor_o, take<0, 2>(SmemLayoutO{}), TileShape_MN{}); + + return { + .tma_store_o = tma_store_o, + .tma_transaction_bytes = TmaTransactionBytes, + .tensormaps = workspace, + }; + } + + static size_t + get_workspace_size(/*Arguments const& args,*/ int sm_count) { + // only use additional TMA desc for output tail tiles + size_t num_bytes = sizeof(cute::TmaDescriptor) * sm_count; + DPRINTF("workspace num_bytes:%zu\n", num_bytes); + return num_bytes; + } + + template + static cutlass::Status + initialize_workspace( + ProblemShape const& problem_shape, + /*Arguments const& args,*/ void* workspace, + cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + CUTE_DEVICE static void + prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + template + CUTE_DEVICE auto + partition_SD(ProblemSize const& problem_size, TileShape const& tile_shape, WorkDesc const& work_desc) { + constexpr auto BlkSeqQ = decltype(get<0>(tile_shape))::value; + constexpr auto HeadSize = decltype(get<2>(tile_shape))::value; + + Tensor g = [&] { + DPRINTF0_W( + "slice view GMEM O: seq_idx:%d head_idx:%d tok_offset:%lld\n", + work_desc.seq_idx, + work_desc.o_head_idx(), + work_desc.tok_offset); + Tensor m_varlen_head = tma_store_.get_tma_tensor(make_shape( + problem_size.head_size, + problem_size.total_seqlen, + problem_size.num_heads)); // global view to the packed varlen sequence + Tensor m_varlen = m_varlen_head(_, _, work_desc.o_head_idx()); // slice into current head_idx + Tensor m_offset = domain_offset( + make_coord(_0{}, work_desc.tok_offset), + m_varlen); // offset to start of the current sequence + Tensor g_full = + local_tile(m_offset, make_tile(HeadSize, BlkSeqQ), make_coord(_0{}, _)); // (d, blk, iter_blk) + return g_full; + }(); + Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayoutO{}); + + auto block_tma = tma_store_.get_slice(_0{}); // do not support cluster + return make_tuple(block_tma.partition_S(s), block_tma.partition_D(g)); + } + + template + CUTE_DEVICE static bool + can_process(ProblemSize const& problem_size, WorkDesc const& work_desc, int blk, int num_blocks) { + if (blk < num_blocks - 1) { + // intermediate full tiles, always use TMA + return true; + } else if (work_desc.seq_len % SizeN{} == 0 || work_desc.seq_idx == problem_size.num_seqs - 1) { + // 1. last tile but full, also use TMA + // 2. last tile but last seq, oob can be handled by TMA + return true; + } else { + return false; + } + } + + template + CUTE_DEVICE void + step( + ProblemSize const& problem_size, + WorkDesc const& work_desc, + SrcDst const& src_dst, + PipelineState& src_pipe, + int dst_iter, + int num_iters, + uint32_t lane_predicate) { + auto src = get<0>(src_dst); + auto dst = get<1>(src_dst); + + if (dst_iter == 0) { + bool can_process_tail = can_process(problem_size, work_desc, num_iters - 1, num_iters); + if (!can_process_tail) { + create_tensormap_for_tail(work_desc, lane_predicate); + } + } + + DPRINTF0_W("Store O pipeline.producer_acquire smem_pipe_read:%d\n", src_pipe.index()); + if constexpr (kAcquireBarrier) { + pipeline_.consumer_wait(src_pipe); + } + + if (can_process(problem_size, work_desc, dst_iter, num_iters)) { + DPRINTF0_W("store src_pipe:%d -> blk:%d\n", src_pipe.index(), dst_iter); + if (lane_predicate == 1) { + copy(tma_store_, src(_, _, _, src_pipe.index()), dst(_, _, _, dst_iter)); + } + } else { + cute::TmaDescriptor* tensormap = acquire_tensormap_for_tail(); + DPRINTF0_W("store tail with tensormap:%p src_pipe:%d -> blk:%d\n", tensormap, src_pipe.index(), dst_iter); + if (lane_predicate == 1) { + copy(tma_store_.with(tensormap), src(_, _, _, src_pipe.index()), dst(_, _, _, dst_iter)); + } + } + + if constexpr (kAcquireBarrier) { + pipeline_.consumer_release(src_pipe); + } + ++src_pipe; + } + + template + CUTE_DEVICE void + create_tensormap_for_tail(WorkDesc const& work_desc, uint32_t lane_predicate) { + namespace ptx = cuda::ptx; + constexpr int num_of_16B = sizeof(cute::TmaDescriptor) / sizeof(uint128_t); + + cute::TmaDescriptor* tensormap = static_cast(tensormaps_) + smid(); + + auto lane_idx = cutlass::canonical_lane_idx(); + if (lane_idx < num_of_16B) { + auto src = reinterpret_cast(tma_store_.get_tma_descriptor()); + auto dst = reinterpret_cast(tensormap); + + dst[lane_idx] = src[lane_idx]; + } + __syncwarp(); + + if (lane_predicate == 1) { + uint32_t new_total_seqlen = work_desc.tok_offset + work_desc.seq_len; + ptx::tensormap_replace_global_dim(ptx::space_global, tensormap, /*ord=*/ptx::n32_t<1>{}, new_total_seqlen); + } + __syncwarp(); + + ptx::fence_proxy_tensormap_generic(ptx::sem_release, ptx::scope_cta); + } + + CUTE_DEVICE cute::TmaDescriptor* + acquire_tensormap_for_tail() { + namespace ptx = cuda::ptx; + cute::TmaDescriptor* tensormap = static_cast(tensormaps_) + smid(); + ptx::fence_proxy_tensormap_generic(ptx::sem_acquire, ptx::scope_cta, tensormap, /*size=*/ptx::n32_t<128>{}); + return tensormap; + } + + private: + TMA const& tma_store_; + Pipeline& pipeline_; + SharedStorage& storage_; + void* tensormaps_; +}; + +} // namespace gdn::sm90::collective diff --git a/csrc/gdn/sm90/device/device_universal.hpp b/csrc/gdn/sm90/device/device_universal.hpp new file mode 100644 index 0000000..18005ce --- /dev/null +++ b/csrc/gdn/sm90/device/device_universal.hpp @@ -0,0 +1,244 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// common +#include +#include + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif // !defined(__CUDACC_RTC__) + +namespace cutlass::device { + +template +class Universal { + public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + + private: + /// Kernel API parameters object + Params params_; + + public: + /// Access the Params structure + Params const& + params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int + maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = + cudaFuncSetAttribute(device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, device_kernel, Kernel::MaxThreadsPerBlock, smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST( + "Universal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = + cudaFuncSetAttribute(device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr (Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster( + cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*)device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/gdn/sm90/gdn_fwd_sm90.cu b/csrc/gdn/sm90/gdn_fwd_sm90.cu new file mode 100644 index 0000000..f7bfcaf --- /dev/null +++ b/csrc/gdn/sm90/gdn_fwd_sm90.cu @@ -0,0 +1,144 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Dispatch function only — does NOT include the .cuh to avoid +// implicit instantiation of all kernel variants in one TU. +// Each SafeGate variant is explicitly instantiated in its own .cu file. + +#include +#include + +namespace gdn::sm90 { + +using namespace cute; + +// Forward declaration of the per-variant launcher (defined in .cuh, instantiated in separate TUs) +template < + bool NeedsBeta, + bool NeedsAlpha, + bool InitStateFromInput, + bool SafeGate, + typename ArchTag, + typename TO, + typename TQKV, + typename TState> +void +launch_gdn_fwd_prefill_kernel_gbai( + cudaStream_t stream, + TO* output, + TState* output_state, + TQKV const* q, + TQKV const* k, + TQKV const* v, + TState const* input_state, + float const* alpha, + float const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + int32_t sm_count); + +template < + typename ArchTag, // TODO: hide this + typename TO, + typename TQKV, + typename TState> +void +launch_gdn_fwd_prefill_kernel( + cudaStream_t stream, + TO* output, + TState* output_state, + TQKV const* q, + TQKV const* k, + TQKV const* v, + TState const* input_state, + float const* alpha, + float const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + bool safe_gate, + int32_t sm_count = 0) { + bool needs_beta = beta != nullptr; + bool needs_alpha = alpha != nullptr; + bool init_state = input_state != nullptr; + +#define LAUNCH(needs_beta, needs_alpha, init_state, safe_gate) \ + launch_gdn_fwd_prefill_kernel_gbai( \ + stream, \ + output, \ + output_state, \ + q, \ + k, \ + v, \ + input_state, \ + alpha, \ + beta, \ + cu_seqlens, \ + workspace_buffer, \ + num_seqs, \ + num_heads, \ + head_size, \ + total_seqlen, \ + scale, \ + sm_count); + if (init_state) { + if (needs_beta && needs_alpha && safe_gate) { + LAUNCH(true, true, true, true); + } else { + throw std::runtime_error("unreachable"); + } + } else { + if (needs_beta && needs_alpha && safe_gate) { + LAUNCH(true, true, false, true); + } else { + throw std::runtime_error("unreachable"); + } + } + +#undef LAUNCH +} + +using bf16 = cute::bfloat16_t; + +template void +launch_gdn_fwd_prefill_kernel( + cudaStream_t stream, + bf16* output, + float* state, + bf16 const* q, + bf16 const* k, + bf16 const* v, + float const* input_state, + float const* alpha, + float const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + bool safe_gate, + int32_t sm_count); + +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/gdn_fwd_sm90_safe_gate.cu b/csrc/gdn/sm90/gdn_fwd_sm90_safe_gate.cu new file mode 100644 index 0000000..4725127 --- /dev/null +++ b/csrc/gdn/sm90/gdn_fwd_sm90_safe_gate.cu @@ -0,0 +1,68 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh" +#include "gdn/sm90/utils/common.hpp" + +namespace gdn::sm90 { + +using namespace cute; +using bf16 = cute::bfloat16_t; + +// SafeGate=true, InitState=false +template void +launch_gdn_fwd_prefill_kernel_gbai( + cudaStream_t, + bf16*, + float*, + bf16 const*, + bf16 const*, + bf16 const*, + float const*, + float const*, + float const*, + int32_t const*, + uint8_t*, + int32_t, + int32_t, + int32_t, + int64_t, + float, + int32_t); + +// SafeGate=true, InitState=true +template void +launch_gdn_fwd_prefill_kernel_gbai( + cudaStream_t, + bf16*, + float*, + bf16 const*, + bf16 const*, + bf16 const*, + float const*, + float const*, + float const*, + int32_t const*, + uint8_t*, + int32_t, + int32_t, + int32_t, + int64_t, + float, + int32_t); + +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/kernel/builder_gdn_fwd.hpp b/csrc/gdn/sm90/kernel/builder_gdn_fwd.hpp new file mode 100644 index 0000000..78f797e --- /dev/null +++ b/csrc/gdn/sm90/kernel/builder_gdn_fwd.hpp @@ -0,0 +1,80 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "gdn/sm90/collective/mainloop_gdn_fwd.hpp" +#include "gdn/sm90/kernel/kernel_gdn_fwd.hpp" +#include "gdn/sm90/kernel/options.hpp" +#include "gdn/sm90/kernel/tile_scheduler.hpp" +#include "gdn/sm90/utils/type_traits.hpp" + +namespace gdn::sm90::kernel { + +template < + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // BlkSeqQO, BlkSeqKV, HeadSize + class LayoutQ_, + class LayoutK_, + class LayoutV_, + class LayoutO_, + class DispatchPolicy, + class Options = DefaultOptions> +struct FlatBuilderGdnFwd; + +template < + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlkSeqQO, BlkSeqKV, HeadSize + class LayoutQ, + class LayoutK, + class LayoutV, + class LayoutO, + class Options> +struct FlatBuilderGdnFwd< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + LayoutO, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options> { + using CollectiveMainloop = gdn::sm90::collective::FlatMainloopTmaWarpSpecializedGdnFwd< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + LayoutO, + Options>; + + static constexpr bool kIsPersistent = find_option_t::value; + static_assert(!kIsPersistent, "not implemented"); + + using TileScheduler = gdn::sm90::kernel::IndividualTileScheduler; + // using TileScheduler = std::conditional_t; + + using Kernel = gdn::sm90::kernel::FlatKernelTmaWarpSpecializedGdnFwd; +}; + +} // namespace gdn::sm90::kernel diff --git a/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp b/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp new file mode 100644 index 0000000..5774bae --- /dev/null +++ b/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp @@ -0,0 +1,630 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "gdn/sm90/kernel/options.hpp" +#include "gdn/sm90/utils/common.hpp" +#include "gdn/sm90/utils/unused.hpp" + +namespace gdn::sm90::kernel { + +using namespace cute; + +template +constexpr T1 +round_down(T1 a, T2 b) { + return (a / b) * b; +} + +constexpr std::tuple +get_register_requirements( + uint32_t max_threads_per_block, + uint32_t min_blocks_per_multiprocessor, + uint32_t num_state_mma_warp_groups // state related mma +) { + uint32_t reg_alloc_granularity = 8; + +#if !defined(FLAT_DEBUG_PRINT) || !FLAT_DEBUG_PRINT + uint32_t load_registers = 40 - 2 * reg_alloc_granularity; +#else + uint32_t load_registers = 40; +#endif + // TODO: better tuning + uint32_t total_aux_load_budget = 176; + uint32_t aux_registers = total_aux_load_budget - load_registers; // (24 + X) or (40 + X) + + uint32_t total_registers = + round_down(64 * 1024 / min_blocks_per_multiprocessor, max_threads_per_block * reg_alloc_granularity) / + cutlass::NumThreadsPerWarpGroup; + uint32_t mma_registers = round_down( + (total_registers - load_registers - aux_registers) / num_state_mma_warp_groups, reg_alloc_granularity); + + // max reg is 255, 248 round to multiple of reg_alloc_granularity; + return {cute::min(248, load_registers), cute::min(248, mma_registers), cute::min(248, aux_registers)}; +} + +template +struct FlatKernelTmaWarpSpecializedGdnFwd { + using ArchTag = cutlass::arch::Sm90; + + static const int NumLoadWarpGroups = 1; + static constexpr int NumStateMmaWarpGroups = CollectiveMainloop::NumStateMmaWarpGroups; + static constexpr int NumAuxMmaWarpGroups = CollectiveMainloop::NumAuxMmaWarpGroups; + + static constexpr int NeedsAlpha = CollectiveMainloop::NeedsAlpha; + static constexpr int NeedsBeta = CollectiveMainloop::NeedsBeta; + static constexpr int SafeGate = CollectiveMainloop::SafeGate; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopQPipeline = typename CollectiveMainloop::MainloopQPipeline; + using MainloopKPipeline = typename CollectiveMainloop::MainloopKPipeline; + using MainloopVPipeline = typename CollectiveMainloop::MainloopVPipeline; + using MainloopOPipeline = typename CollectiveMainloop::MainloopOPipeline; + + using MainloopQKPipeline = typename CollectiveMainloop::MainloopQKPipeline; + using MainloopKKPipeline = typename CollectiveMainloop::MainloopKKPipeline; + + using MainloopAlphaLastPipeline = typename CollectiveMainloop::MainloopAlphaLastPipeline; + + using MainloopAlphaPipeline = typename CollectiveMainloop::MainloopAlphaPipeline; + using MainloopBetaPipeline = typename CollectiveMainloop::MainloopBetaPipeline; + + using OrderedMathBarriers = typename CollectiveMainloop::OrderedMathBarriers; + + static constexpr uint32_t StagesPerMathWarpGroup = 2; + + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + struct TensorStorage { + typename CollectiveMainloop::SharedStorage mainloop; + }; + + struct SharedStorage { + TensorStorage tensors; + + using QPipelineStorage = typename MainloopQPipeline::SharedStorage; + using KPipelineStorage = typename MainloopKPipeline::SharedStorage; + using VPipelineStorage = typename MainloopVPipeline::SharedStorage; + using OPipelineStorage = typename MainloopOPipeline::SharedStorage; + + alignas(16) QPipelineStorage q_pipeline_storage; + alignas(16) KPipelineStorage k_pipeline_storage; + alignas(16) VPipelineStorage v_pipeline_storage; + alignas(16) OPipelineStorage o_pipeline_storage; + + using QKPipelineStorage = typename MainloopQKPipeline::SharedStorage; + using KKPipelineStorage = typename MainloopKKPipeline::SharedStorage; + + alignas(16) QKPipelineStorage qk_pipeline_storage; + alignas(16) KKPipelineStorage kk_pipeline_storage; + + using AlphaLastPipelineStorage = typename MainloopAlphaLastPipeline::SharedStorage; + + alignas(16) AlphaLastPipelineStorage alpha_last_pipeline_storage; + + using AlphaPipelineStorage = typename MainloopAlphaPipeline::SharedStorage; + using BetaPipelineStorage = typename MainloopBetaPipeline::SharedStorage; + alignas(16) AlphaPipelineStorage alpha_pipeline_storage; + alignas(16) BetaPipelineStorage beta_pipeline_storage; + + alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; + }; + static constexpr int SharedStorageSize = sizeof(SharedStorage); + struct VarlenProblemShape { + int32_t const* cu_seqlens; + int64_t total_seqlen; + int32_t num_seqs; + int32_t num_heads; // Q, K, V, O all share the same head count in GDN + int32_t head_size; // d + }; + using ProblemShape = VarlenProblemShape; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename TileScheduler::Params scheduler; + }; + + using QPipelineParams = typename MainloopQPipeline::Params; + using QPipelineState = typename cutlass::PipelineState; + + using KPipelineParams = typename MainloopKPipeline::Params; + using KPipelineState = typename cutlass::PipelineState; + + using VPipelineParams = typename MainloopVPipeline::Params; + using VPipelineState = typename cutlass::PipelineState; + + using OPipelineParams = typename MainloopOPipeline::Params; + using OPipelineState = typename cutlass::PipelineState; + + using QKPipelineParams = typename MainloopQKPipeline::Params; + using QKPipelineState = typename cutlass::PipelineState; + + using KKPipelineParams = typename MainloopKKPipeline::Params; + using KKPipelineState = typename cutlass::PipelineState; + + using AlphaLastPipelineParams = std::conditional_t; + using AlphaLastPipelineState = + std::conditional_t, Unused>; + + using AlphaPipelineParams = std::conditional_t; + using AlphaPipelineState = + std::conditional_t, Unused>; + + using BetaPipelineParams = std::conditional_t; + using BetaPipelineState = + std::conditional_t, Unused>; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int MaxThreadsPerBlock = + (NumLoadWarpGroups + NumStateMmaWarpGroups + NumAuxMmaWarpGroups) * cutlass::NumThreadsPerWarpGroup; + + static constexpr auto RegisterRequirements = + get_register_requirements(MaxThreadsPerBlock, MinBlocksPerMultiprocessor, NumStateMmaWarpGroups); + static constexpr uint32_t LdStRegisterRequirement = get<0>(RegisterRequirements); + static constexpr uint32_t StateMmaRegisterRequirement = get<1>(RegisterRequirements); + static constexpr uint32_t AuxMmaRegisterRequirement = get<2>(RegisterRequirements); + + static size_t + get_workspace_size(Arguments const& args) { + return CollectiveMainloop::get_workspace_size(args.mainloop, args.hw_info.sm_count); + } + + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace, cudaStream_t stream) { + return CollectiveMainloop::initialize_workspace(args.problem_size, args.mainloop, workspace, stream); + } + + static bool + can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 + get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.scheduler); + } + + static dim3 + get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params + to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{})}; + } + + CUTE_DEVICE void + operator()(const Params& params, char* smem) { + enum class WarpGroupRole { + LdSt = 0, + Math0 = 1, + Math1 = 2, + MathA = 3, // auxiliary math WG + }; + + // NOTE: CollectiveInverse will have more utilization on warp 0&1 + // so we put beta and alpha preprocessing on warp 2&3 + enum class LdStWarpRole { + LoadQKV = 0, + StoreO = 1, + LoadBetaAndAlpha = 2, + LoadAlphaLast = 3, + }; + + TileScheduler scheduler{params.scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int lane_idx = cutlass::canonical_lane_idx(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_idx_in_wg = warp_idx % cutlass::NumWarpsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto ldst_warp_role = LdStWarpRole(warp_idx_in_wg); + + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + constexpr int NumStateMathThreads = NumStateMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + constexpr int NumAuxMathThreads = NumAuxMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + + QPipelineParams q_pipeline_params; + q_pipeline_params.transaction_bytes = CollectiveMainloop::LoadQBytes; + q_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + q_pipeline_params.num_consumers = NumStateMathThreads + NumAuxMathThreads; + + KPipelineParams k_pipeline_params; + k_pipeline_params.transaction_bytes = CollectiveMainloop::LoadKBytes; + k_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + k_pipeline_params.num_consumers = NumStateMathThreads + NumAuxMathThreads; + + VPipelineParams v_pipeline_params; + v_pipeline_params.transaction_bytes = CollectiveMainloop::LoadVBytes; + v_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); + v_pipeline_params.num_consumers = NumStateMathThreads; + + + + OPipelineParams o_pipeline_params; + o_pipeline_params.producer_arv_count = NumStateMathThreads; + o_pipeline_params.consumer_arv_count = cutlass::NumThreadsPerWarp; + + QKPipelineParams qk_pipeline_params; + qk_pipeline_params.producer_arv_count = NumAuxMathThreads; + qk_pipeline_params.consumer_arv_count = NumStateMathThreads; + + KKPipelineParams kk_pipeline_params; + kk_pipeline_params.producer_arv_count = NumAuxMathThreads; + kk_pipeline_params.consumer_arv_count = NumStateMathThreads; + + AlphaPipelineParams alpha_pipeline_params; + if constexpr (NeedsAlpha) { + alpha_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; + alpha_pipeline_params.consumer_arv_count = NumAuxMathThreads + NumStateMathThreads + cutlass::NumThreadsPerWarp; + } + + AlphaLastPipelineParams alpha_last_pipeline_params; + if constexpr (NeedsAlpha) { + alpha_last_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; + alpha_last_pipeline_params.consumer_arv_count = NumStateMathThreads; + } + + BetaPipelineParams beta_pipeline_params; + if constexpr (NeedsBeta) { + beta_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; + beta_pipeline_params.consumer_arv_count = NumAuxMathThreads + NumStateMathThreads; + } + + OrderedMathBarriers math_barriers; + + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadQKV) { + DPRINTF0_W("ldst_warp_role: LoadQKV\n"); + q_pipeline_params.role = MainloopQPipeline::ThreadCategory::Producer; + k_pipeline_params.role = MainloopKPipeline::ThreadCategory::Producer; + v_pipeline_params.role = MainloopVPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::StoreO) { + DPRINTF0_W("ldst_warp_role: StoreO\n"); + o_pipeline_params.role = MainloopOPipeline::ThreadCategory::Consumer; + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadBetaAndAlpha) { + if constexpr (NeedsBeta) { + beta_pipeline_params.role = MainloopBetaPipeline::ThreadCategory::Producer; + } + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Producer; + } + } + if (warp_group_role == WarpGroupRole::LdSt && ldst_warp_role == LdStWarpRole::LoadAlphaLast) { + // LoadAlpha warp consumes alpha_pipeline (reads last row) and produces alpha_last_pipeline + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Consumer; + } + alpha_last_pipeline_params.role = MainloopAlphaLastPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Math0 || warp_group_role == WarpGroupRole::Math1) { + DPRINTF0_WG("warp_group_role: MathX\n"); + q_pipeline_params.role = MainloopQPipeline::ThreadCategory::Consumer; + k_pipeline_params.role = MainloopKPipeline::ThreadCategory::Consumer; + v_pipeline_params.role = MainloopVPipeline::ThreadCategory::Consumer; + o_pipeline_params.role = MainloopOPipeline::ThreadCategory::Producer; + + qk_pipeline_params.role = MainloopQKPipeline::ThreadCategory::Consumer; + kk_pipeline_params.role = MainloopKKPipeline::ThreadCategory::Consumer; + + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Consumer; + alpha_last_pipeline_params.role = MainloopAlphaLastPipeline::ThreadCategory::Consumer; + } + if constexpr (NeedsBeta) { + beta_pipeline_params.role = MainloopBetaPipeline::ThreadCategory::Consumer; + } + + math_barriers.init(warp_group_idx - 1); + } + if (warp_group_role == WarpGroupRole::MathA) { + DPRINTF0_WG("warp_group_role: MathA\n"); + q_pipeline_params.role = MainloopQPipeline::ThreadCategory::Consumer; + k_pipeline_params.role = MainloopKPipeline::ThreadCategory::Consumer; + + qk_pipeline_params.role = MainloopQKPipeline::ThreadCategory::Producer; + kk_pipeline_params.role = MainloopKKPipeline::ThreadCategory::Producer; + + if constexpr (NeedsAlpha) { + alpha_pipeline_params.role = MainloopAlphaPipeline::ThreadCategory::Consumer; + } + if constexpr (NeedsBeta) { + beta_pipeline_params.role = MainloopBetaPipeline::ThreadCategory::Consumer; + } + } + + MainloopQPipeline q_pipeline(storage.q_pipeline_storage, q_pipeline_params, ClusterShape{}); + MainloopKPipeline k_pipeline(storage.k_pipeline_storage, k_pipeline_params, ClusterShape{}); + MainloopVPipeline v_pipeline(storage.v_pipeline_storage, v_pipeline_params, ClusterShape{}); + MainloopAlphaPipeline alpha_pipeline(storage.alpha_pipeline_storage, alpha_pipeline_params, cute::true_type{}); + MainloopOPipeline o_pipeline(storage.o_pipeline_storage, o_pipeline_params, /*InitBarriers=*/cute::true_type{}); + + MainloopAlphaLastPipeline alpha_last_pipeline( + storage.alpha_last_pipeline_storage, + alpha_last_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + + MainloopQKPipeline qk_pipeline( + storage.qk_pipeline_storage, + qk_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + MainloopKKPipeline kk_pipeline( + storage.kk_pipeline_storage, + kk_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + + // MainloopAlphaPipeline alpha_pipeline(storage.alpha_pipeline_storage, alpha_pipeline_params, + // /*InitBarriers=*/cute::true_type{}); + MainloopBetaPipeline beta_pipeline( + storage.beta_pipeline_storage, + beta_pipeline_params, + /*InitBarriers=*/cute::true_type{}); + + QPipelineState q_smem_pipe_read; + QPipelineState q_smem_pipe_write = cutlass::make_producer_start_state(); + KPipelineState k_smem_pipe_read; + KPipelineState k_smem_pipe_write = cutlass::make_producer_start_state(); + VPipelineState v_smem_pipe_read; + VPipelineState v_smem_pipe_write = cutlass::make_producer_start_state(); + OPipelineState o_smem_pipe_read; + OPipelineState o_smem_pipe_write = cutlass::make_producer_start_state(); + + AlphaLastPipelineState alpha_last_smem_pipe_read; + AlphaLastPipelineState alpha_last_smem_pipe_write; + if constexpr (NeedsAlpha) { + alpha_last_smem_pipe_write = cutlass::make_producer_start_state(); + } + + QKPipelineState qk_smem_pipe_read; + QKPipelineState qk_smem_pipe_write = cutlass::make_producer_start_state(); + KKPipelineState kk_smem_pipe_read; + KKPipelineState kk_smem_pipe_write = cutlass::make_producer_start_state(); + + AlphaPipelineState alpha_smem_pipe_read; + AlphaPipelineState alpha_smem_pipe_write; + if constexpr (NeedsAlpha) { + alpha_smem_pipe_write = cutlass::make_producer_start_state(); + } + BetaPipelineState beta_smem_pipe_read; + BetaPipelineState beta_smem_pipe_write; + if constexpr (NeedsBeta) { + beta_smem_pipe_write = cutlass::make_producer_start_state(); + } + + // barrier sm or cluster level for initialization + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + DPRINTF0_WG("warpspecialized grid initialized\n"); + + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::LdSt) { + DPRINTF0_WG("LsSt warp_group_idx:%d, RegisterRequirement:%d\n", warp_group_idx, LdStRegisterRequirement); + cutlass::arch::warpgroup_reg_dealloc(); + if (ldst_warp_role == LdStWarpRole::LoadQKV) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "LsSt working on LoadQ/K/V, seq_idx:%d, q/k/v_head_idx:(%d,%d,%d), seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.q_head_idx(), + work_desc.k_head_idx(), + work_desc.v_head_idx(), + work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.load_qkv( + params.mainloop, + params.problem_size, + tile_shape, + work_desc, + q_pipeline, + q_smem_pipe_write, + k_pipeline, + k_smem_pipe_write, + v_pipeline, + v_smem_pipe_write, + storage.tensors.mainloop); + } + } else if (ldst_warp_role == LdStWarpRole::LoadBetaAndAlpha) { + if constexpr (NeedsBeta) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "LsSt working on LoadBeta, seq_idx:%d, sab_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.o_head_idx(), + work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.load_beta_and_alpha( + params.mainloop, + params.problem_size, + tile_shape, + work_desc, + beta_pipeline, + beta_smem_pipe_write, + alpha_pipeline, + alpha_smem_pipe_write, + storage.tensors.mainloop); + } + } + } else if (ldst_warp_role == LdStWarpRole::LoadAlphaLast) { + // produce the last row of Alpha + if constexpr (NeedsAlpha) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "LsSt working on LoadAlpha+ExtractLast, seq_idx:%d, sab_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.o_head_idx(), + work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.extract_alpha_last( + params.mainloop, + params.problem_size, + tile_shape, + work_desc, + alpha_pipeline, + alpha_smem_pipe_read, + alpha_last_pipeline, + alpha_last_smem_pipe_write, + storage.tensors.mainloop); + } + } + } else if (ldst_warp_role == LdStWarpRole::StoreO) { + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + DPRINTF0_WG( + "LsSt working on StoreO, seq_idx:%d, o_head_idx:%d, seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.o_head_idx(), + work_desc.seq_len); + auto tile_shape = typename CollectiveMainloop::TileShape{}; + collective_mainloop.store( + params.mainloop.tma_store_o, + params.mainloop.tensormaps, + params.problem_size, + tile_shape, + work_desc, + o_pipeline, + o_smem_pipe_read, + storage.tensors.mainloop.smem_o); + } + } else if (warp_group_role == WarpGroupRole::Math0 || warp_group_role == WarpGroupRole::Math1) { + DPRINTF0_WG( + "Compute[state]: warp_group_idx:%d, RegisterRequirement:%d\n", + warp_group_idx, + StateMmaRegisterRequirement); + cutlass::arch::warpgroup_reg_alloc(); + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "Compute[state]: seq_idx:%d, qk/v/o_head_idx:(%d,%d,%d,%d), seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.q_head_idx(), + work_desc.k_head_idx(), + work_desc.v_head_idx(), + work_desc.o_head_idx(), + work_desc.seq_len); + collective_mainloop.compute( + params.mainloop, + params.problem_size, + work_desc, + q_pipeline, + q_smem_pipe_read, + k_pipeline, + k_smem_pipe_read, + v_pipeline, + v_smem_pipe_read, + o_pipeline, + o_smem_pipe_write, + qk_pipeline, + qk_smem_pipe_read, + kk_pipeline, + kk_smem_pipe_read, + alpha_pipeline, + alpha_smem_pipe_read, + beta_pipeline, + beta_smem_pipe_read, + alpha_last_pipeline, + alpha_last_smem_pipe_read, + math_barriers, + storage.tensors.mainloop); + } + } else if (warp_group_role == WarpGroupRole::MathA) { + DPRINTF0_WG( + "Compute[aux]: warp_group_idx:%d, RegisterRequirement:%d\n", warp_group_idx, AuxMmaRegisterRequirement); + cutlass::arch::warpgroup_reg_alloc(); + auto work_desc = scheduler.get_next_work(params.scheduler, params.problem_size); + CUTE_NO_UNROLL + for (; work_desc.is_valid(params.scheduler); + work_desc = scheduler.get_next_work(params.scheduler, params.problem_size)) { + DPRINTF0_WG( + "Compute[aux]: seq_idx:%d, qk/v/o_head_idx:(%d,%d,%d,%d), seq_len:%lld)\n", + work_desc.seq_idx, + work_desc.q_head_idx(), + work_desc.k_head_idx(), + work_desc.v_head_idx(), + work_desc.o_head_idx(), + work_desc.seq_len); + collective_mainloop.compute_aux_safe( + params.mainloop, + params.problem_size, + work_desc, + q_pipeline, + q_smem_pipe_read, + k_pipeline, + k_smem_pipe_read, + qk_pipeline, + qk_smem_pipe_write, + kk_pipeline, + kk_smem_pipe_write, + alpha_pipeline, + alpha_smem_pipe_read, + beta_pipeline, + beta_smem_pipe_read, + alpha_last_pipeline, + alpha_last_smem_pipe_write, + storage.tensors.mainloop); + } + } else { + DPRINTF0_WG("Unknown warp role, warp_group_idx:%d\n", warp_group_idx); + } + + __syncthreads(); + } +}; + +} // namespace gdn::sm90::kernel diff --git a/csrc/gdn/sm90/kernel/options.hpp b/csrc/gdn/sm90/kernel/options.hpp new file mode 100644 index 0000000..833c1c4 --- /dev/null +++ b/csrc/gdn/sm90/kernel/options.hpp @@ -0,0 +1,86 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +namespace gdn::sm90::kernel { + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +using DefaultOptions = std::tuple<>; + +namespace detail { + +template +struct find_option_impl; + +template +struct find_option_impl { + using option_value = Default; +}; + +template +struct find_option_impl : find_option_impl {}; + +template +struct find_option_impl + : std::conditional_t> {}; + +template +struct find_option_impl> : find_option_impl {}; + +template +struct add_option_impl; + +template +struct add_option_impl> { + using options = std::tuple; +}; + +} // namespace detail + +template +using find_option_t = typename detail::find_option_impl>::option_value; + +template +using add_option_t = typename detail::add_option_impl, std::tuple>::options; + +template +constexpr auto +add_option(Option new_option, std::tuple options_tuple) { + return add_option_t(); +} + +enum class Tag { + kIsDeltaRule, + kIsPersistent, + kNumMmaWarpGroups, + kStagesQ, + kStagesK, + kStagesV, + kNeedsAlpha, // gated delta rule + kNeedsBeta, // delta rule + kInitStateFromInput, // if true, initialize state by reading global memory instead of zero initialization. + kSafeGate, // GDN +}; + +} // namespace gdn::sm90::kernel diff --git a/csrc/gdn/sm90/kernel/tile_scheduler.hpp b/csrc/gdn/sm90/kernel/tile_scheduler.hpp new file mode 100644 index 0000000..6cb9a8e --- /dev/null +++ b/csrc/gdn/sm90/kernel/tile_scheduler.hpp @@ -0,0 +1,139 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace gdn::sm90::kernel { + +using namespace cute; + +struct WorkDesc { + // coord + int32_t seq_idx; + int32_t head_idx; + int64_t tok_offset; // offset to the start of the start + + // shape + int64_t seq_len; + + // update by mainloop + int32_t tile_idx = 0; + + template + CUTE_DEVICE bool + is_valid(Params const& params) { + return seq_idx >= 0 && seq_idx < params.num_seqs; + } + + CUTE_DEVICE int32_t + q_head_idx() const { + return head_idx; + } + CUTE_DEVICE int32_t + k_head_idx() const { + return head_idx; + } + CUTE_DEVICE int32_t + v_head_idx() const { + return head_idx; + } + CUTE_DEVICE int32_t + o_head_idx() const { + return head_idx; + } + + // compatible interface, for work without ChunkWiseParallel, chunk_len equals to seq_len + CUTE_DEVICE int32_t + chunk_len() const { + return seq_len; + } +}; + +struct IndividualTileScheduler { + struct Params { + dim3 grid; + int32_t num_seqs; + int32_t num_heads; + }; + + bool scheduled = false; // a once flag + + CUTE_DEVICE + IndividualTileScheduler(Params const& params) { + } + + template + static Params + to_underlying_arguments( + ProblemSize const& problem_size, + cutlass::KernelHardwareInfo const& hw_info, + ClusterShape const& cluster_shape, + TileShape const& tile_shape) { + dim3 grid(0, 1, 1); + grid.x = problem_size.num_seqs * problem_size.num_heads; + DPRINTF( + "to_underlying_arguments: grid:{.x:%d, .y:%d, .z:%d}, num_seqs:%d, num_heads:%d\n", + grid.x, + grid.y, + grid.z, + problem_size.num_seqs, + problem_size.num_heads); + return { + .grid = grid, + .num_seqs = problem_size.num_seqs, + .num_heads = problem_size.num_heads, + }; + } + + static dim3 + get_grid_shape(Params const& params) { + return params.grid; + } + + template + CUTE_DEVICE WorkDesc + get_next_work(Params params, ProblemSize const& problem_size) { + int32_t seq_idx = blockIdx.x / params.num_heads; + int32_t head_idx = blockIdx.x % params.num_heads; + + int32_t s = problem_size.cu_seqlens[seq_idx]; + int32_t e = problem_size.cu_seqlens[seq_idx + 1]; + int32_t seq_len = e - s; + + if (scheduled) { + seq_idx = -1; + } else { + scheduled = true; + DPRINTF0_W( + "get_next_work: this_work={seq_idx:%d head_idx:%d tok_offset:%lld seq_len:%lld}\n", + seq_idx, + head_idx, + s, + seq_len); + } + + return { + .seq_idx = seq_idx, + .head_idx = head_idx, + .tok_offset = s, + .seq_len = seq_len, + }; + } +}; + +} // namespace gdn::sm90::kernel diff --git a/csrc/gdn/sm90/prefill_kernel.hpp b/csrc/gdn/sm90/prefill_kernel.hpp new file mode 100644 index 0000000..bfcfd6f --- /dev/null +++ b/csrc/gdn/sm90/prefill_kernel.hpp @@ -0,0 +1,49 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +namespace gdn::sm90 { + +template < + typename ArchTag, // TODO: hide this + typename TO, + typename TQKV, + typename TState> +void +launch_gdn_fwd_prefill_kernel( + cudaStream_t stream, + TO* output, + TState* output_state, + TQKV const* q, + TQKV const* k, + TQKV const* v, + TState const* input_state, + float const* alpha, + float const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + bool safe_gate, + int32_t sm_count = 0); + +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh b/csrc/gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh new file mode 100644 index 0000000..111cb46 --- /dev/null +++ b/csrc/gdn/sm90/prefill_kernel_gdn_fwd_sm90.cuh @@ -0,0 +1,154 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include +#include +#include + +#include "gdn/sm90/device/device_universal.hpp" +#include "gdn/sm90/kernel/builder_gdn_fwd.hpp" +#include "gdn/sm90/utils/common.hpp" + +namespace gdn::sm90 { + +using namespace cute; + +template < + bool NeedsBeta, + bool NeedsAlpha, + bool InitStateFromInput, + bool SafeGate, + typename ArchTag, + typename TO, + typename TQKV, + typename TState> +void +launch_gdn_fwd_prefill_kernel_gbai( + cudaStream_t stream, + TO* output, + TState* output_state, + TQKV const* q, + TQKV const* k, + TQKV const* v, + TState const* input_state, + float const* alpha, + float const* beta, + int32_t const* cu_seqlens, + uint8_t* workspace_buffer, + int32_t num_seqs, + int32_t num_heads, + int32_t head_size, + int64_t total_seqlen, + float scale, + int32_t sm_count) { +#if defined(CULA_SM90A_ENABLED) + constexpr bool HopperSupported = true; +#else + constexpr bool HopperSupported = false; +#endif + + if constexpr (HopperSupported) { + static_assert(std::is_same_v); + + using namespace gdn::sm90::kernel; + using T = map_to_cutlass_t; + + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = sm_count; + + using SafeGateType = std::conditional_t; + using NeedsBetaType = std::conditional_t; + using NeedsAlphaType = std::conditional_t; + using InitStateType = std::conditional_t; + using Options = decltype(add_option( + Option{}, + add_option( + Option{}, + add_option( + Option{}, + add_option( + Option{}, + add_option(Option{}, DefaultOptions{})))))); + + using TileShape = Shape<_64, _64, _128>; + using Scheduler = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using Operation = cutlass::device::Universal, + /*LayoutK=*/cute::tuple, + /*LayoutV=*/cute::tuple, + /*LayoutO=*/cute::tuple, + Scheduler, + Options>::Kernel>; + using Arguments = typename Operation::Arguments; + + // NOTE: LayoutQ/K/V in (seq, head_size, (b,h)) coordinate semantics + + int32_t tok_stride = num_heads * head_size; + int32_t head_stride = head_size; + + Operation op; + Arguments arguments{ + .problem_size = + { + .cu_seqlens = cu_seqlens, + .total_seqlen = total_seqlen, + .num_seqs = num_seqs, + .num_heads = num_heads, + .head_size = head_size, + }, + .mainloop = + { + // clang-format off + .ptr_Q = (T*)q, .dQ = {tok_stride, _1{}, head_stride}, + .ptr_K = (T*)k, .dK = {tok_stride, _1{}, head_stride}, + .ptr_V = (T*)v, .dV = {tok_stride, _1{}, head_stride}, + .ptr_O = (T*)output, .dO = {tok_stride, _1{}, head_stride}, + .ptr_Alpha = alpha, .alpha_stride = {num_heads, _1{}}, + .ptr_output_state = (float*)output_state, + .ptr_input_state = (float*)input_state, + .scale = scale, + .beta_ptr = beta, .beta_stride = {num_heads, 1}, + }, // clang-format on + .hw_info = hw_info}; + + cutlass::Status status; + status = op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("can_implement failed"); + } + + status = op.initialize(arguments, workspace_buffer, stream); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("initialize failed"); + } + status = op.run(stream); + if (status != cutlass::Status::kSuccess) { + throw std::runtime_error("run failed"); + } + + } else { + throw std::runtime_error("hopper not supported"); + } +} + +}; // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/utils/common.hpp b/csrc/gdn/sm90/utils/common.hpp new file mode 100644 index 0000000..f1b94f2 --- /dev/null +++ b/csrc/gdn/sm90/utils/common.hpp @@ -0,0 +1,59 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "gdn/sm90/utils/debug.hpp" + +#define FLAT_UNUSED_PARAMETER(x) (void)x + +#define CHECK(expr, msg) \ + do { \ + if (!(expr)) { \ + std::string buffer(1024, '\0'); \ + sprintf(buffer.data(), "Failed to check %s, %s at %s:%d\n", ##expr, msg __FILE__, __LINE__); \ + throw std::runtime_error(buffer.c_str()); \ + } \ + } while (0) + +#define CUDA_CHECK(expr) \ + do { \ + cudaError_t err = (expr); \ + if (err != cudaSuccess) { \ + std::string buffer(1024, '\0'); \ + sprintf( \ + buffer.data(), "CUDA Error: %s, Code: %d at %s:%d\n", cudaGetErrorName(err), err, __FILE__, __LINE__); \ + throw std::runtime_error(buffer.c_str()); \ + } \ + } while (0) diff --git a/csrc/gdn/sm90/utils/debug.hpp b/csrc/gdn/sm90/utils/debug.hpp new file mode 100644 index 0000000..779b937 --- /dev/null +++ b/csrc/gdn/sm90/utils/debug.hpp @@ -0,0 +1,149 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#if DEBUG_PIPE +#define PIPE_DEBUG_PRINTF(fmt, ...) \ + if (threadIdx.x == 0) \ + printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#else +#define PIPE_DEBUG_PRINTF(...) +#endif + +#ifndef FLAT_DEBUG_PRINT +#define FLAT_DEBUG_PRINT 0 +#endif + +#if FLAT_DEBUG_PRINT +#define IS_PRINT_BLOCK cute::block(0) +#define DPRINTF(fmt, ...) \ + if (IS_PRINT_BLOCK) \ + printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#define DPRINTF0(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x == 0) \ + printf("%s:%d " fmt, __FILE__, __LINE__, ##__VA_ARGS__) +#define DPRINTF_W(fmt, ...) \ + if (IS_PRINT_BLOCK) \ + printf( \ + "%s:%d [WG%d][W%d][T%-3d] " fmt, \ + __FILE__, \ + __LINE__, \ + threadIdx.x / 128, \ + threadIdx.x / 32, \ + threadIdx.x, \ + ##__VA_ARGS__) +#define DPRINTF0_W(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x % 32 == 0) \ + printf( \ + "%s:%d [WG%d][W%d][T%-3d] " fmt, \ + __FILE__, \ + __LINE__, \ + threadIdx.x / 128, \ + threadIdx.x / 32, \ + threadIdx.x, \ + ##__VA_ARGS__) +#define DPRINTF_WG(fmt, ...) \ + if (IS_PRINT_BLOCK) \ + printf( \ + "%s:%d [WG%d][W%d][T%-3d] " fmt, \ + __FILE__, \ + __LINE__, \ + threadIdx.x / 128, \ + threadIdx.x / 32, \ + threadIdx.x, \ + ##__VA_ARGS__) +#define DPRINTF0_WG(fmt, ...) \ + if (IS_PRINT_BLOCK && threadIdx.x % 128 == 0) \ + printf( \ + "%s:%d [WG%d][W%d][T%-3d] " fmt, \ + __FILE__, \ + __LINE__, \ + threadIdx.x / 128, \ + threadIdx.x / 32, \ + threadIdx.x, \ + ##__VA_ARGS__) +#else +#define DPRINTF(...) +#define DPRINTF0(...) +#define DPRINTF_W(...) +#define DPRINTF0_W(...) +#define DPRINTF_WG(...) +#define DPRINTF0_WG(...) +#endif + +#if FLAT_DEBUG_PRINT +#define DPRINT_TMA_DESC(tma_dess_addr) \ + do { \ + auto p = reinterpret_cast(tma_dess_addr); \ + DPRINTF( \ + "\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n" \ + "%08X%08X %08X%08X %08X%08X %08X%08X\n", \ + p[0], \ + p[1], \ + p[2], \ + p[3], \ + p[4], \ + p[5], \ + p[6], \ + p[7], \ + p[8], \ + p[9], \ + p[10], \ + p[11], \ + p[12], \ + p[13], \ + p[14], \ + p[15], \ + p[16], \ + p[17], \ + p[18], \ + p[19], \ + p[20], \ + p[21], \ + p[22], \ + p[23], \ + p[24], \ + p[25], \ + p[26], \ + p[27], \ + p[28], \ + p[29], \ + p[30], \ + p[31]); \ + } while (0) +#else +#define DPRINT_TMA_DESC(tma_dess_addr) +#endif diff --git a/csrc/gdn/sm90/utils/math.hpp b/csrc/gdn/sm90/utils/math.hpp new file mode 100644 index 0000000..d7a71ae --- /dev/null +++ b/csrc/gdn/sm90/utils/math.hpp @@ -0,0 +1,53 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace gdn::sm90 { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr T +ceil_log2(T n) { + return n <= 1 ? 0 : 1 + ceil_log2((n + 1) / 2); +} + +} // namespace detail + +template +CUTE_HOST_DEVICE constexpr T +next_power_of_two(T n) { + return static_cast(1) << detail::ceil_log2(n); +} + +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/utils/math_order_barrier.hpp b/csrc/gdn/sm90/utils/math_order_barrier.hpp new file mode 100644 index 0000000..64e8d39 --- /dev/null +++ b/csrc/gdn/sm90/utils/math_order_barrier.hpp @@ -0,0 +1,116 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace gdn::sm90 { + +// cutlass' OrderedSequenceBarrier uses mbarrier +template < + bool UseReservedNB_, // treat nb_id as cutlass::ReservedNamedBarriers + uint32_t... WGIdToNBIdMapping // say 6,4 is passed, means wg0 use nb6 and wg1 use nb4 + > +struct OrderedNamedBarriers { + static constexpr bool UseReservedNB = UseReservedNB_; + static constexpr int NumWG = sizeof...(WGIdToNBIdMapping); + using NBId_t = std::conditional_t; + + CUTE_DEVICE + OrderedNamedBarriers() : mapping_{NBId_t(WGIdToNBIdMapping)...} { + } + + CUTE_DEVICE + void + init(int wg_idx) { // wg_idx in among all WG participants + for (int i = wg_idx; i > 0; --i) { + cutlass::arch::NamedBarrier::arrive(cutlass::NumThreadsPerWarpGroup * NumWG, mapping_[i - 1]); + } + // with 3 WGs, init to namedbarrier_id:(arrived_wg,expected_wg) + // 0:(2,3) + // 1:(1,3) + // 2:(0,3) + } + + CUTE_DEVICE + ~OrderedNamedBarriers() { + // FIXME: this will be a problem for persistent scheduler + } + + CUTE_DEVICE + void + ordered_or_wait(int wg_idx) { // wg_idx in participants + // during first call, before + // 0:(2,3) + // 1:(1,3) + // 2:(0,3) + cutlass::arch::NamedBarrier::sync(cutlass::NumThreadsPerWarpGroup * NumWG, mapping_[wg_idx]); + // after + // 0:(3,3) immediately unblock wg0, and named barrier automatically reset to (0,3) + // 1:(2,3) + // 2:(1,3) + } + + CUTE_DEVICE + void + notify_next_blocked(int wg_idx) { // wg_idx in participants + // allways call this after ordered_or_wait + // during first call, before + // 0:(0,3) + // 1:(2,3) + // 2:(1,3) + CUTE_UNROLL + for (int i = 1; i < NumWG; ++i) { + cutlass::arch::NamedBarrier::arrive( + cutlass::NumThreadsPerWarpGroup * NumWG, mapping_[(wg_idx + i) % NumWG]); + } + // after wg0 called this function + // 0:(0,3), wg0 has not reached on second ordered_or_wait() or (1,3) wg0 wait on second ordered_or_wait() call + // 1:(0,3), unblocked wg1's first ordered_or_wait() and reset nb1 + // 2:(2,3), still wait on first ordered_or_wait() call + // + // after wg1 called this function + // 0:(1,3), wg0 has not reached on second ordered_or_wait() or (2,3) wg0 wait on second ordered_or_wait() call + // 1:(0,3), wg1 has not reached on second ordered_or_wait() or (1,3) wg1 wait on second ordered_or_wait() call + // 2:(0,3), unblocked wg2's first ordered_or_wait() and reset nb2 + // + // after wg2 called this function + // 0:(2,3), wg0 has not reached on second ordered_or_wait() or (0,3) wg0 wait on second ordered_or_wait() call, + // unblocked 1:(1,3), wg1 has not reached on second ordered_or_wait() or (2,3) wg1 wait on second + // ordered_or_wait() call, still block 2:(0,3), unblock wg0 ordered_or_wait() and reset + // + } + + private: + cute::array mapping_; +}; +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/utils/type_traits.hpp b/csrc/gdn/sm90/utils/type_traits.hpp new file mode 100644 index 0000000..489083d --- /dev/null +++ b/csrc/gdn/sm90/utils/type_traits.hpp @@ -0,0 +1,66 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace gdn::sm90 { + +// clang-format off +template struct map_to_cutlass; +template<> struct map_to_cutlass { using type = cutlass::half_t; }; +template<> struct map_to_cutlass { using type = cutlass::bfloat16_t; }; +template<> struct map_to_cutlass { using type = cutlass::half_t; }; +template<> struct map_to_cutlass { using type = cutlass::bfloat16_t; }; + +template using map_to_cutlass_t = typename map_to_cutlass::type; +// clang-format on + +template +struct first_non_void { + static_assert(sizeof...(Ts) > 0, "all voids is not allowed"); + using type = void; +}; + +template +struct first_non_void { + using type = T; +}; + +template +struct first_non_void : first_non_void {}; + +template +using first_non_void_t = typename first_non_void::type; + +} // namespace gdn::sm90 diff --git a/csrc/gdn/sm90/utils/unused.hpp b/csrc/gdn/sm90/utils/unused.hpp new file mode 100644 index 0000000..9764c82 --- /dev/null +++ b/csrc/gdn/sm90/utils/unused.hpp @@ -0,0 +1,54 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace gdn::sm90 { + +struct Unused { + using Params = Unused; + using SharedStorage = char; + static constexpr uint32_t Stages = 0; + + template + CUTE_HOST_DEVICE + Unused(Ts... vs) { + } + + template + CUTE_HOST_DEVICE Unused + operator=(T&& v) { + return Unused{}; + } +}; + +} // namespace gdn::sm90 diff --git a/cula/gdn/hopper_fused_fwd.py b/cula/gdn/hopper_fused_fwd.py index 087273e..2246380 100644 --- a/cula/gdn/hopper_fused_fwd.py +++ b/cula/gdn/hopper_fused_fwd.py @@ -74,10 +74,7 @@ def forward( chunk_indices=chunk_indices, lower_bound=lower_bound ) - torch.cuda.synchronize() # DEBUG: surface errors from gdn_gate_chunk_cumsum_lowerbound - else: - print("launching FLA chunk local cumsum") g = chunk_local_cumsum( g=g, chunk_size=chunk_size, @@ -85,7 +82,6 @@ def forward( cu_seqlens=cu_seqlens, chunk_indices=chunk_indices ) - torch.cuda.synchronize() # DEBUG: surface errors from chunk_local_cumsum q_rstd, k_rstd = None, None if use_qk_l2norm_in_kernel: @@ -106,7 +102,6 @@ def forward( # call the C++ kernel # Signature:gdn_fwd_prefill(output_, output_state_, q, k, v, input_state_, alpha_, beta_, cu_seqlens, workspace, scale, safe_gate) - print("launching prefill kernel cpp") o, final_state = cula_cuda.gdn_fwd_prefill( None, # output_ (auto-allocate) None, # output_state_ (auto-allocate) @@ -121,11 +116,6 @@ def forward( scale, safe_gate, ) - torch.cuda.synchronize() # DEBUG: surface errors from gdn_fwd_prefill kernel - print(f"DEBUG: o has nan={o.isnan().any().item()}, shape={o.shape}, dtype={o.dtype}") - print(f"DEBUG: final_state has nan={final_state.isnan().any().item()}, shape={final_state.shape}") - print(f"DEBUG: o[:3]={o.flatten()[:8].tolist()}") - o = rearrange(o, "(b t) h d -> b t h d", b = batch_size) return o.to(q.dtype), final_state diff --git a/tests/test_gdn_fused_fwd.py b/tests/test_gdn_fused_fwd.py index 0674304..363467f 100644 --- a/tests/test_gdn_fused_fwd.py +++ b/tests/test_gdn_fused_fwd.py @@ -100,12 +100,11 @@ def test_safe_gate_chunk( if use_gate_in_kernel: A_log, dt_bias = map(lambda x: x.to(device).requires_grad_(True), (A_log, dt_bias)) q, k, v, g, beta, h0 = map(lambda x: x.to(device).requires_grad_(True), (q, k, v, g, beta, h0)) - ref, ref_ht = naive_recurrent_gated_delta_rule( q=F.normalize(q.clone(), p=2, dim=-1), k=F.normalize(k.clone(), p=2, dim=-1), v=v.clone(), - g=(naive_gdn_gate_fn(g, A_log, dt_bias) if use_gate_in_kernel else g.clone()), + g=(naive_gdn_gate_fn(g.clone(), A_log, dt_bias) if use_gate_in_kernel else g.clone()), beta=beta.clone(), initial_state=h0.clone(), output_final_state=True, @@ -138,11 +137,10 @@ def test_safe_gate_chunk( safe_gate=safe_gate, lower_bound=lower_bound, ) - - assert_close("o", ref, tri, 0.005) assert_close("ht", ref_ht, tri_ht, 0.005) - assert_close("o", ref_fla, tri, 0.005) assert_close("ht", ref_ht_fla, tri_ht, 0.005) + assert_close("o", ref, tri, 0.005) + assert_close("o", ref_fla, tri, 0.005) @pytest.mark.parametrize( From 2ea7bf59858693640347b18cb23d3b4c77a8294f Mon Sep 17 00:00:00 2001 From: Kingsley Kim Date: Sun, 12 Apr 2026 18:05:56 -0400 Subject: [PATCH 3/3] precommit fixes --- csrc/gdn/sm90/changes_from_kda.md | 32 ++++++++ csrc/gdn/sm90/collective/load_predicated.hpp | 46 +++++++---- csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp | 79 ++++++++----------- csrc/gdn/sm90/gdn_config.hpp | 19 +++++ csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp | 5 +- cula/gdn/__init__.py | 5 +- cula/gdn/gate.py | 47 +++++------ cula/gdn/hopper_fused_fwd.py | 43 +++++----- tests/test_gdn_fused_fwd.py | 11 ++- 9 files changed, 169 insertions(+), 118 deletions(-) create mode 100644 csrc/gdn/sm90/changes_from_kda.md create mode 100644 csrc/gdn/sm90/gdn_config.hpp diff --git a/csrc/gdn/sm90/changes_from_kda.md b/csrc/gdn/sm90/changes_from_kda.md new file mode 100644 index 0000000..eebfc89 --- /dev/null +++ b/csrc/gdn/sm90/changes_from_kda.md @@ -0,0 +1,32 @@ +## Important changes from KDA -> GDN + +### Kept: + +- The KK_inv lambda defined in mainloop is kept because it serves the same purpose of applying the beta. It is a modular function, so we don't need to worry about the scale applications - this needs to be changed in other lambda. + +### Change: + +- In load_kv, the cached state should be loaded NON-transposed. This is different from the kda implementation, c.f. mainloop_kda_fwd.hpp:909. + - Thinking more on this, it's better to just align the state shape with how the KDA implementation has already carried it out - this means that the expected state shape should be made explicit higher up in the API chain, possibly in FLA. +- In Kimi linear, the K matrix is multipled with the cumulative alpha gating matrix everywhere - it is also per channel. In GDN the first I + (KK^T) matrix also involves multiplying by alpha, but it is on a per-sequence, per-head basis. Thus the K_scaled and Q_scaled need to be rewritten. +- Alpha (gate) has previous shape of (B, T, H, K) for Kimi Linear, with per-channel gating. GDN instead has per-head gating, so the shape becomes (B, T, H), and we instead load a vector of size blkSeqQ == blkSeqK into shared memory. This means the atoms and layouts related to Alpha must all be changed, as well as the application of the gate. + - Because Alpha is now not the same shape as the Blk_Q/K/V tiles, it is now the same shape as the beta. This means that we don't need to create auxilliary layouts for the TMA loads, and instead we port over Alpha's SMEM layout into a CollectiveLoadVector + - The load_qkv in mainloop_gdn_fwd.hpp doesn't load alpha anymore - this is transferred to the load_beta + - extract_alpha_last needs to be changed to a simple index into the last index of shared alpha tensor, while checking for end of sequence boundaries. It just copies once. + - Alpha params are now changed to pointers with gmemlayout instead of TMALoad type + - Another alpha change - during GDN's forward pass, the gating matrix applied to KK^T is computed as the difference between [i,j] coords in log space, then exp2f. However, KDA instead applies an elementwise mask that is pre computed. The final QK^T , also coputed in compute_aux_safe, doesn't multiply on the alpha gate matrix, so it is a normal tensor core multiplication instead. + - SharedStorage needs to be changed in kernel +- Compute_aux_safe changes + - In s2r_compute_subchunk_operandA, I tried to keep the changes as minimal as possible, so I kept the behavior of copying a tile of A, but I broadcast the alpha values, which are now a row equivalent, to all the 32 columns in the subchunk. This allows the previous broadcast_row_0 + exp2f(g - g_first) values to still work. POSSIBLE OPTIMIZATION: It might be faster to just use the same register + identity tensor + row indexing across threads that +- Compute_loop_body changes + - The alpha loading in + scaling needs to be changed, since the KDA implementation loads in a tile of 32 across the head dimension. I did the same change that i did in compute_aux_safe to create a dummy tensor shape that broadcasts. + - I also stopped using a CopyAlpaAhtom and instead do a manual unrolled loop when loading in the shared alpha values for QK scaling +- KV state shape: + - It looks like KDA implementation uses the same V^T * K_scaled, with the KV_state shape being d_V x d_K in the output. This is also equivalent to the FLA transpose flag being set to TRUE. +- Change in kernel_gdn_fwd.hpp: + - Because the load type for alpha is now a predicated vector, we need to also change the alphapipelineparam initializaiton, moved it down next to beta, since they are loaded together + +### Possible Optimizations: + +- Because GDN doesn't need to materialize an entire register tile to hold results, we can load in the rows directly from shared memory and not worry about copying through to registers before multiplying. This could allow more aggressive use of the register file, in exchange for added latency from accessing SMEM. To keep consistency with the previous KDA implementation, I just used 0-strides to broadcast along the row dimension. + diff --git a/csrc/gdn/sm90/collective/load_predicated.hpp b/csrc/gdn/sm90/collective/load_predicated.hpp index 9ebeb7d..53ac270 100644 --- a/csrc/gdn/sm90/collective/load_predicated.hpp +++ b/csrc/gdn/sm90/collective/load_predicated.hpp @@ -30,6 +30,7 @@ #pragma once +#include #include #include #include @@ -79,7 +80,8 @@ template < class GmemLayout, class ElementDst, class SmemLayout, - class VectorProcessor_ = Unused> + class VectorProcessor_ = Unused, + bool UseCPCopy = false> struct CollectiveLoadVector { using SharedStorage = cute::array_aligned>; using PipelineState = typename cutlass::PipelineState; @@ -144,17 +146,8 @@ struct CollectiveLoadVector { template CUTE_DEVICE void step(SrcDst const& src_dst, int src_iter, PipelineState& dst_pipe, int num_iters, VectorProcessor processor = {}) { - auto src = get<0>(src_dst); - auto dst = get<1>(src_dst); - - auto regs = make_fragment_like(take<0, 2>(shape(dst))); - if constexpr (!IsTail) { - copy(src(_, _, src_iter), regs); - } else { - auto mask = get<2>(src_dst); - fill(regs, src_oob_value_); - copy_if(mask, src(_, _, src_iter), regs); - } + auto& src = get<0>(src_dst); + auto& dst = get<1>(src_dst); int dst_pipe_idx = dst_pipe.index(); @@ -162,10 +155,33 @@ struct CollectiveLoadVector { pipeline_.producer_acquire(dst_pipe); cutlass::arch::fence_view_async_shared(); - if constexpr (rank_v == 3) { - copy(regs, dst(_, _, _0{}, dst_pipe_idx)); + auto dst_slice = [&]() -> decltype(auto) { + if constexpr (rank_v == 3) { + return dst(_, _, _0{}, dst_pipe_idx); + } else { + return dst(_, _, dst_pipe_idx); + } + }(); + + if constexpr (!IsTail && UseCPCopy) { + // Hot path: copy directly gmem→smem via cp.async, no register intermediate. + // This avoids the per-iteration register pressure from the fragment allocation. + using AsyncCopy = Copy_Atom, ElementSrc>; + copy(AsyncCopy{}, src(_, _, src_iter), dst_slice); + cp_async_fence(); + cp_async_wait<0>(); } else { - copy(regs, dst(_, _, dst_pipe_idx)); + // Either the tail path (called once, needs predicated fill) or UseCPCopy=false + // (register intermediate path). + auto regs = make_fragment_like(take<0, 2>(shape(dst))); + if constexpr (IsTail) { + auto mask = get<2>(src_dst); + fill(regs, src_oob_value_); + copy_if(mask, src(_, _, src_iter), regs); + } else { + copy(src(_, _, src_iter), regs); + } + copy(regs, dst_slice); } Tensor s = make_tensor(make_smem_ptr(storage_.data()), SmemLayout{}); diff --git a/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp b/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp index 931c1a9..02d4259 100644 --- a/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp +++ b/csrc/gdn/sm90/collective/mainloop_gdn_fwd.hpp @@ -210,7 +210,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // make_shape( // shape<1>(TileShapeQK{}), // Int{}))); // (blk_kv), (64) - + // using GmemTiledCopyAlpha = cute::SM90_TMA_LOAD; // using TMA_Alpha = decltype(make_tma_copy( // GmemTiledCopyAlpha{}, @@ -298,7 +298,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { ElementO, ElementAccumulatorO, /*Smem*/ ElementO, - decltype(select<1, 0, 2>(LayoutO{})), // creates mn-major atom in store_tma.hpp + decltype(select<1, 0, 2>(LayoutO{})), // creates mn-major atom in store_tma.hpp StagesO::value>; // layout for compute @@ -310,7 +310,6 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { using KVSmemLayoutV = SmemLayoutV_DS; using QKScaledSmemLayoutKt = SmemLayoutQ_K_Scaled_DS; - // layout for compute output using SmemLayoutQK = decltype(tile_to_shape( GMMA::Layout_K_INTER_Atom{}, @@ -372,8 +371,8 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { using GmemStrideBeta = Stride; using GmemLayoutBeta = Layout, GmemStrideBeta>; // (seq, head) - using GmemShapeAlpha = Shape; // (seqlen_k, h) - using GmemStrideAlpha = Stride; // TODO: depends on gate cumsum output, so we won't hardset to 1 + using GmemShapeAlpha = Shape; // (seqlen_k, h) + using GmemStrideAlpha = Stride; // TODO: depends on gate cumsum output, so we won't hardset to 1 using GmemLayoutAlpha = Layout; // only store the last Alpha value, either end of chunk or end of sequence @@ -455,16 +454,16 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { using LoadK = CollectiveLoadTma; using LoadV = CollectiveLoadTma; // using LoadAlpha = - // CollectiveLoadTma; + // CollectiveLoadTma; using LoadAlpha = CollectiveLoadVector< - LoadKindVector::kAlpha, - MainloopAlphaPipeline, + LoadKindVector::kAlpha, + MainloopAlphaPipeline, + ElementAlpha, + GmemLayoutAlpha, ElementAlpha, - GmemLayoutAlpha, - ElementAlpha, SmemLayoutAlpha, - AlphaProcessor - >; + AlphaProcessor, + /*UseCPCopy=*/std::is_same_v>; using LoadBeta = CollectiveLoadVector< LoadKindVector::kBeta, @@ -473,7 +472,8 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { GmemLayoutBeta, ElementBeta, SmemLayoutBeta, - BetaProcessor>; + BetaProcessor, + /*UseCPCopy=*/std::is_same_v>; struct Arguments { // clang-format off Element const* ptr_Q; LayoutQ dQ; @@ -499,7 +499,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { float* ptr_output_state; float const* ptr_input_state; - ElementAlpha const * alpha_ptr; + ElementAlpha const* alpha_ptr; GmemLayoutAlpha alpha_layout; ElementBeta const* beta_ptr; GmemLayoutBeta beta_layout; @@ -643,7 +643,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { LoadBeta{params.beta_ptr, params.beta_layout, /*oob_value=*/0.0f, beta_pipeline, storage.smem_beta}; // oob fill value for alpha is -INFINITY because of exp2f(alpha) auto alpha_collective_load = - LoadAlpha {params.alpha_ptr, params.alpha_layout, /*oob_value=*/0.0f, alpha_pipeline, storage.smem_alpha}; + LoadAlpha{params.alpha_ptr, params.alpha_layout, /*oob_value=*/0.0f, alpha_pipeline, storage.smem_alpha}; auto beta_src_dst = beta_collective_load.partition_SD(problem_size, tile_shape, work_desc); auto alpha_src_dst = alpha_collective_load.partition_SD(problem_size, tile_shape, work_desc); @@ -820,7 +820,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { auto const& cMkk = cMqk; auto tKKcMkk = kk_thr_mma.partition_C(cMkk); - // S@K (-S K^T + V^T) - K and T + // S@K (-S K^T + V^T) - K and T auto sk_tiled_mma = TiledMmaSK{}; auto sk_thr_mma = sk_tiled_mma.get_thread_slice(thread_idx); @@ -945,9 +945,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { auto s_decay = [&](auto& tKVrKV, auto const& alpha_last_smem_pipe_read) INLINE_LAMBDA { auto alpha_last_curr = AlphaLast(0, alpha_last_smem_pipe_read.index()); - for_each(make_int_sequence{}, [&](auto i) { - tKVrKV(i) *= exp2f(alpha_last_curr); - }); + for_each(make_int_sequence{}, [&](auto i) { tKVrKV(i) *= exp2f(alpha_last_curr); }); }; auto o1_epi = [&](auto& tOrO1) INLINE_LAMBDA { @@ -1052,14 +1050,12 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { qk_thr_mma_rs_quar.partition_fragment_A(sKqk_slice(_, _, _0{}, make_coord(_0{}, _0{}))); auto tArA = make_fragment_like(tQKrQ_wg); - auto sA_cur = sAlpha_slice(_, _0{}); - #pragma unroll +#pragma unroll for (int v = 0; v < size(tArA); v++) { - tArA(v) = sA_cur(get<0>(tQcMq_quar(v))); + tArA(v) = sAlpha_slice((get<0>(tQcMq_quar(v))), _0{}); } cute::transform(tArA, [](auto g) { return exp2f(g); }); for (int s = 0; s < 2; ++s) { - // S2R Q auto sQqk_cur = sQqk_slice(_, _, _0{}, make_coord(s, wg_idx)); auto tQKsQ_cur = thr_load_qk_quar.partition_S(sQqk_cur); @@ -1268,7 +1264,6 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { if constexpr (!is_first_block) { s_decay(tKVrKV, alpha_last_smem_pipe_read); } - // synchronize 2 WGs before rewriting sQ_K_scaled cutlass::arch::NamedBarrier::arrive_and_wait(NumStateMmaThreads, GdnNamedBarriers::StateMath); @@ -1281,12 +1276,10 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // Allocate K/Alpha register fragments once (reused across slices) auto tQKrK_wg = qk_thr_mma_rs_quar.partition_fragment_A(sKqk_slice(_, _, _0{}, make_coord(_0{}, _0{}))); auto tArA_wg = make_fragment_like(tQKrK_wg); - - auto sA_cur = sAlpha_slice(_, _0{}); - // Dummy tensor to enable broadcast of alpha values across a row - #pragma unroll +// Dummy tensor to enable broadcast of alpha values across a row +#pragma unroll for (int v = 0; v < size(tArA_wg); v++) { - tArA_wg(v) = sA_cur(get<0>(tQcMq_quar(v))); + tArA_wg(v) = sAlpha_slice(get<0>(tQcMq_quar(v)), _0{}); } auto alpha_last = sAlast_curr(0); for (int s = 0; s < 2; ++s) { @@ -1302,7 +1295,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { auto [seq, _] = coord; auto alpha = tArA_wg(i); auto k = tQKrK_wg(i); - auto k_scaled = Element(exp2f(alpha_last - alpha) * float(k)); + auto k_scaled = Element(exp2f(alpha_last - alpha) * float(k)); tQKrK_wg(i) = k_scaled; if constexpr (is_final_block) { if (seq >= B) { @@ -1466,7 +1459,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // Alpha S2R: load in BF16 MMA layout so gating happens before the layout shuffle, // reducing register pressure (alpha can be freed before the shuffle). - // BF16-layout alpha copies for operand A and B (for element-wise gating) + // BF16-layout alpha copies for operand A and B (for element-wise gating) auto alpha_Q_bf16_tiled_copy = make_tiled_copy_A(CopyAlphaAtom{}, tiledmma_bf16_subchunk); auto alpha_Kt_bf16_tiled_copy = make_tiled_copy_B(CopyAlphaAtom{}, tiledmma_bf16_subchunk); // Q/K S2R: LDSM copies using BF16 MMA layout for efficient ldmatrix loads @@ -1498,7 +1491,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { constexpr auto tiler_subchunk_beta = Shape<_16>{}; auto sQqk_curr = sQqk(_, _, q_smem_pipe_read.index()); auto sKqk_curr = sKqk(_, _, k_smem_pipe_read.index()); - auto sAlpha_curr = Alpha(_, alpha_smem_pipe_read.index()); // (_64) + auto sAlpha_curr = Alpha(_, alpha_smem_pipe_read.index()); // (_64) Tensor sBeta_curr = Beta(_, beta_smem_pipe_read.index()); // (_16,(_32,_1),_4,(_2,_2)):(_64,(_1,_0),_1024,(_32,_4096)) @@ -1539,14 +1532,12 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // layout) auto s2r_compute_subchunk_operandA = [&](auto r_, int j, int j0, int j1) INLINE_LAMBDA { // S2R g_r_j in BF16 MMA operand A layout (single load) - Tensor sAlpha_r = sAlpha_slice(_, r_); Tensor tArA_r = make_fragment_like(tv_layout_bf16_mma_A); - // unrolled loop for explicit copy instead of using Cutlass copy - #pragma unroll - for (int v = 0 ; v < size(tArA_r); v++) { - tArA_r(v) = sAlpha_r(get<0>(tQKaMqk_subchunk(v))); +// unrolled loop for explicit copy instead of using Cutlass copy +#pragma unroll + for (int v = 0; v < size(tArA_r); v++) { + tArA_r(v) = sAlpha_slice(get<0>(tQKaMqk_subchunk(v)), r_); } - // Derive g_first (alpha[row=0, :]) from tArA_r_j via warp shuffle, // directly into operand B layout (8 values instead of 16). @@ -1562,8 +1553,8 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // so g_first for index 4j+{0,2} = frag_B(2j), for 4j+{1,3} = frag_B(2j+1). CUTE_UNROLL for (int k = 0; k < 4; k++) { - auto gf_lo = tArAfirst_r_j_kt(2 * k); // g_first at K = 2*t0 - auto gf_hi = tArAfirst_r_j_kt(2 * k + 1); // g_first at K = 2*t0+1 + auto gf_lo = tArAfirst_r_j_kt(2 * k); // g_first at K = 2*t0 + auto gf_hi = tArAfirst_r_j_kt(2 * k + 1); // g_first at K = 2*t0+1 tArA_r(4 * k + 0) = exp2f(tArA_r(4 * k + 0) - gf_lo); // v0=0, v1=0 tArA_r(4 * k + 1) = exp2f(tArA_r(4 * k + 1) - gf_hi); // v0=1, v1=0 tArA_r(4 * k + 2) = exp2f(tArA_r(4 * k + 2) - gf_lo); // v0=0, v1=1 @@ -1581,8 +1572,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { copy(Q_tiled_copy, tQKsQ_r_j, tQKrQ_r_j_bf16_cv); // gate: Q * exp2(g - g_first) in BF16 MMA layout, producing float Tensor tQKrQ_r_j_float = make_fragment_like(tv_layout_bf16_mma_A); - cute::transform( - tQKrQ_r_j_bf16, tArA_r, tQKrQ_r_j_float, [&](auto q, auto g) { return float(q) * g; }); + cute::transform(tQKrQ_r_j_bf16, tArA_r, tQKrQ_r_j_float, [&](auto q, auto g) { return float(q) * g; }); // convert BF16 MMA layout → TF32 MMA layout in-place via warp shuffles convert_bf16_to_tf32_operandA_layout(tQKrQ_r_j_float, local_thread_idx); // NOTE: triton tl.dot also lets MMA hardware for truncation @@ -1595,8 +1585,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { Tensor tQKrK_r_j_bf16_cv = Q_thr_copy.retile_D(tQKrK_r_j_bf16); copy(Q_tiled_copy, tQKsK_r_j, tQKrK_r_j_bf16_cv); Tensor tQKrK_r_j_float = make_fragment_like(tv_layout_bf16_mma_A); - cute::transform( - tQKrK_r_j_bf16, tArA_r, tQKrK_r_j_float, [&](auto k, auto g) { return float(k) * g; }); + cute::transform(tQKrK_r_j_bf16, tArA_r, tQKrK_r_j_float, [&](auto k, auto g) { return float(k) * g; }); // convert BF16 MMA layout → TF32 MMA layout in-place via warp shuffles convert_bf16_to_tf32_operandA_layout(tQKrK_r_j_float, local_thread_idx); auto tQKrK_r_j = recast(tQKrK_r_j_float); @@ -1614,7 +1603,7 @@ struct FlatMainloopTmaWarpSpecializedGdnFwd { // S2R g_c_j in BF16 MMA operand B layout Tensor sAlpha_c = sAlpha_slice(_, c_); Tensor tArA_c = make_fragment_like(tv_layout_bf16_mma_B); - #pragma unroll +#pragma unroll for (int v = 0; v < size(tArA_c); v++) { tArA_c(v) = sAlpha_c(get<0>(tQKbMqk_subchunk(v))); } diff --git a/csrc/gdn/sm90/gdn_config.hpp b/csrc/gdn/sm90/gdn_config.hpp new file mode 100644 index 0000000..aa3a6db --- /dev/null +++ b/csrc/gdn/sm90/gdn_config.hpp @@ -0,0 +1,19 @@ +// Copyright 2025-2026 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "gdn/sm90/kernel/tile_scheduler.hpp" + +struct GDN_fwd_intra_params {}; \ No newline at end of file diff --git a/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp b/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp index 5774bae..291fdf6 100644 --- a/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp +++ b/csrc/gdn/sm90/kernel/kernel_gdn_fwd.hpp @@ -280,8 +280,6 @@ struct FlatKernelTmaWarpSpecializedGdnFwd { v_pipeline_params.is_leader = lane_predicate && (ldst_warp_role == LdStWarpRole::LoadQKV); v_pipeline_params.num_consumers = NumStateMathThreads; - - OPipelineParams o_pipeline_params; o_pipeline_params.producer_arv_count = NumStateMathThreads; o_pipeline_params.consumer_arv_count = cutlass::NumThreadsPerWarp; @@ -297,7 +295,8 @@ struct FlatKernelTmaWarpSpecializedGdnFwd { AlphaPipelineParams alpha_pipeline_params; if constexpr (NeedsAlpha) { alpha_pipeline_params.producer_arv_count = cutlass::NumThreadsPerWarp; - alpha_pipeline_params.consumer_arv_count = NumAuxMathThreads + NumStateMathThreads + cutlass::NumThreadsPerWarp; + alpha_pipeline_params.consumer_arv_count = + NumAuxMathThreads + NumStateMathThreads + cutlass::NumThreadsPerWarp; } AlphaLastPipelineParams alpha_last_pipeline_params; diff --git a/cula/gdn/__init__.py b/cula/gdn/__init__.py index 1b69a56..e2c5eb6 100644 --- a/cula/gdn/__init__.py +++ b/cula/gdn/__init__.py @@ -14,7 +14,4 @@ from cula.gdn.hopper_fused_fwd import cula_gdn_prefill as gdn_prefill_hopper -__all__ = [ - "gdn_prefill_hopper" -] - +__all__ = ["gdn_prefill_hopper"] diff --git a/cula/gdn/gate.py b/cula/gdn/gate.py index 7ee1b41..9c4426e 100644 --- a/cula/gdn/gate.py +++ b/cula/gdn/gate.py @@ -2,15 +2,15 @@ import torch.nn.functional as F import triton import triton.language as tl - +from fla.ops.utils.index import prepare_chunk_indices from fla.ops.utils.op import exp from fla.ops.utils.softplus import softplus -from fla.ops.utils.index import prepare_chunk_indices from fla.utils import autotune_cache_kwargs, input_guard BT_LIST_AUTOTUNE = [32, 64, 128] NUM_WARPS_AUTOTUNE = [4, 8, 16, 32] + def naive_gdn_gate( g: torch.Tensor, A_log: torch.Tensor, @@ -33,43 +33,42 @@ def naive_gdn_gate( Returns: Output tensor of shape `[..., H]` . """ - H = g.shape[-1] g = g.float() if dt_bias is not None: g = g + dt_bias g = (-A_log.float().exp() * F.softplus(g.float())).to(output_dtype) return g + # naive gdn lowerbound method based off of fla.ops.kda.gate def naive_gdn_lowerbound_gate( g: torch.Tensor, A_log: torch.Tensor, dt_bias: torch.Tensor | None = None, lower_bound: float = -5.0, - output_dtype: torch.dtype = torch.float32 + output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: - num_heads = g.shape[-1] g = g.float() if dt_bias is not None: g = g + dt_bias g = lower_bound * F.sigmoid(A_log.exp() * g) return g.to(output_dtype) -@triton.heuristics({ - "HAS_BIAS": lambda args: args["dt_bias"] is not None, - 'HAS_SCALE': lambda args: args['scale'] is not None, - 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, - 'USE_LOWER_BOUND': lambda args: args['lower_bound'] is not None, -}) + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["dt_bias"] is not None, + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "USE_LOWER_BOUND": lambda args: args["lower_bound"] is not None, + } +) @triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps) - for num_warps in [2, 4, 8] - ], - key=['H', 'BT', 'IS_VARLEN', 'REVERSE'], + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [2, 4, 8]], + key=["H", "BT", "IS_VARLEN", "REVERSE"], **autotune_cache_kwargs, ) -@triton.jit(do_not_specialize=['T']) +@triton.jit(do_not_specialize=["T"]) def gdn_gate_chunk_cumsum_scalar_kernel( s, A_log, @@ -97,9 +96,8 @@ def gdn_gate_chunk_cumsum_scalar_kernel( else: bos, eos = i_b * T, i_b * T + T - - p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) # [BT] b_s = tl.load(p_s, boundary_check=(0)).to(tl.float32) @@ -149,8 +147,11 @@ def gdn_gate_chunk_cumsum_lowerbound( NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" - g_org, g = g, torch.empty_like(g, dtype = output_dtype or g.dtype) - def grid(meta): return (NT, B * H) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (NT, B * H) + gdn_gate_chunk_cumsum_scalar_kernel[grid]( s=g_org, A_log=A_log, @@ -165,4 +166,4 @@ def grid(meta): return (NT, B * H) BT=BT, REVERSE=False, ) - return g \ No newline at end of file + return g diff --git a/cula/gdn/hopper_fused_fwd.py b/cula/gdn/hopper_fused_fwd.py index 2246380..ee514b8 100644 --- a/cula/gdn/hopper_fused_fwd.py +++ b/cula/gdn/hopper_fused_fwd.py @@ -14,18 +14,17 @@ import torch from einops import rearrange - from fla.modules.l2norm import l2norm_fwd from fla.ops.utils import chunk_local_cumsum from fla.ops.utils.constant import RCP_LN2 from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard import cula.cudac as cula_cuda -from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens from cula.gdn.gate import gdn_gate_chunk_cumsum_lowerbound +from cula.utils import _get_cache_buf, assert_hopper, get_device_sm_count, prepare_uniform_cu_seqlens -class HopperChunkGDNFunction(torch.autograd.Function): +class HopperChunkGDNFunction(torch.autograd.Function): @staticmethod @input_guard @autocast_custom_fwd @@ -40,13 +39,13 @@ def forward( dt_bias: torch.Tensor, scale: float, initial_state: torch.Tensor, - output_final_state : bool = False, - use_qk_l2norm_in_kernel : bool = False, - use_gate_in_kernel : bool = False, - safe_gate : bool = False, - lower_bound : float | None = None, - cu_seqlens : torch.IntTensor | None = None, - chunk_indices : torch.IntTensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + use_gate_in_kernel: bool = False, + safe_gate: bool = False, + lower_bound: float | None = None, + cu_seqlens: torch.IntTensor | None = None, + chunk_indices: torch.IntTensor | None = None, ): chunk_size = 64 assert q.shape[-2] == v.shape[-2] == k.shape[-2], "Number of heads must be the same across q, k, v" @@ -55,11 +54,11 @@ def forward( if cu_seqlens is None: cu_seqlens = prepare_uniform_cu_seqlens(batch_size, seq_len, q.device, torch.int32) - + # after setting up cu_seqlens, set batch size to 1 if batch_size != 1: - q, k, v, g, beta = map(lambda x : rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) - + q, k, v, g, beta = map(lambda x: rearrange(x, "b t ... -> 1 (b t) ..."), (q, k, v, g, beta)) + # compute gate inside kernel if use_gate_in_kernel: if safe_gate: @@ -72,15 +71,11 @@ def forward( scale=RCP_LN2, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, - lower_bound=lower_bound + lower_bound=lower_bound, ) else: g = chunk_local_cumsum( - g=g, - chunk_size=chunk_size, - scale=RCP_LN2, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices + g=g, chunk_size=chunk_size, scale=RCP_LN2, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices ) q_rstd, k_rstd = None, None @@ -116,15 +111,13 @@ def forward( scale, safe_gate, ) - o = rearrange(o, "(b t) h d -> b t h d", b = batch_size) + o = rearrange(o, "(b t) h d -> b t h d", b=batch_size) return o.to(q.dtype), final_state - + @staticmethod @input_guard @autocast_custom_bwd - def backward( - ctx, do, dht - ): + def backward(ctx, do, dht): raise NotImplementedError("Backward pass not implemented yet") @@ -238,4 +231,4 @@ def cula_gdn_prefill( cu_seqlens, chunk_indices, ) - return o, final_state \ No newline at end of file + return o, final_state diff --git a/tests/test_gdn_fused_fwd.py b/tests/test_gdn_fused_fwd.py index 363467f..999dfce 100644 --- a/tests/test_gdn_fused_fwd.py +++ b/tests/test_gdn_fused_fwd.py @@ -16,15 +16,16 @@ # Adapted from flash-linear-attention: https://github.com/fla-org/flash-linear-attention/blob/main/tests/ops/test_kda.py +"Comparing SM90 GDN Fused Prefill Kernel with FLA implementation" import pytest import torch import torch.nn.functional as F -from fla.ops.gated_delta_rule import naive_recurrent_gated_delta_rule, chunk_gated_delta_rule +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, naive_recurrent_gated_delta_rule from fla.utils import assert_close, device -from cula.utils import get_gdn_fused_fwd from cula.gdn.gate import naive_gdn_gate +from cula.utils import get_gdn_fused_fwd pytestmark = pytest.mark.sm90_only @@ -273,7 +274,11 @@ def test_safe_gate_chunk_varlen( ref = [] ref_ht = [] for i in range(N): - g_slice = naive_gdn_gate_fn(g[:, cu_seqlens[i] : cu_seqlens[i + 1]], A_log, dt_bias) if use_gate_in_kernel else g[:, cu_seqlens[i] : cu_seqlens[i + 1]] + g_slice = ( + naive_gdn_gate_fn(g[:, cu_seqlens[i] : cu_seqlens[i + 1]], A_log, dt_bias) + if use_gate_in_kernel + else g[:, cu_seqlens[i] : cu_seqlens[i + 1]] + ) ref_i, ref_ht_i = naive_recurrent_gated_delta_rule( q=F.normalize(q[:, cu_seqlens[i] : cu_seqlens[i + 1]], p=2, dim=-1), k=k[:, cu_seqlens[i] : cu_seqlens[i + 1]],