-
Notifications
You must be signed in to change notification settings - Fork 50
KDA_prefill_sm89 #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Azir9
wants to merge
1
commit into
inclusionAI:main
Choose a base branch
from
Azir9:feature-kda-operator
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
KDA_prefill_sm89 #47
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # KDA prefill (CUDA) | ||
|
|
||
| Fused-style **KDA (Kimi Delta Attention) prefill** for float32: two kernels (intra-chunk KKT / W–U, then inter-chunk recurrence). Targets **sm_89+** (CMake default: `89`). Layout matches [cuLA](https://github.com/inclusionAI/cuLA) KDA tensor conventions: `q/k/g` `[B,H,T,K]`, `v/o` `[B,H,T,V]`, `beta` `[B,H,T]`, chunk buffers `w/u` `[B,H,num_chunks,C,K|V]`. | ||
|
|
||
| Standalone **CMake** build plus **ctypes** benchmark (`src/main.py`). Tensor layout matches [cuLA](https://github.com/inclusionAI/cuLA) KDA conventions for a future port into `csrc/kda/sm90/`. | ||
|
|
||
| ## Build | ||
|
|
||
| ```bash | ||
| cmake -S . -B build -DCMAKE_CUDA_ARCHITECTURES=89 | ||
| cmake --build build | ||
| ``` | ||
|
|
||
| Artifacts: `libkda_prefill.a`, `libkda_prefill_runtime.so` (C ABI: `kda_prefill_f32`). | ||
|
|
||
| ## Python benchmark (optional FLA) | ||
|
|
||
| Requires PyTorch and, for FLA comparison, [flash-linear-attention](https://github.com/sustcsonglin/flash-linear-attention) (`fla`). | ||
|
|
||
| ```bash | ||
| python src/main.py --suite fla-benchmark --B 2 --H 8 --T 4096 --K 64 --V 64 --C 32 | ||
| ``` | ||
|
|
||
| - `local-benchmark` — CUDA only | ||
| - `fla-only-benchmark` — FLA only | ||
| - `fla-benchmark` — both + quick accuracy check vs `chunk_kda` | ||
|
|
||
| ## Tuning env (optional) | ||
|
|
||
| | Variable | Role | | ||
| |----------|------| | ||
| | `KDA_INTRA_THREADS` | Block size for intra-chunk kernel | | ||
| | `KDA_INTER_THREADS` | Block size for inter-chunk kernel | | ||
| | `KDA_INTER_SHARDS` | V-way sharding for inter kernel (must divide `V`) | | ||
|
|
||
| ## Benchmark vs FLA (flash-linear-attention) | ||
|
|
||
| **Hardware:** NVIDIA **RTX 4070**, **sm_89** (Ada). CMake built with `CMAKE_CUDA_ARCHITECTURES=89`. | ||
|
|
||
| **Setup:** `K=V=64`, chunk `C=32`. Local path is **float32**; FLA `chunk_kda` in `src/main.py` runs **bf16** — timing numbers are not apples-to-apples on dtype, but useful as a rough baseline. | ||
|
|
||
| Latest sweep (machine‑logged) is under [`analysis/ncu/20260409-193656-large-shape-compare/`](analysis/ncu/20260409-193656-large-shape-compare/README.md): summary table in that folder’s README, raw CSV [`benchmark_results.csv`](analysis/ncu/20260409-193656-large-shape-compare/benchmark_results.csv), console captures in [`benchmarks/`](analysis/ncu/20260409-193656-large-shape-compare/benchmarks/). | ||
|
|
||
| | Shape (B,H,T) | iters | Local ms | FLA ms | Local / FLA | | ||
| |---------------|------:|---------:|-------:|------------:| | ||
| | (2,8,4096) | 10 | 1.8884 | 1.1918 | 1.58× | | ||
| | (2,8,8192) | 10 | 3.6517 | 3.1311 | 1.17× | | ||
| | (4,8,4096) | 10 | 3.8848 | 3.0597 | 1.27× | | ||
| | (2,16,4096) | 10 | 3.8868 | 3.1079 | 1.25× | | ||
| | (2,8,16384) | 5 | 7.2716 | 6.5675 | 1.11× | | ||
|
|
||
| Quick accuracy line from the same harness (local fp32 vs FLA bf16): `max_abs ≈ 4.09`, `mean_abs ≈ 0.80` on the small validation shape inside `src/main.py` (see folder README). | ||
|
|
||
| ## Profiling artifacts | ||
|
|
||
| `analysis/ncu/20260409-193656-large-shape-compare/` holds the FLA comparison CSV, console logs under `benchmarks/`, and a short README. Optional for builds. | ||
|
|
||
| ## Supported shapes | ||
|
|
||
| Instantiated in `src/kda.cu`: `(K,V,C)` in `(64,64,64)`, `(64,64,32)`, `(128,128,64)`, `(128,128,32)`. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| #pragma once | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include "kda_prefill_io.hpp" | ||
|
|
||
| #include <algorithm> | ||
| #include <cstdint> | ||
| #include <cstdlib> | ||
| #include <stdexcept> | ||
| #include <string> | ||
|
|
||
| #include "kda_kernel.hpp" | ||
| #include "kda_params.hpp" | ||
|
|
||
| namespace kda { | ||
| namespace api { | ||
|
|
||
| inline void check_cuda(cudaError_t err, const char* msg) { | ||
| if (err != cudaSuccess) { | ||
| throw std::runtime_error(std::string(msg) + ": " + cudaGetErrorString(err)); | ||
| } | ||
| } | ||
|
|
||
| inline void check_sm89_or_newer(const cudaDeviceProp& prop) { | ||
| const int sm = prop.major * 10 + prop.minor; | ||
| if (sm < 89) { | ||
| throw std::runtime_error( | ||
| "KDA prefill requires an sm89-or-newer GPU, but current device is sm_" + | ||
| std::to_string(prop.major) + std::to_string(prop.minor)); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| inline void normalize_prefill_io(KdaPrefillIO<T>& io) { | ||
| if (io.chunk_size <= 0) { | ||
| throw std::runtime_error("prefill io chunk_size must be positive"); | ||
| } | ||
| if (io.num_chunks <= 0) { | ||
| io.num_chunks = (io.seq_len + io.chunk_size - 1) / io.chunk_size; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| inline void validate_prefill_io(const KdaPrefillIO<T>& io) { | ||
| if (!io.q_ptr || !io.k_ptr || !io.v_ptr || !io.g_ptr || !io.beta_ptr || | ||
| !io.w_ptr || !io.u_ptr || !io.o_ptr) { | ||
| throw std::runtime_error("prefill io contains null pointers"); | ||
| } | ||
| if (io.batch_size <= 0 || io.num_heads <= 0 || io.seq_len <= 0 || | ||
| io.head_dim <= 0 || io.value_dim <= 0 || io.chunk_size <= 0) { | ||
| throw std::runtime_error("prefill io dimensions must be positive"); | ||
| } | ||
| } | ||
|
|
||
| // Split V across inter-chunk blocks to bound per-block shared memory. | ||
| inline int choose_inter_v_shards(int value_dim, | ||
| int base_inter_blocks, | ||
| int target_inter_blocks) { | ||
| if (value_dim <= 0 || base_inter_blocks <= 0) { | ||
| return 1; | ||
| } | ||
|
|
||
| int desired_shards = | ||
| (target_inter_blocks + base_inter_blocks - 1) / base_inter_blocks; | ||
| desired_shards = std::max(1, std::min(desired_shards, value_dim)); | ||
|
|
||
| for (int shards = desired_shards; shards <= value_dim; ++shards) { | ||
| if (value_dim % shards == 0) { | ||
| return shards; | ||
| } | ||
| } | ||
| return desired_shards; | ||
| } | ||
|
|
||
|
|
||
| inline int read_env_threads_override(const char* name, int fallback) { | ||
| const char* value = std::getenv(name); | ||
| if (value == nullptr || *value == '\0') { | ||
| return fallback; | ||
| } | ||
| const int parsed = std::atoi(value); | ||
| return parsed > 0 ? parsed : fallback; | ||
| } | ||
|
|
||
| inline int read_env_positive_override(const char* name, int fallback) { | ||
| const char* value = std::getenv(name); | ||
| if (value == nullptr || *value == '\0') { | ||
| return fallback; | ||
| } | ||
| const int parsed = std::atoi(value); | ||
| return parsed > 0 ? parsed : fallback; | ||
| } | ||
|
|
||
| template <typename Params> | ||
| inline size_t intra_chunk_smem_bytes() { | ||
| using T = typename Params::element_type; | ||
| return sizeof(T) * | ||
| (Params::C * (Params::K_DIM + 1) + Params::C * (Params::V_DIM + 1) + | ||
| Params::C * (Params::K_DIM + 1) + Params::C) + | ||
| sizeof(float) * (Params::C * (Params::C + 1) + Params::C); | ||
| } | ||
|
|
||
| template <typename Params> | ||
| inline size_t inter_chunk_smem_bytes(const Params& params) { | ||
| using T = typename Params::element_type; | ||
| const size_t wmma_scratch = | ||
| (Params::K_DIM == 64 && Params::V_DIM == 64 && Params::C == 32 && | ||
| params.inter_v_tile == 16) | ||
| ? sizeof(float) * Params::K_DIM * params.inter_v_tile | ||
| : 0; | ||
| return sizeof(float) * (Params::K_DIM * (params.inter_v_tile + 1)) + | ||
| sizeof(T) * (Params::C * (Params::K_DIM + 1) + Params::C * (Params::K_DIM + 1) + | ||
| Params::C * (params.inter_v_tile + 1)) + | ||
| sizeof(float) * (Params::C + Params::C * (Params::K_DIM + 1)) + | ||
| wmma_scratch; | ||
| } | ||
|
|
||
| template <typename Params> | ||
| inline size_t fused_prefill_smem_bytes(const Params& params) { | ||
| using T = typename Params::element_type; | ||
| const size_t k_stride = Params::K_DIM + 1; | ||
| const size_t v_stride = params.inter_v_tile + 1; | ||
| const size_t m_stride = Params::C + 1; | ||
| return sizeof(T) * | ||
| (2 * Params::C * k_stride + 2 * Params::C * k_stride + | ||
| 2 * Params::C * k_stride + 2 * Params::C * v_stride + | ||
| 2 * Params::C + Params::C * k_stride + Params::C * v_stride) + | ||
| sizeof(float) * | ||
| (Params::K_DIM * v_stride + Params::C * m_stride + | ||
| Params::C * k_stride + Params::C); | ||
| } | ||
|
|
||
| template <typename Params> | ||
| inline void launch_kda_prefill_kernel(Params params, | ||
| cudaStream_t stream = 0, | ||
| int intra_threads = 256, | ||
| int default_inter_threads = 256) { | ||
| if (params.seq_len <= 0 || params.batch_size <= 0 || params.num_heads <= 0) { | ||
| return; | ||
| } | ||
|
|
||
| int dev = 0; | ||
| check_cuda(cudaGetDevice(&dev), "get current device failed"); | ||
| cudaDeviceProp prop{}; | ||
| check_cuda(cudaGetDeviceProperties(&prop, dev), "get device properties failed"); | ||
| check_sm89_or_newer(prop); | ||
|
|
||
| const int base_inter_blocks = params.batch_size * params.num_heads; | ||
| const int target_inter_blocks = | ||
| std::max(prop.multiProcessorCount * 2, base_inter_blocks); | ||
| int shards = | ||
| choose_inter_v_shards(Params::V_DIM, base_inter_blocks, target_inter_blocks); | ||
| const char* inter_shards_env = std::getenv("KDA_INTER_SHARDS"); | ||
| if (inter_shards_env != nullptr && *inter_shards_env != '\0') { | ||
| const int requested_shards = std::atoi(inter_shards_env); | ||
| if (requested_shards > 0 && Params::V_DIM % requested_shards == 0) { | ||
| shards = requested_shards; | ||
| } | ||
| } else if (Params::K_DIM == 64 && Params::V_DIM == 64 && Params::C == 32) { | ||
| shards = 4; | ||
| } | ||
|
|
||
| params.inter_v_shards = shards; | ||
| params.inter_v_tile = (Params::V_DIM + shards - 1) / shards; | ||
|
|
||
| const int intra_smem = static_cast<int>(intra_chunk_smem_bytes<Params>()); | ||
| const int inter_smem = static_cast<int>(inter_chunk_smem_bytes(params)); | ||
| constexpr bool use_fused_kernel = false; | ||
| const int fused_smem = | ||
| use_fused_kernel ? static_cast<int>(fused_prefill_smem_bytes(params)) : 0; | ||
|
|
||
| cudaError_t attr_err = cudaSuccess; | ||
| if constexpr (use_fused_kernel) { | ||
| attr_err = cudaFuncSetAttribute( | ||
| kernel::prefill::kda_fused_prefill_kernel<Params>, | ||
| cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
| fused_smem); | ||
| if (attr_err != cudaSuccess && fused_smem <= 48 * 1024) { | ||
| check_cuda(attr_err, "set fused shared memory attribute failed"); | ||
| } | ||
| } else { | ||
| attr_err = cudaFuncSetAttribute( | ||
| kernel::prefill::kda_intra_chunk_kernel<Params>, | ||
| cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
| intra_smem); | ||
| if (attr_err != cudaSuccess && intra_smem <= 48 * 1024) { | ||
| check_cuda(attr_err, "set intra shared memory attribute failed"); | ||
| } | ||
| attr_err = cudaFuncSetAttribute( | ||
| kernel::prefill::kda_inter_chunk_rnn_kernel<Params>, | ||
| cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
| inter_smem); | ||
| if (attr_err != cudaSuccess && inter_smem <= 48 * 1024) { | ||
| check_cuda(attr_err, "set inter shared memory attribute failed"); | ||
| } | ||
| } | ||
|
|
||
| if constexpr (use_fused_kernel) { | ||
| dim3 grid_fused(params.num_heads, params.batch_size, params.inter_v_shards); | ||
| kernel::prefill::kda_fused_prefill_kernel<Params> | ||
| <<<grid_fused, default_inter_threads, fused_smem, stream>>>(params); | ||
| check_cuda(cudaGetLastError(), "launch kda_fused_prefill_kernel failed"); | ||
| return; | ||
| } | ||
|
|
||
| dim3 grid_intra(params.num_chunks, params.num_heads, params.batch_size); | ||
| intra_threads = read_env_threads_override("KDA_INTRA_THREADS", intra_threads); | ||
| kernel::prefill::kda_intra_chunk_kernel<Params> | ||
| <<<grid_intra, intra_threads, intra_smem, stream>>>(params); | ||
| check_cuda(cudaGetLastError(), "launch kda_intra_chunk_kernel failed"); | ||
|
|
||
| int inter_threads = | ||
| (Params::K_DIM == 64 && Params::V_DIM == 64 && Params::C == 32) ? default_inter_threads : intra_threads; | ||
| inter_threads = read_env_threads_override("KDA_INTER_THREADS", inter_threads); | ||
| dim3 grid_inter(params.num_heads, params.batch_size, params.inter_v_shards); | ||
| kernel::prefill::kda_inter_chunk_rnn_kernel<Params> | ||
| <<<grid_inter, inter_threads, inter_smem, stream>>>(params); | ||
| check_cuda(cudaGetLastError(), "launch kda_inter_chunk_rnn_kernel failed"); | ||
| } | ||
|
|
||
| cudaError_t launch_kda_prefill_f32(const KdaPrefillIO<float>& io, | ||
| cudaStream_t stream = 0); | ||
|
|
||
| inline cudaError_t launch_kda_prefill(KdaPrefillIO<float> io, | ||
| cudaStream_t stream = 0) { | ||
| normalize_prefill_io(io); | ||
| validate_prefill_io(io); | ||
| return launch_kda_prefill_f32(io, stream); | ||
| } | ||
|
|
||
| } // namespace api | ||
| } // namespace kda | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for checking
cudaFuncSetAttributefailure seems inconsistent. It only throws an error if the requested shared memory is less than or equal to 48KB. If the kernel requires more than 48KB (which is common for fused kernels on sm89) and the attribute setting fails, the launch will likely fail with an 'out of shared memory' error anyway. It is better to check for success regardless of the size, or provide a more descriptive error message if the hardware limit is exceeded.