Skip to content

Add SM80 (Ampere/A100) dense MLA decode kernel#183

Open
bzantium wants to merge 1 commit intodeepseek-ai:mainfrom
bzantium:sm80-dense-decode-port
Open

Add SM80 (Ampere/A100) dense MLA decode kernel#183
bzantium wants to merge 1 commit intodeepseek-ai:mainfrom
bzantium:sm80-dense-decode-port

Conversation

@bzantium
Copy link
Copy Markdown

@bzantium bzantium commented May 1, 2026

Summary

Implements a CUDA kernel for the dense MLA decode path on SM80 GPUs (Ampere / A100). The current README excludes Ampere; this PR enables A100 deployment without forcing migration to Hopper.

  • Drop-in: dense_decode_fwd API and the combine kernel are unchanged. Only the decode kernel itself is new.
  • Correctness verified against a PyTorch eager BMM reference across BF16/FP16, multi-batch, multi-KV-head, causal mask, and split-K.
  • Build flag: FLASH_MLA_DISABLE_SM100=1 FLASH_MLA_DISABLE_SM90=1 pip install -v . for A100-only nodes.

Why

The SM90 kernel relies on TMA, WGMMA, thread block clusters, and mbarrier async barriers — all sm90+ only. SM80 must use cp.async + mma.m16n8k16 + ldmatrix + cutlass::arch::NamedBarrier. This is a separate kernel rather than a SM90 fallback path.

Design

  • 256 threads / CTA = 4 warpgroups × 1 warp × 32 lanes
  • BLOCK_M=16 (one warp covers M via mma.m16n8k16); each wg owns HEAD_DIM_V/4 = 128 V columns
  • All wgs compute QK^T independently (no cross-wg P transfer); compute is not the bottleneck
  • 2× sK SMEM buffer + cross-block cp.async prefetch
  • XOR swizzle (Swizzle<3,3,3>-equivalent) for SMEM bank-conflict-free ldmatrix
  • cp.async.cg (L1 bypass) for K loading; cp.async.ca for Q (reused across K iterations)
  • SMEM 162 KB / CTA = 18 KB sQ + 2 × 72 KB sK, fits within the 164 KB opt-in cap

Detailed design rationale and the SMEM budget tradeoffs are in docs/sm80-dense-decode-design.md.

Performance

A100-SXM4-80GB, peak HBM 2039 GB/s. Median of 3 runs (cluster GPU is shared so single-run variance is high).

Config Latency KV BW A100 peak% vs torch eager
b=64 sk=4096 0.82 ms 490 GB/s 24% 122x
b=16 sk=4096 0.26 ms 287 GB/s 14% 84x
b=1 sk=65536 0.28 ms 276 GB/s 14% 68x
b=1 sk=16384 0.13 ms 149 GB/s 7% 38x

Full sweep + optimization milestones in docs/sm80-benchmark-2026-05-01.md.

Limitations / future work

  • 24% peak is below the SM90 path's ~80% on H800. The largest gap is the lack of TMA + WGMMA, not addressable on SM80.
  • SM80+SM90 combined builds currently fail because the sm90 sources use __launch_bounds__(N, M, K) (third arg is sm90 cluster). I left those untouched so this PR is minimal; happy to follow up with portable launch_bounds macros if maintainers prefer combined builds.
  • A CuTe-based rewrite was attempted and hit cute's "Stride Divisibility Condition" for the (head_dim=576 non-pow2, BLOCK_M=16, dynamic stride) combo — this version uses raw PTX. CuTe path is a candidate for follow-up if maintainers prefer that style.
  • No FP8 KV (sm89/sm90+ only). No sparse decode/prefill on SM80.

Modifications outside csrc/sm80/

File Change
csrc/api/api.cpp, common.h, dense_decode.h SM80 arch dispatch + Arch::is_sm80() helper. SM80 build only exposes dense_decode_fwd
csrc/smxx/decode/combine/combine.cu #if __CUDA_ARCH__ >= 900 guard around cudaGridDependencySynchronize() (PDL is sm90+)
csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu drop the sm90-only third arg from __launch_bounds__
setup.py FLASH_MLA_DISABLE_SM80 flag, sm80 source list, plus include paths from pip-installed nvidia-* wheels (system CUDA can lack cusparse.h)

Test plan

  • benchmark/bench_sm80_decode.py --check — correctness vs PyTorch eager (BF16/FP16 noise level, < 1e-3 max-abs-diff)
  • Sweep across (batch, sq, hq, hk, sk) configs (b=1..128, sk=256..65536, hk=1..4) — all PASS
  • BF16 and FP16 instantiations
  • is_causal=true for multi-token Q
  • Split-K via combine kernel (num_sm_parts > 1)
  • benchmark/profile_decode_step.py — DeepSeek-V3-shape attention block profile

Happy to make any code style / structure changes maintainers want before merge. If upstream prefers to keep the support matrix at SM90/SM100 only, this can also live as a maintained fork — see docs/sm80-distribution-strategy.md for that path.

Implements a CUDA kernel for the dense MLA decode path on SM80 GPUs.
Upstream currently supports SM90/SM100 only; this enables A100 deployment
without forcing migration to Hopper hardware.

Kernel design:
- BLOCK_M=16, 4 warpgroups x 1 warp x V-quarter split (HEAD_DIM_V/4 cols/wg)
- mma.m16n8k16 (BF16/FP16 with FP32 accumulator)
- cp.async with double-buffered sK + cross-block prefetch
- XOR swizzle (Swizzle<3,3,3>-style) for zero-bank-conflict SMEM access
- cp.async.cg (L1 bypass) for K loading
- SMEM 162 KB / CTA: 18 KB sQ + 2 x 72 KB sK, fits within 164 KB cap

Functional coverage:
- BF16 + FP16
- Multi-batch, multi-KV-head, causal mask
- Split-K via the existing combine kernel (no changes to combine path)
- Drop-in API compatibility: dense_decode_fwd signature unchanged

Performance (A100-SXM4-80GB, 2039 GB/s peak HBM):
- Peak: 490 GB/s on b=64 sk=4096 (24 percent of HBM peak)
- Long-seq: 276 GB/s on b=1 sk=65536
- 9-117x speedup vs PyTorch eager BMM reference across the sweep

Build:
  FLASH_MLA_DISABLE_SM100=1 FLASH_MLA_DISABLE_SM90=1 pip install -v .
SM80-only is the supported configuration; SM80+SM90 combined builds need
__launch_bounds__ portability fixes in upstream sm90 sources (deferred).

Tests:
- benchmark/bench_sm80_decode.py --check  (correctness vs torch eager)
- benchmark/profile_decode_step.py        (DeepSeek-V3-shape step profile)

Modifications outside csrc/sm80/:
- csrc/api/{api.cpp,common.h,dense_decode.h}: SM80 arch dispatch
- csrc/smxx/decode/combine/combine.cu: __CUDA_ARCH__>=900 guard for the
  PDL device intrinsic so the combine kernel compiles for sm_80
- csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu:
  drop the sm90-only third arg from __launch_bounds__
- setup.py: SM80 build flag and source list, plus include paths from
  pip-installed nvidia-* wheels (system CUDA may lack cusparse headers)
@bzantium bzantium force-pushed the sm80-dense-decode-port branch from de2151d to 4f5ac5b Compare May 1, 2026 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant