Skip to content

[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983

Open
frida-andersson wants to merge 1 commit intoROCm:mainfrom
frida-andersson:fix/mla-nhead32-nonpersistent-crash
Open

[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983
frida-andersson wants to merge 1 commit intoROCm:mainfrom
frida-andersson:fix/mla-nhead32-nonpersistent-crash

Conversation

@frida-andersson
Copy link
Copy Markdown
Contributor

Summary

Fixes a GPU memory access fault when running MLA decode with nhead=32 (DeepSeek-V3.2 at TP4) in non-persistent mode on MI355X (gfx950).

Root Cause

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)" #2729) zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. However, the legacy QH16 ASM kernel (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) used for nhead=32 still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes:

Memory access fault by GPU node-X on address 0xNNNNNN. Reason: Write access to a read-only page.

This crashes during CUDA graph capture (decode, FULL).

Fix

C++ (csrc/py_itfs_cu/asm_mla.cu):

Python (aiter/mla.py):

  • Restore the bf16 nhead in [32, 64] early-return after stage1 when num_kv_splits==1. Without this, stage2 overwrites the kernel's direct output with garbage from the uninitialized split buffer.

Both changes match the behavior from v0.1.11 for the affected code paths.

Test

  • MI355X (gfx950), TP4, DeepSeek-V3.2
  • No crash during CUDA graph capture
  • GSM8K accuracy correct (0.94+)

Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64,
qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all
non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32
(MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes
directly to the output buffer via ptr_RP when kv_split==1.
Dereferencing nullptr causes a GPU memory access fault during CUDA
graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4.

Fix:
- Conditionally restore ptr_RP and out_16_nosplit in the non-persistent
  path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while
  keeping nullptr for newer kernels (e.g. gqa_ratio=64).
- Restore the bf16 nhead in [32,64] early-return after stage1 when
  num_kv_splits==1 to prevent stage2 from overwriting the kernel's
  direct output.

Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32):
- No crash during CUDA graph capture
- Correct GSM8K accuracy

Made-with: Cursor
@frida-andersson frida-andersson requested a review from a team April 30, 2026 14:38
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2983 --add-label <label>

args.out_16_nosplit = 0;
args.ptr_RP = nullptr;
// Legacy QH16 ASM kernels (nhead=32/64, qseqlen=1) write directly to
// output via ptr_RP when kv_split==1. Passing nullptr causes GPU
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It really sounds to me like we are missing a test for this scenario in aiter. There should be a more transparent way of distinguishing kernels that write these but by the number of QHs.

@ChuanLi1101
Copy link
Copy Markdown

ChuanLi1101 commented May 3, 2026

Independently verified on MI355X (gfx950) inside rocm/atom-dev:vllm-v0.19.0-nightly_20260422 (ROCm 7.2.2, torch 2.10.0+rocm7.2.2.git40d237bf, aiter HEAD 9522c04).

Repro path — exactly the failure mode this PR describes:

  • nhead=32, bf16/bf16, num_kv_splits=1, decode_qlen=1, non-persistent (no work_meta_data)
  • BEFORE the patch: Memory access fault by GPU node-2 ... address (nil) on the first call into mla_a16w16_qh16_m32x1_n16x1_coex0_mask1. Process aborts (exit 134).
  • AFTER the patch (after rebuilding module_mla_asm so the asm_mla.cu change is actually picked up): 12/12 cases pass for nhead in {32, 64} x ctx in {256, 1024} x bs in {1, 4, 16}, all numerically match the torch reference at atol=rtol=1e-2, latency 13-67 us. Same kernel as before, just with a non-null ptr_RP.
pr2983 nhead=32 ctx=256  bs=1  splits=1:    21.02 us  passed
pr2983 nhead=32 ctx=256  bs=4  splits=1:    21.82 us  passed
pr2983 nhead=32 ctx=256  bs=16 splits=1:    22.81 us  passed
pr2983 nhead=32 ctx=1024 bs=1  splits=1:    65.40 us  passed
pr2983 nhead=32 ctx=1024 bs=4  splits=1:    65.70 us  passed
pr2983 nhead=32 ctx=1024 bs=16 splits=1:    67.53 us  passed
pr2983 nhead=64 ctx=256  bs=1  splits=1:    13.28 us  passed
pr2983 nhead=64 ctx=256  bs=4  splits=1:    13.95 us  passed
pr2983 nhead=64 ctx=256  bs=16 splits=1:    15.11 us  passed
pr2983 nhead=64 ctx=1024 bs=1  splits=1:    35.15 us  passed
pr2983 nhead=64 ctx=1024 bs=4  splits=1:    36.00 us  passed
pr2983 nhead=64 ctx=1024 bs=16 splits=1:    38.28 us  passed

The diff is narrowly scoped: 19 lines / 2 files, only the non-persistent host wrapper (asm_mla.cu) and the matching nhead in [32, 64] branch in the python entrypoint. No kernel change, no impact on the persistent / sparse / unified-attention paths.

Could a maintainer apply the ready label so review CI can run? @sunway513 — given the v0.1.13 / vLLM 0.21 freeze, this looks like a low-risk addition to the #3005 bulk merge.

Note for anyone reproducing locally: AITER_REBUILD=1 alone won't pick up the asm_mla.cu change — mla_decode_stage1_asm_fwd uses the ctypes ffi path which only checks for the .so's existence. To force a rebuild after applying the patch:

rm -f aiter/jit/module_mla_asm.so
rm -rf aiter/jit/build/module_mla_asm
python op_tests/test_mla_nhead32_regression.py

sunway513 added a commit that referenced this pull request May 3, 2026
Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32
non-persistent decode fix causes deterministic test_mla k_cache and
mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2).

#2983 should go through its own PR with proper CI validation by the
original author (frida-andersson).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants