[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983
[MLA] Fix nhead=32 non-persistent decode crash on gfx950#2983frida-andersson wants to merge 1 commit intoROCm:mainfrom
Conversation
Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)") zeroed ptr_RP and out_16_nosplit for all non-persistent dispatch. The legacy QH16 ASM kernel used for nhead=32 (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) still writes directly to the output buffer via ptr_RP when kv_split==1. Dereferencing nullptr causes a GPU memory access fault during CUDA graph capture on MI355X (gfx950) with DeepSeek-V3.2 at TP4. Fix: - Conditionally restore ptr_RP and out_16_nosplit in the non-persistent path for legacy kernels (gqa_ratio * max_seqlen_q <= 64) while keeping nullptr for newer kernels (e.g. gqa_ratio=64). - Restore the bf16 nhead in [32,64] early-return after stage1 when num_kv_splits==1 to prevent stage2 from overwriting the kernel's direct output. Tested on MI355X TP4 with deepseek-ai/DeepSeek-V3.2 (nhead=32): - No crash during CUDA graph capture - Correct GSM8K accuracy Made-with: Cursor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
| args.out_16_nosplit = 0; | ||
| args.ptr_RP = nullptr; | ||
| // Legacy QH16 ASM kernels (nhead=32/64, qseqlen=1) write directly to | ||
| // output via ptr_RP when kv_split==1. Passing nullptr causes GPU |
There was a problem hiding this comment.
It really sounds to me like we are missing a test for this scenario in aiter. There should be a more transparent way of distinguishing kernels that write these but by the number of QHs.
|
Independently verified on MI355X (gfx950) inside Repro path — exactly the failure mode this PR describes:
The diff is narrowly scoped: 19 lines / 2 files, only the non-persistent host wrapper ( Could a maintainer apply the Note for anyone reproducing locally: rm -f aiter/jit/module_mla_asm.so
rm -rf aiter/jit/build/module_mla_asm
python op_tests/test_mla_nhead32_regression.py |
Reverting cherry-pick of #2983 from this bulk merge. The MLA nhead=32 non-persistent decode fix causes deterministic test_mla k_cache and mla_decode-absorb precision failures on CI MI35X runners (Shard 1 & 2). #2983 should go through its own PR with proper CI validation by the original author (frida-andersson).
Summary
Fixes a GPU memory access fault when running MLA decode with
nhead=32(DeepSeek-V3.2 at TP4) in non-persistent mode on MI355X (gfx950).Root Cause
Commit c849fd5 ("Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)" #2729) zeroed
ptr_RPandout_16_nosplitfor all non-persistent dispatch. However, the legacy QH16 ASM kernel (MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co) used for nhead=32 still writes directly to the output buffer viaptr_RPwhenkv_split==1. Dereferencing nullptr causes:This crashes during CUDA graph capture (decode, FULL).
Fix
C++ (
csrc/py_itfs_cu/asm_mla.cu):ptr_RPandout_16_nosplitfor legacy kernels (gqa_ratio * max_seqlen_q <= 64) while keeping them as nullptr/0 for newer kernels (e.g. the gqa_ratio=64 kernel added in Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent) #2729).Python (
aiter/mla.py):nhead in [32, 64]early-return after stage1 whennum_kv_splits==1. Without this, stage2 overwrites the kernel's direct output with garbage from the uninitialized split buffer.Both changes match the behavior from v0.1.11 for the affected code paths.
Test