[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030
[CuTeDSL] Flash Attention v2 for SM120 (Blackwell GeForce)#3030blake-snc wants to merge 10 commits intoNVIDIA:mainfrom
Conversation
Add a CuTe DSL Flash Attention v2 forward pass kernel targeting SM120 (RTX 5090, GB10 / DGX Spark). SM120 lacks tcgen05 MMA support, so this implementation uses SM80-compatible tensor core instructions (mma.sync.aligned.m16n8k16) with CpAsync for global-to-shared memory transfers — the same proven approach as the Ampere FA2 example, tuned for SM120's 101 KB shared memory capacity. Features: - FP16 and BF16 support - Online softmax fusion (Flash Attention v2 algorithm) - Causal masking support - Configurable tile sizes (m_block_size, n_block_size) - Register pipeline for smem-to-register overlap - Predicated loads for boundary handling Tested on NVIDIA GB10 (SM121a / DGX Spark) with multiple configs: - head_dim=64/128, seqlen up to 2048, batch_size up to 4 - Both causal and non-causal modes - Asymmetric Q/K sequence lengths All verified against PyTorch scaled_dot_product_attention reference. Closes NVIDIA#2956 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Thanks. Are you using TMA? in fact, DGX spark should be flash attention 3 without wargroups from Hopper |
|
@johnnynunez Thanks for taking a look! To answer your questions: The link you shared points to No TMA yet. We're using CpAsync ( Regarding FA3 without warpgroups: From what I know, FA3's core performance gains come from three techniques that are deeply tied to async WGMMA. Producer-consumer warp specialization needs WGMMA + TMA overlap. Pingpong scheduling needs warpgroups with barrier synchronization. And intra-warpgroup GEMM-softmax overlap relies on WGMMA executing asynchronously, which One FA3 improvement that does seem portable is FP8 block quantization, where SM120's block-scaled MMA would be a natural fit. That could be a good follow-up kernel using |
|
How does this compare to the existing sm80 kernel here: https://github.com/Dao-AILab/flash-attention/blob/c4d8b0630eb81cf88206e0cc9e9bff4e7806d88f/flash_attn/cute/flash_fwd.py#L52 |
in fact if you execute this https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/flash_attention_v2.py , it runs in DGX Spark the thing is the results shared https://github.com/gau-nernst/learn-cuda/tree/main/02c_matmul_sm120 related with TMA |
|
@drisspg Good question - the Dao-AILab SM80 path uses the same fundamental approach: This PR is more narrowly scoped - an SM120-tuned standalone example for the CUTLASS repo addressing #2956, with tile sizes configured for SM120's 101 KB shared memory. The main opportunity for differentiation would be adding TMA ( |
|
oh whoops sorry you can ignore my comment I though this PR was in the FA repo, my bad |
Add FlashAttentionForwardSm120Tma alongside the existing CpAsync implementation, using TMA (cp.async.bulk) loads with a dedicated DMA warp for compute/load overlap: - 4D TMA descriptors (seq, dim, head, batch) for multi-batch support - TMA-compatible Swizzle(B, 4, 3) pattern (M=4 required for TMA hardware) - Warp specialization: 1 DMA warp + N MMA warps (default 4) - PipelineTmaAsync with mbarrier-based producer/consumer synchronization - Multi-stage KV double-buffering (configurable kv_stages) - Separate K and V pipelines for independent scheduling - SM80-compatible MMA (mma.sync.aligned.m16n8k16) unchanged Validated on NVIDIA GB10 (SM121a / DGX Spark) against PyTorch SDPA: B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Updated with a TMA variant (
Key finding during development: TMA on SM120 requires Usage: Verified on GB10 (SM121a) with B=1..4, S=128..1024, H=1..8, D=64/128, causal/non-causal. |
super cool! I'm going to report internally |
|
Thanks for the PR - did you happen to collect any performance data vs the existing FAV2 non-TMA kernel ? |
|
@blake-snc i love your work... do you want adapt it to https://github.com/Dao-AILab/flash-attention? If not i will try |
|
@johnnynunez Thanks, I really appreciate that! I'd love to take that on. I'll put up a PR to Dao-AILab/flash-attention with an SM120 path. They already have the CuTe DSL flash_fwd.py for SM80, so the structure is there to build on. I'm also collecting TMA vs CpAsync benchmark numbers on GB10 for @IonThruster's question as I did not think about that one, and I will post those a bit later! |
|
@IonThruster Here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a). Benchmark Configuration
Results
Note: Configs with SeqLen=8192 and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory. Summary
|
|
@IonThruster Great question — here are the benchmark results comparing the CpAsync (non-TMA) kernel vs. the TMA kernel from this PR, measured on DGX Spark (GB10 / SM121a). Benchmark Configuration
Results
Note: Configs with SeqLen=8192, and (B=4, S=4096, D=128, causal=yes) failed with OOM on the GB10's unified memory. Summary
Next StepsWe're investigating further optimizations to the TMA path:
Happy to share Nsight Compute profiles or run additional configs if useful. Contributed by Second Nature Computing |
SM120's FP8 MMA uses `mma.sync.aligned.kind::f8f6f4.m16n8k32` (SM120_16x8x32_TN in mma_sm120.hpp), which differs from SM89's FP8 instruction and is not yet exposed in the CuTe Python DSL. Added a NOTE documenting this for future FP8 FA enablement. Also fixed run/run_tma to capture and print avg execution time. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
FP8 Flash Attention Update: SM120 FP8 MMA Validated via Inline PTXWe investigated adding FP8 flash attention using SM120's Key findings:
Will update with the full FP8 FA kernel and benchmarks once complete. Contributed by Second Nature Computing |
FP8 Flash Attention Progress UpdateFollowing up on the earlier FP8 investigation — the proof-of-concept FP8 flash attention kernel is now producing correct output across all test configurations. What works:
Current kernel architecture (single-warp POC): The current kernel uses a minimal single-warp (32 threads) design with M=16, N=32 tiles and GMEM O accumulation. This validates correctness of the FP8 MMA register layout and the full softmax→conversion→GEMM pipeline, but is not performance-optimized — O accumulation through global memory is the dominant bottleneck. Next step: register-tiled multi-warp kernel Now working on a performance-optimized FP8 kernel matching the BF16 kernel's architecture (4 warps, M=128/N=64 tiles, register O accumulation). The Contributed by Second Nature Computing |
FP8 Flash Attention — Performance UpdateAdded Optimizations applied
Benchmark on DGX Spark (NVIDIA GB10, SM121a)FP8 kernel:
The FP8 kernel reaches 0.60–1.38x of the BF16 kernel's performance, peaking at 42 TFLOPS for D=64 and 35 TFLOPS for D=128. The B=4 S=1024 D=128 case beats BF16 by 38%. Remaining performance gapsThe main gap vs BF16 comes from:
Files added
|
FP8 Flash Attention using mma.sync.aligned.kind::f8f6f4.m16n8k32 with CpAsync pipelining and bank-conflict-free SMEM layout. - POC kernel: 1 warp, basic correctness validation - Optimized kernel: 4 warps, register O accumulation, vectorized 4x4 byte transpose via prmt.b32, CpAsync double-buffered K/V pipeline, +16 byte SMEM row padding for bank conflict elimination - FP8 GEMM helper with inline PTX MMA (workaround for NVIDIA#3044) - Benchmark script comparing FP8 vs BF16 kernels Performance on DGX Spark (SM121a): - FP8 peaks at 42.4 TFLOPS (D=64) and 35.3 TFLOPS (D=128) - FP8 outperforms BF16 by up to 1.38x at larger batch sizes Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
@johnnynunez Thanks for sharing the TMA results! We've seen gau-nernst's work — impressive numbers (94.4% peak on RTX 5090 with Ampere-era instructions). Our TMA variant ( Also — we already have the Dao-AILab/flash-attention adaptation up at #2268 (just rebased onto latest main). Happy to coordinate if you want to help expand it (e.g. varlen, backward pass). |
…nal tiles The TMA consumer loop was passing in_mask_steps=True for every KV tile when is_causal=True, applying expensive per-element causal masking (identity tensor creation + column comparisons) to all tiles including those fully below the diagonal. This caused up to 40% regression vs CpAsync on causal workloads. Fix: add a runtime check to only apply masking for the mask_steps tiles near the causal diagonal (n_block >= n_block_max - ceil_div(m_block_size, n_block_size)), matching the CpAsync variant's two-loop approach. Tiles fully below the diagonal use in_mask_steps=False and skip the masking. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
TMA Causal Masking Fix + Updated BenchmarksBug fix (commit
|
| Config | Before fix | After fix |
|---|---|---|
| B=1 S=512 D=64 causal | 0.69x | 1.00x |
| B=1 S=512 D=128 causal | 0.94x | 0.93x |
| B=1 S=1024 D=64 causal | 0.88x | 0.85x |
| B=1 S=1024 D=128 causal | 0.99x | 0.95x |
The biggest improvement is at short sequences where the masked-tile fraction is highest. Updated full benchmark results are in the PR description.
Two correctness fixes to FlashAttentionForwardSm120Tma:
1. Non-causal OOB masking: the last K tile is now masked when seqlen_k
is not divisible by n_block_size. TMA zero-fills OOB positions during
load, but softmax must treat them as -inf (not 0) to avoid corrupting
the normalization. The consumer loop now passes in_mask_steps=True for
n_block == n_block_max - 1 in the non-causal path, and
_softmax_rescale_O handles both causal and non-causal masking when
in_mask_steps=True.
2. SMEM capacity check: the previous estimate used 3*1024 bytes of
alignment overhead, which over-counted by ~2 KB. The mbar region
(< 200 B) rounds up to 1024 B before sQ; sQ and sKV are typically
multiples of 1024 B for standard tile configs, so they need no further
padding. Updated can_implement() to use the actual layout arithmetic,
allowing the default config (m=128, n=64, d=128, kv_stages=2, bf16)
which uses 97.0 KB of SM120's 99.0 KB SMEM budget.
Validated on SM121a with 8 configs:
- TMA: default non-causal, default causal, seqlen_k non-divisible
(non-causal and causal), head_dim=64 — all PASS
- CpAsync: same 4 configs — all PASS (no regressions)
Contributed by Second Nature Computing (https://joinsecondnature.com)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Update: Two additional correctness fixes (commit 6e193de) 1. Non-causal OOB maskingThe TMA variant was missing out-of-bounds masking for the last K tile when Fix: the consumer loop now passes Previously only visible with 2. SMEM capacity check in
|
Add BSD-3-Clause license header to benchmark_fp8_vs_bf16.py (was missing entirely). Normalize license header format in fp8_flash_attention.py to match the canonical CUTLASS style used across all other example files. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
cc @Junkai-Wu |
|
Ping — any NVIDIA maintainers available to review? This adds CuTe DSL Flash Attention v2 for SM120 (Blackwell GeForce). Happy to address any feedback. |
| """ | ||
|
|
||
|
|
||
| class FlashAttentionForwardSm120: |
There was a problem hiding this comment.
How does this non-TMA version compare to the Ampere version? It looks very similar and the Ampere kernel should be runnable on SM120/121.
|
Re: @depaulmillz's comment on the non-TMA version vs Ampere: You're right — the CpAsync version ( The main value-add of this file is the TMA variant ( |
|
This PR has been labeled |
|
Still relevant on current main ( Note: commenting to reset the |
|
Hi @blake-snc, I've run the
There may be environment differences or I missed some critical steps. Could you share you benchmark setup/steps to repro the result? Thanks! Environment setup:
Repro command: Result on NVIDIA RTX PRO 6000 (SM120):
Result on NVIDIA GB10 (SM121):
Reference failure log example: |
|
@Aneureka my Spark is occupied with a training run through tomorrow, but I will look at this as soon as that device frees up. I have a few thoughts, but need to confirm! Thanks for the quick follow-up! |
|
Runtime confirmation on cutlass-dsl 4.4.2 / NVIDIA GB10 (sm_121a) today against current main
|
|
@Aneureka thanks for the careful repro. I ran your exact command on my SM121a today and confirmed both findings — FP8 causal=True FAILing across the matrix and FP8 trailing BF16 in the non-causal cases. Both reproduce on cutlass-dsl 4.4.2 / cutlass main Working on a fix now; will follow up on this PR with results and an updated benchmark. |
q/k/v have shape (B, S, H, D) contiguous. Reshaping to (B*H, S, D) via plain .view() succeeds without erroring but interleaves seq and head; for H=1 it is identity, for H>=2 the per-head reference becomes a wrong layout and the bmm reference is computed against the wrong data. With that, every causal config in benchmark_fp8_vs_bf16.py reported FAIL on the kernel even though the kernel produces the correct (B, S, H, D) output via direct head_idx indexing. Permuting (B, S, H, D) -> (B, H, S, D), making it contiguous, and then viewing as (B*H, S, D) gives the correct per-head batched layout for bmm. The output is permuted back the same way before comparison. After the fix the full 16-config matrix in benchmark_fp8_vs_bf16.py passes (max_diff <= 0.005), including all causal=True cases that previously reported FAIL. The kernel itself is unchanged. Reported by @Aneureka in PR comment 4322113xxx. Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com> Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Optimized FP8 SM120 attention combining the BF16 TMA kernel's infrastructure with the FP8 inline-PTX MMA path. Q/K/V all loaded via inline-PTX ldmatrix.x4(.trans).shared.b16 from swizzled SMEM with manual swizzle bit-math in the address path (cutlass-dsl 4.4.2's Pointer.llvm_ptr does not apply layout swizzle to inline-PTX operands). Register-resident P via warp shuffles eliminates the sP_f32 SMEM round-trip. M=128 N=64 with n_warp_groups=2 (each MMA warp handles two m16 row groups). Bench (DGX Spark / SM121a, --iters 50, median of 3 runs): - 1.09x sum-time vs FP8FlashAttentionSm120Opt across the 16-config matrix - Peak 1.51x on B=4 S=1024 D=128 (833us → 551us) - Strong wins on D=128 large-batch shapes - Trails BF16 TMA on sum-time (0.88x) — the f8f6f4 m16n8k32 B operand layout requires per-thread bridge shuffles that don't fully amortize on small-S D=64 shapes Engineering notes: - Path A bridge bug: select-then-shuffle reads source's post-selection value (which differs by source's q). Fix is shuffle-then-select per receiver. Validated via byte-pattern probe. - Manual swizzle: Swizzle(B,4,3) is `addr ^ (((addr>>7) & ((1<<B)-1)) << 4)` - Bridge cost differs by operand: Q (4 shuffles), K (8 shuffles + 4 prmt, no selp because ni is constexpr), V (16 shuffles + 8 selp + 4 prmt because PV q is per-thread) - 26 correctness configs (B/S/H/D × causal × non-divisible Sk) PASS at full accuracy with zero_rows=0 per-row coverage check benchmark_fp8_vs_bf16.py extended to compare 3 kernels: FP8 Opt, FP8 TMA, BF16 TMA, with TMA/Opt and TMA/BF16 ratio columns. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Hi @Aneureka, follow-up here on the FP8 perf gap and the test failure you flagged. Both are addressed. Test failure — reference bug, not kernel bug
Perf gap — new
|
| Comparison | Sum-time | Was (FP8 Opt baseline before this work) |
|---|---|---|
| FP8 TMA / FP8 Opt | 1.09× | 1.00× |
| FP8 TMA / BF16 TMA | 0.88× | 0.84× |
Where FP8 TMA wins decisively (production-shape side):
| B | S | D | Causal | TMA/Opt median |
|---|---|---|---|---|
| 4 | 1024 | 128 | no | 1.51× |
| 4 | 512 | 128 | no | 1.19× |
| 1 | 2048 | 128 | no | 1.11× |
| 1 | 1024 | 128 | no | 1.11× |
| 4 | 1024 | 64 | no | 1.03× |
Honest framing on the BF16 gap
This variant doesn't fully match BF16 TMA on sum-time (0.88× median). The remaining gap traces to a structural cost: the f8f6f4 m16n8k32 B operand has a gid/tip swap vs standard m16n8k32, so ldmatrix output doesn't directly produce the B-operand register layout. A per-thread bridge (shuffle + selp + prmt) is needed — about 16 shuffles per ldmatrix call in the PV path. On D=128 large-batch the bridge amortizes and FP8 wins; on small-S D=64 the bridge cost dominates the throughput advantage and BF16 still wins.
A few specific structural choices in case useful for your reference / for others reading this thread:
- The
gid/tip-swapped FP8 layout was empirically validated with a byte-pattern probe before integration. Happy to upstream the probe as a test utility if useful. - A subtle bug in early integrations:
select-then-shufflefor the bridge reads the source thread's post-selection value, which differs based on the source'sq. Correct pattern isshfl(r0, src)andshfl(r2, src)separately, thenselp.b32locally per receiver. - cutlass-dsl 4.4.2's
Pointer.llvm_ptrdoesn't apply the layout's swizzle to inline-PTX address operands, so the ldmatrix path computes the swizzled byte address manually:addr ^ (((addr >> 7) & ((1<<B)-1)) << 4). If a future cutlass-dsl exposes a TiledMma for FP8 m16n8k32 (CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044), the high-levelmake_tiled_copy_Bpath would replace this and the bridge altogether.
Reproducing locally:
python examples/python/CuTeDSL/blackwell_geforce/benchmark_fp8_vs_bf16.py --iters 50Validation: 26 correctness configs (B/S/H/D × causal × non-divisible Sk shapes), max_diff ≤ 0.007 typical.
Future work avenues that would close more of the BF16 gap are tracked internally — primarily instruction-scheduling tuning (profiler-driven) and the high-level atom path once #3044 lands.
Contributed by Second Nature Computing — tested on DGX Spark hardware
Default n_block_size 64 → 32. Sum-time over the 16-config Aneureka matrix improves from 0.88x to 0.91x of BF16 TMA, and per-shape N picking in benchmark_fp8_vs_bf16.py reaches 0.93x. Several configs now beat BF16 TMA outright (B=1 S=512 D=64 non-causal: 1.04x; B=1 S=2048 D=64: 1.06x; B=4 S=512 D=64: 1.03-1.05x). Why N=32 wins broadly: smaller tile reduces fixed per-K-iter overhead, which helps causal (work halves per CTA) and small-S shapes. N=64 still wins on B*H ≥ 64 non-causal D=128 where more work-per-iter amortizes the f8f6f4 bridge cost — bench picks per shape. BF16 output mode: pass a torch.bfloat16 mO and the epilogue uses cvt.rn.bf16x2.f32 + packed Int32 store, matching the production FP8-attention pattern (FP8 in, BF16 out). FP32 path unchanged via cutlass.const_expr branch on mO.element_type. Removed the dead sVt allocation (10 KB SMEM saved); Path A made the transpose-staging buffer obsolete. Validated: 26 F32-out configs and 18 BF16-out configs pass at full accuracy on DGX Spark (SM121a). max_diff 0.0003-0.007 typical. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
Quick perf-tuning follow-up on top of the FP8 TMA variant. Tightened the BF16 TMA gap from 0.88× → 0.91× sum-time on the 16-config matrix (per-shape N picking gets to 0.93×), and several configs now beat BF16 outright. What landed (commit
|
| Change | Effect |
|---|---|
n_block_size default 64 → 32 |
+3pp BF16 sum-time (0.88× → 0.91×). N=32 reduces fixed per-K-iter overhead — wins on causal (work halves per CTA) and small-shape D=64 cases. N=64 still wins on B*H ≥ 64 non-causal D=128, so benchmark_fp8_vs_bf16.py picks per shape (+2pp more, → 0.93×). |
| BF16 output mode | cutlass.const_expr branch on mO.element_type — pass a torch.bfloat16 tensor and the epilogue uses cvt.rn.bf16x2.f32 + packed Int32 store. Matches the production FP8-attention pattern (FP8 in, BF16 out). FP32 path unchanged. |
sVt removal |
Dead since Path A; saved 10 KB SMEM. |
Configs where FP8 TMA now beats BF16 TMA
| B | S | D | TMA / BF16 |
|---|---|---|---|
| 1 | 512 | 64 | 1.04× |
| 1 | 2048 | 64 | 1.06× |
| 4 | 512 | 64 | 1.03–1.05× |
| 1 | 1024 | 64 | 1.00× |
Things I tried that didn't pan out (so the path is documented)
kv_stages=3— wash. Diagnostic showed the consumer wasn't TMA-bound, so adding pipeline depth had no effect.- Software prefetching V
ldmatrixone iter ahead — wash. ptxas was already hoisting the load under the bridge arithmetic. - F32 epilogue stores via Int64 recast (vec2 store) — mixed. Wins on D=64 small-S, regressions on D=128 large-S — net wash on sum-time.
- N=128 — 30–40% slower from register-pressure spill. The register budget already accommodates
s_regs + acc_O + Q register cachetightly at N=64. - Diagnostic: removed all epilogue stores — only moved 0.88× → 0.92×. So even a perfect store optimization caps at ~4pp, which is why the BF16-output / vec-store experiments couldn't move the needle alone.
Honest framing on the remaining gap
The remaining ~7% to BF16 TMA is structural: the f8f6f4 register-layout gid/tip swap means each ldmatrix output needs a per-thread bridge (PV's V bridge in particular: 16 shuffles + 8 selp + 4 prmt per call). Eliminating this without doubling the ldmatrix count needs a TiledMma atom for FP8 m16n8k32, which is tracked as #3044 — re-validated as still segfaulting on cutlass-dsl 4.4.2 (current latest) on 2026-04-25, with the fix targeted for 4.5 per @shubaoyu2. Once that lands, the high-level make_tiled_copy_B path replaces both the manual swizzle bit-math and the bridge.
Reproducing:
python examples/python/CuTeDSL/blackwell_geforce/benchmark_fp8_vs_bf16.py --iters 50Contributed by Second Nature Computing — tested on DGX Spark hardware
|
Thanks for the follow-up! Confirmed that now the mismatches are gone, and the remaining issue seems FP8 TMA perf constantly outperformed by BF16 by ~20% (on my end with the same repro steps). We should investigate the perf gap and narrow it to under 10% or, ideally, 5%, and leave a comment for future improvement. I'm also open to merging it first if other reviewers agree, as it is functional ready, and then we follow up on the perf later. Latest result on DGX Spark:
|
Summary
Adds CuTe DSL Flash Attention v2 forward pass kernels for SM120 (RTX 5090, GB10 / DGX Spark), addressing the lack of high-performance FA kernels for this architecture.
Three implementations included:
BF16 Kernels (
flash_attention_v2.py)CpAsync variant (
FlashAttentionForwardSm120): All threads perform both loads and compute usingcp.async— the Ampere-era approach, tuned for SM120's 101 KB shared memory.TMA variant (
FlashAttentionForwardSm120Tma): Uses TMA (cp.async.bulk) with warp specialization — 1 dedicated DMA warp handles TMA loads while N MMA warps compute, enabling load/compute overlap via multi-stage KV pipelining withPipelineTmaAsyncmbarrier synchronization.Both use SM80-compatible tensor core instructions (
mma.sync.aligned.m16n8k16) since SM120 lacks tcgen05 MMA. Supports FP16/BF16, causal/non-causal, configurable tile sizes, asymmetric Q/K lengths, online softmax fusion, and register pipelining.FP8 Kernel (
fp8_flash_attention.py)FP8FlashAttentionSm120Opt): Uses SM120's native FP8 MMA instruction (mma.sync.aligned.kind::f8f6f4.m16n8k32) via inline PTX, with CpAsync double-buffered K/V pipeline, vectorized 4×4 byte transpose viaprmt.b32, and bank-conflict-free SMEM layout (+16 byte row padding).FP8 kernel features:
MmaAtomSM80Typesegfault with FP8 types (CuTe DSL: FP8 MMA segfaults on SM120 — MmaAtomSM80Type missing kind::f8f6f4 lowering #3044)prmt.b324×4 byte shuffle (noldmatrix.b8on SM120)TMA variant details
(seq, dim, head, batch)for native multi-batch supportSwizzle(B, 4, 3)pattern (M=4 required by TMA hardware; the CpAsync version usesSwizzle(B, 3, 3)which is not valid for TMA)kv_stagesfor double-buffering (default 2, falls back to 1 for large head_dim)Benchmark Results
FP8 vs BF16 Performance on DGX Spark (SM121a)
FP8 kernel:
FP8FlashAttentionSm120Opt(CpAsync, bank-conflict-free, 4 warps, M=64, N=32)BF16 kernel:
FlashAttentionForwardSm120(CpAsync, tiled MMA, M=128, N=64)Key findings:
BF16 TMA vs CpAsync Performance on DGX Spark (SM121a)
TMA kernel:
FlashAttentionForwardSm120Tma(warp-specialized, 3 MMA + 1 DMA warp,PipelineTmaAsync)CpAsync kernel:
FlashAttentionForwardSm120(all threads load+compute,cp.async)Both: M=128, N=64, 16 heads, 20 warmup iters, 100 timed iters
Geometric mean speedup: 0.95x · Min: 0.62x · Max: 1.25x
Key findings:
cp.async.bulk+ warp specialization on SM120, a pattern that scales better with larger tile sizes and multi-stage pipeliningBF16 TMA vs CpAsync — Updated (post causal masking fix)
The initial results above had two issues:
JIT compilation artifacts: The CuTe DSL JIT-compiles kernels on first invocation, which inflated/deflated some configs (e.g. the 1.25x outlier at B=1 S=1024 D=64 was a JIT artifact). The updated benchmark pre-warms all kernel variants before timing.
Causal masking bug (fixed in
106e24b): The TMA variant applied expensive per-element causal masking to all KV tiles instead of only theceil_div(m_block, n_block)tiles near the diagonal. The CpAsync variant correctly used a two-loop structure (masked loop + fast loop). This caused up to 40% regression on causal configs (e.g. B=1 S=512 D=64 causal: 0.69x → 1.00x after fix).Updated results with JIT pre-warming + causal fix:
Geometric mean speedup: 0.93x · Min: 0.77x · Max: 1.10x
What changed vs initial results:
Updated key findings:
cp.async.bulk+ warp specialization patterns in the CuTe DSLTest Results
Validated on NVIDIA GB10 (SM121a / DGX Spark) hardware at Second Nature Computing against PyTorch
scaled_dot_product_attention:BF16 CpAsync variant
BF16 TMA variant (
--use_tma)FP8 kernel
Tolerance:
atol=1e-02, rtol=1e-04Motivation
There are currently few high-performance flash attention kernels available for SM120. The existing Blackwell FMHA (
blackwell/fmha.py) targets SM100 and uses tcgen05 MMA + TMEM, which SM120 does not support. This implementation fills that gap by:f8f6f4MMA instructions for up to 2× arithmetic throughput over BF16Usage
Closes #2956
Contributed by Second Nature Computing — tested on DGX Spark hardware