Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions csrc/api/gdn_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2025-2026 Ant Group Co., Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <ATen/cuda/CUDAContext.h>
#include <cute/numeric/numeric_types.hpp>
#include <cutlass/arch/arch.h>
#include <torch/extension.h>

#include "gdn/sm90/prefill_kernel.hpp"

using OptionalTensor = std::optional<torch::Tensor>;

std::tuple<torch::Tensor, torch::Tensor>
gdn_fwd_prefill(
OptionalTensor output_,
OptionalTensor output_state_,
torch::Tensor const& q,
torch::Tensor const& k,
torch::Tensor const& v,
OptionalTensor input_state_,
OptionalTensor alpha_,
OptionalTensor beta_,
torch::Tensor const& cu_seqlens,
torch::Tensor workspace_buffer,
float scale,
bool safe_gate) {
// Q, K, V: [packed_seq, H, D] (already packed by Python layer)
auto packed_seq = q.size(0);
auto num_heads = q.size(1);
auto head_size = q.size(2);
auto num_seqs = cu_seqlens.size(0) - 1;

// GDN constraint: all head counts must be the same
TORCH_CHECK(num_heads == k.size(1), "GDN requires num_q_heads == num_k_heads, got ", num_heads, " vs ", k.size(1));
TORCH_CHECK(num_heads == v.size(1), "GDN requires num_q_heads == num_v_heads, got ", num_heads, " vs ", v.size(1));
TORCH_CHECK(head_size == v.size(2), "GDN requires Q and V head dim to match, got ", head_size, " vs ", v.size(2));

// Allocate output if not provided
torch::Tensor output = output_.has_value() ? output_.value()
: torch::empty(
{packed_seq, num_heads, head_size},
torch::TensorOptions().dtype(q.dtype()).device(q.device()));

// Allocate output state if not provided
torch::Tensor output_state = output_state_.has_value()
? output_state_.value()
: torch::zeros(
{num_seqs, num_heads, head_size, head_size},
torch::TensorOptions().dtype(torch::kFloat32).device(q.device()));

// Validate dtypes
TORCH_CHECK(q.dtype() == torch::kBFloat16, "q must be bfloat16");
TORCH_CHECK(k.dtype() == torch::kBFloat16, "k must be bfloat16");
TORCH_CHECK(v.dtype() == torch::kBFloat16, "v must be bfloat16");
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32, "cu_seqlens must be int32");

// Validate contiguity
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
TORCH_CHECK(output.is_contiguous(), "output must be contiguous");
TORCH_CHECK(output_state.is_contiguous(), "output_state must be contiguous");
TORCH_CHECK(cu_seqlens.is_contiguous(), "cu_seqlens must be contiguous");
TORCH_CHECK(workspace_buffer.is_contiguous(), "workspace_buffer must be contiguous");

// Extract optional pointers
float const* alpha_ptr = nullptr;
float const* beta_ptr = nullptr;
float const* input_state_ptr = nullptr;

if (alpha_.has_value()) {
auto& alpha = alpha_.value();
TORCH_CHECK(alpha.dtype() == torch::kFloat32, "alpha must be float32");
TORCH_CHECK(alpha.is_contiguous(), "alpha must be contiguous");
TORCH_CHECK(
alpha.size(0) == packed_seq && alpha.size(1) == num_heads, "alpha shape must be [packed_seq, num_heads]");
alpha_ptr = alpha.data_ptr<float>();
}
if (beta_.has_value()) {
auto& beta = beta_.value();
TORCH_CHECK(beta.dtype() == torch::kFloat32, "beta must be float32");
TORCH_CHECK(beta.is_contiguous(), "beta must be contiguous");
TORCH_CHECK(
beta.size(0) == packed_seq && beta.size(1) == num_heads, "beta shape must be [packed_seq, num_heads]");
beta_ptr = beta.data_ptr<float>();
}
if (input_state_.has_value()) {
auto& input_state = input_state_.value();
TORCH_CHECK(input_state.dtype() == torch::kFloat32, "input_state must be float32");
TORCH_CHECK(input_state.is_contiguous(), "input_state must be contiguous");
input_state_ptr = input_state.data_ptr<float>();
}

// Auto-compute scale if 0
if (scale == 0.0f) {
scale = 1.0f / std::sqrt(static_cast<float>(head_size));
}

auto stream = at::cuda::getCurrentCUDAStream();
auto sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;

using bf16 = cute::bfloat16_t;
using Sm90 = cutlass::arch::Sm90;
gdn::sm90::launch_gdn_fwd_prefill_kernel<Sm90, bf16, bf16, float>(
stream,
reinterpret_cast<bf16*>(output.data_ptr()),
output_state.data_ptr<float>(),
reinterpret_cast<bf16 const*>(q.data_ptr()),
reinterpret_cast<bf16 const*>(k.data_ptr()),
reinterpret_cast<bf16 const*>(v.data_ptr()),
input_state_ptr,
alpha_ptr,
beta_ptr,
cu_seqlens.data_ptr<int32_t>(),
workspace_buffer.data_ptr<uint8_t>(),
static_cast<int32_t>(num_seqs),
static_cast<int32_t>(num_heads),
static_cast<int32_t>(head_size),
static_cast<int64_t>(packed_seq),
scale,
safe_gate,
static_cast<int32_t>(sm_count));

return {output, output_state};
}
16 changes: 16 additions & 0 deletions csrc/api/pybind.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ kda_fwd_prefill(
torch::Tensor workspace_buffer,
float scale,
bool safe_gate);

std::tuple<torch::Tensor, torch::Tensor>
gdn_fwd_prefill(
std::optional<torch::Tensor> output_,
std::optional<torch::Tensor> output_state_,
torch::Tensor const& q,
torch::Tensor const& k,
torch::Tensor const& v,
std::optional<torch::Tensor> input_state_,
std::optional<torch::Tensor> alpha_,
std::optional<torch::Tensor> beta_,
torch::Tensor const& cu_seqlens,
torch::Tensor workspace_buffer,
float scale,
bool safe_gate);
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand All @@ -75,5 +90,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#endif
#if defined(CULA_SM90A_ENABLED)
m.def("kda_fwd_prefill", &kda_fwd_prefill);
m.def("gdn_fwd_prefill", &gdn_fwd_prefill);
#endif
}
32 changes: 32 additions & 0 deletions csrc/gdn/sm90/changes_from_kda.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
## Important changes from KDA -> GDN

### Kept:

- The KK_inv lambda defined in mainloop is kept because it serves the same purpose of applying the beta. It is a modular function, so we don't need to worry about the scale applications - this needs to be changed in other lambda.

### Change:

- In load_kv, the cached state should be loaded NON-transposed. This is different from the kda implementation, c.f. mainloop_kda_fwd.hpp:909.
- Thinking more on this, it's better to just align the state shape with how the KDA implementation has already carried it out - this means that the expected state shape should be made explicit higher up in the API chain, possibly in FLA.
- In Kimi linear, the K matrix is multipled with the cumulative alpha gating matrix everywhere - it is also per channel. In GDN the first I + (KK^T) matrix also involves multiplying by alpha, but it is on a per-sequence, per-head basis. Thus the K_scaled and Q_scaled need to be rewritten.
- Alpha (gate) has previous shape of (B, T, H, K) for Kimi Linear, with per-channel gating. GDN instead has per-head gating, so the shape becomes (B, T, H), and we instead load a vector of size blkSeqQ == blkSeqK into shared memory. This means the atoms and layouts related to Alpha must all be changed, as well as the application of the gate.
- Because Alpha is now not the same shape as the Blk_Q/K/V tiles, it is now the same shape as the beta. This means that we don't need to create auxilliary layouts for the TMA loads, and instead we port over Alpha's SMEM layout into a CollectiveLoadVector
- The load_qkv in mainloop_gdn_fwd.hpp doesn't load alpha anymore - this is transferred to the load_beta
- extract_alpha_last needs to be changed to a simple index into the last index of shared alpha tensor, while checking for end of sequence boundaries. It just copies once.
- Alpha params are now changed to pointers with gmemlayout instead of TMALoad type
- Another alpha change - during GDN's forward pass, the gating matrix applied to KK^T is computed as the difference between [i,j] coords in log space, then exp2f. However, KDA instead applies an elementwise mask that is pre computed. The final QK^T , also coputed in compute_aux_safe, doesn't multiply on the alpha gate matrix, so it is a normal tensor core multiplication instead.
- SharedStorage needs to be changed in kernel
- Compute_aux_safe changes
- In s2r_compute_subchunk_operandA, I tried to keep the changes as minimal as possible, so I kept the behavior of copying a tile of A, but I broadcast the alpha values, which are now a row equivalent, to all the 32 columns in the subchunk. This allows the previous broadcast_row_0 + exp2f(g - g_first) values to still work. POSSIBLE OPTIMIZATION: It might be faster to just use the same register + identity tensor + row indexing across threads that
- Compute_loop_body changes
- The alpha loading in + scaling needs to be changed, since the KDA implementation loads in a tile of 32 across the head dimension. I did the same change that i did in compute_aux_safe to create a dummy tensor shape that broadcasts.
- I also stopped using a CopyAlpaAhtom and instead do a manual unrolled loop when loading in the shared alpha values for QK scaling
- KV state shape:
- It looks like KDA implementation uses the same V^T * K_scaled, with the KV_state shape being d_V x d_K in the output. This is also equivalent to the FLA transpose flag being set to TRUE.
- Change in kernel_gdn_fwd.hpp:
- Because the load type for alpha is now a predicated vector, we need to also change the alphapipelineparam initializaiton, moved it down next to beta, since they are loaded together

### Possible Optimizations:

- Because GDN doesn't need to materialize an entire register tile to hold results, we can load in the rows directly from shared memory and not worry about copying through to registers before multiplying. This could allow more aggressive use of the register file, in exchange for added latency from accessing SMEM. To keep consistency with the previous KDA implementation, I just used 0-strides to broadcast along the row dimension.

Loading