Add SM80 (Ampere/A100) dense MLA decode kernel#183
Open
bzantium wants to merge 1 commit intodeepseek-ai:mainfrom
Open
Add SM80 (Ampere/A100) dense MLA decode kernel#183bzantium wants to merge 1 commit intodeepseek-ai:mainfrom
bzantium wants to merge 1 commit intodeepseek-ai:mainfrom
Conversation
Implements a CUDA kernel for the dense MLA decode path on SM80 GPUs.
Upstream currently supports SM90/SM100 only; this enables A100 deployment
without forcing migration to Hopper hardware.
Kernel design:
- BLOCK_M=16, 4 warpgroups x 1 warp x V-quarter split (HEAD_DIM_V/4 cols/wg)
- mma.m16n8k16 (BF16/FP16 with FP32 accumulator)
- cp.async with double-buffered sK + cross-block prefetch
- XOR swizzle (Swizzle<3,3,3>-style) for zero-bank-conflict SMEM access
- cp.async.cg (L1 bypass) for K loading
- SMEM 162 KB / CTA: 18 KB sQ + 2 x 72 KB sK, fits within 164 KB cap
Functional coverage:
- BF16 + FP16
- Multi-batch, multi-KV-head, causal mask
- Split-K via the existing combine kernel (no changes to combine path)
- Drop-in API compatibility: dense_decode_fwd signature unchanged
Performance (A100-SXM4-80GB, 2039 GB/s peak HBM):
- Peak: 490 GB/s on b=64 sk=4096 (24 percent of HBM peak)
- Long-seq: 276 GB/s on b=1 sk=65536
- 9-117x speedup vs PyTorch eager BMM reference across the sweep
Build:
FLASH_MLA_DISABLE_SM100=1 FLASH_MLA_DISABLE_SM90=1 pip install -v .
SM80-only is the supported configuration; SM80+SM90 combined builds need
__launch_bounds__ portability fixes in upstream sm90 sources (deferred).
Tests:
- benchmark/bench_sm80_decode.py --check (correctness vs torch eager)
- benchmark/profile_decode_step.py (DeepSeek-V3-shape step profile)
Modifications outside csrc/sm80/:
- csrc/api/{api.cpp,common.h,dense_decode.h}: SM80 arch dispatch
- csrc/smxx/decode/combine/combine.cu: __CUDA_ARCH__>=900 guard for the
PDL device intrinsic so the combine kernel compiles for sm_80
- csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu:
drop the sm90-only third arg from __launch_bounds__
- setup.py: SM80 build flag and source list, plus include paths from
pip-installed nvidia-* wheels (system CUDA may lack cusparse headers)
de2151d to
4f5ac5b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements a CUDA kernel for the dense MLA decode path on SM80 GPUs (Ampere / A100). The current README excludes Ampere; this PR enables A100 deployment without forcing migration to Hopper.
dense_decode_fwdAPI and thecombinekernel are unchanged. Only the decode kernel itself is new.FLASH_MLA_DISABLE_SM100=1 FLASH_MLA_DISABLE_SM90=1 pip install -v .for A100-only nodes.Why
The SM90 kernel relies on TMA, WGMMA, thread block clusters, and
mbarrierasync barriers — all sm90+ only. SM80 must usecp.async+mma.m16n8k16+ldmatrix+cutlass::arch::NamedBarrier. This is a separate kernel rather than a SM90 fallback path.Design
BLOCK_M=16(one warp covers M via mma.m16n8k16); each wg ownsHEAD_DIM_V/4 = 128V columnscp.asyncprefetchSwizzle<3,3,3>-equivalent) for SMEM bank-conflict-freeldmatrixcp.async.cg(L1 bypass) for K loading;cp.async.cafor Q (reused across K iterations)Detailed design rationale and the SMEM budget tradeoffs are in
docs/sm80-dense-decode-design.md.Performance
A100-SXM4-80GB, peak HBM 2039 GB/s. Median of 3 runs (cluster GPU is shared so single-run variance is high).
Full sweep + optimization milestones in
docs/sm80-benchmark-2026-05-01.md.Limitations / future work
__launch_bounds__(N, M, K)(third arg is sm90 cluster). I left those untouched so this PR is minimal; happy to follow up with portable launch_bounds macros if maintainers prefer combined builds.(head_dim=576 non-pow2, BLOCK_M=16, dynamic stride)combo — this version uses raw PTX. CuTe path is a candidate for follow-up if maintainers prefer that style.Modifications outside
csrc/sm80/csrc/api/api.cpp,common.h,dense_decode.hArch::is_sm80()helper. SM80 build only exposesdense_decode_fwdcsrc/smxx/decode/combine/combine.cu#if __CUDA_ARCH__ >= 900guard aroundcudaGridDependencySynchronize()(PDL is sm90+)csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu__launch_bounds__setup.pyFLASH_MLA_DISABLE_SM80flag, sm80 source list, plus include paths from pip-installednvidia-*wheels (system CUDA can lackcusparse.h)Test plan
benchmark/bench_sm80_decode.py --check— correctness vs PyTorch eager (BF16/FP16 noise level, < 1e-3 max-abs-diff)(batch, sq, hq, hk, sk)configs (b=1..128, sk=256..65536, hk=1..4) — all PASSis_causal=truefor multi-token Qnum_sm_parts > 1)benchmark/profile_decode_step.py— DeepSeek-V3-shape attention block profileHappy to make any code style / structure changes maintainers want before merge. If upstream prefers to keep the support matrix at SM90/SM100 only, this can also live as a maintained fork — see
docs/sm80-distribution-strategy.mdfor that path.