[Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer (fix MI355X DSV32 MTP MAF)#2866
[Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer (fix MI355X DSV32 MTP MAF)#2866
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Fixes intermittent GPU memory access faults in Gluon pa_mqa_logits preshuffle Triton kernels by preventing out-of-bounds writes to OutLogits_buffer on MI355X (gfx950) full-context decode paths.
Changes:
- Add missing upper-bound lane masks (
< max_model_len) to 10buffer_store(ptr=OutLogits_buffer, ...)call sites. - Apply the same fix to both fixed-context and varctx preshuffle kernels.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| context_idx | ||
| + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) | ||
| ) | ||
| < max_model_len, |
There was a problem hiding this comment.
These new masks only enforce an upper bound. If context_idx (or derived column offsets) can ever be negative in this code path, offset < max_model_len will still be true and may allow negative column offsets (underflow) to be stored. To make the store guard symmetric and consistent with other masked stores in this file, consider AND-ing in the appropriate lower-bound predicate (e.g., >= 0 or >= split_context_start, whichever matches the logical indexing at this site).
| context_idx | |
| + gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)) | |
| ) | |
| < max_model_len, | |
| ( | |
| context_idx | |
| + gl.arange( | |
| 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) | |
| ) | |
| ) | |
| >= 0 | |
| ) | |
| & ( | |
| ( | |
| context_idx | |
| + gl.arange( | |
| 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) | |
| ) | |
| ) | |
| < max_model_len | |
| ), |
| mask=( | ||
| context_idx | ||
| + ChunkK | ||
| + gl.arange( | ||
| 0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout) | ||
| ) | ||
| ) | ||
| < max_model_len, |
There was a problem hiding this comment.
The column-offset expression used for the mask is fairly long and is likely duplicated in the corresponding offsets= argument to buffer_store (same arithmetic + same gl.arange(...)). Consider computing this once into a local col_offsets (or similar) and reusing it for both offsets and mask. This reduces the risk of future divergence between the two expressions and typically helps readability (and can also avoid recomputing identical IR).
|
Cross-reference: vLLM-side companion draft PR is up at vllm-project/vllm#40643. That PR (1) adopts the new \deepgemm_fp8_paged_mqa_logits\ API on the DSv3.2 decode path, (2) caches the output buffer across the 61 decode layers, and (3) over-allocates the logits row stride by +256 float32 columns as defense-in-depth against the OOB store fixed here. Once this PR lands and a new aiter wheel is released, the +256 padding in vLLM can be set to 0 with zero functional change. |
b946416 to
e769aa2
Compare
The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10 `buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the upper-bound mask present on their sibling stores. When `context_length == max_model_len` (the last-token position in a long- context decode step), `split_context_length` is rounded UP to a `KVBlockSize` multiple at line 427 and the final prefix/suffix store then writes up to `ChunkKPerStage` float32 elements past the logical row end. With `stride_out_batch == max_model_len`, those writes cross into the next row / the next allocation, causing intermittent HIP memory-access faults on gfx950 during DeepSeek V3.2 MTP decoding. This change adds `mask=<offset> < max_model_len` to every unmasked `buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching the pattern of their already-masked neighbours. The existing `tl.where(..., -inf)` masking of the *values* is preserved; the only behavioural change is that out-of-row lanes no longer emit buffer stores. Hardware overhead is negligible: `buffer_store` with a predicate is the same SMEM descriptor path as the unmasked variant, just with a VCC mask setup. Repro + end-to-end fix evidence: see PR description. Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
e769aa2 to
275ceff
Compare
Summary
Fix 10 unmasked
buffer_store(ptr=OutLogits_buffer, ...)sites in the Gluonpa_mqa_logitspreshuffle kernels that can overshoot the allocatedout_logitsrow on DeepSeek V3.2 MTP decodes, producing intermittent HIPmemory access faults on gfx950 (MI355X).
All 10 sites receive the minimal guard
mask=<row-offset> < max_model_len,matching the pattern of the adjacent (already-masked) stores in the same
kernels.
Root cause
Where the bug lives
aiter/ops/triton/gluon/pa_mqa_logits.py— functions_gluon_deepgemm_fp8_paged_mqa_logits_preshuffleand_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx.Each kernel writes per-(batch, next_n) rows into
OutLogits_bufferof shape(batch_size * next_n, max_model_len)withstride_out_batch == max_model_len.Each
buffer_store(ptr=OutLogits_buffer, offsets=row*stride + col_expr)isvalid only when
col_expr < max_model_len; otherwise the write crosses intothe next row or past the tensor entirely.
Why
col_exprcan exceedmax_model_lensplit_context_lengthis rounded up to aKVBlockSizemultiple whenLoadBlockIndiceForEachStageis true (seeaiter/ops/triton/gluon/pa_mqa_logits.py):Combined with
split_context_start + split_context_length <= max_model_lenwhen
context_length < max_model_len, this works. But whencontext_length == max_model_len(a full-context decode step) the roundedsplit_context_lengthcannot grow and the same invariant still holds at thesplit boundary — however the intra-split chunk index
context_idx + ChunkKPerStageadvances past the logical end. The 10 unmasked stores allwrite at offsets of the form
context_idx + {ChunkK | ChunkKPerStage} + arange(ChunkKPerStage), which atthe last iteration of the prefix loop equals
split_context_start + split_context_length, i.e. exactly at or just pastthe row end.
Concretely, for MI355X + DSV32 MTP at
max_model_len=2048,KVBlockSize=64,ChunkK=256,ChunkKPerStage=128: the overshoot is up toChunkKPerStage = 128float32 elements = 512 B past the row. Withstride_out_batch = 2048, those 512 B land in the next row if one exists,or in unrelated memory for the last
(pid_batch, pid_next_n)program —which is what faults.
Why the sibling stores already have the fix
Look at the existing masked store at line 274:
and at line 651 (in
_preshuffle):The pattern is clearly "mask the store when the offset is out of the
logical range". The author just missed the symmetric upper bound on 10
sites. This PR adds it.
List of fixed call sites
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle<= context_length-next_n+...)_gluon_deepgemm_fp8_paged_mqa_logits_preshuffletl.where(mask, ..., -inf)applied to values but store unmasked_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle+ ChunkKPerStagestore_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle+ ChunkKstore,tl.where(...)applied, store unmasked_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle+ ChunkKPerStage,tl.where(...)applied, store unmasked_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctxAfter this PR, all 18
buffer_store(ptr=OutLogits_buffer, ...)sites inthese two kernels are masked.
Why the mask choice
< max_model_len(and not<= context_length-...)There are two equally-correct bounds:
offset < max_model_len. Matches the allocation. Theminimum mask that guarantees no OOB write.
offset <= context_length - next_n + pid_next_n.Stricter; also masks the in-row lanes past the logical sequence end,
where 5 of the 10 sites already store
-infvia a precedingtl.where(mask, logits, -inf).I chose (1) because:
tl.where, the valuesbeyond the logical end are already
-inf; writing them is harmlessbecause
large_context_topkrespectsseq_lensand ignores them. So theonly new behaviour with (1) is "don't write OOB bytes past the
allocation" — which is exactly the bug fix and nothing more.
max_model_lenis already a kernel argument;context_lengthis also loaded, but usingmax_model_lenmakes the maska compile-time stride check against the tensor, which is trivial to
verify correct at code-review time.
now.
The stricter mask (2) is also acceptable and is arguably cleaner
semantically; happy to switch if reviewers prefer.
Evidence
Before this PR (baseline)
DSV32-MTP1 decode with concurrency=4 on MI355X reproduces the MAF
deterministically within ~60 s of first trial, with
AMD_SERIALIZE_KERNEL=3 AMD_LOG_LEVEL=4 HIP_LAUNCH_BLOCKING=1:(Log excerpts in
DSV32_MTP_ANALYSIS.mdV5.1 entry.)After this PR (equivalent vLLM-side workaround)
We validated the fix via an equivalent workaround in vLLM
(
_get_paged_logits_bufferover-allocates the output tensor by +256columns so that the OOB stores land in padding). This is equivalent in
effect to the in-kernel mask because both prevent the OOB address access;
we use it to prove the kernel's unmasked stores are indeed the only
problematic writes. 20/20 c=4 MTP=1 trials passed, each running ~30 s
of sustained decode with realistic prompts:
Zero memory access faults, zero engine wedges, stable spec-decode acceptance
rate across all 20 runs. (Raw per-trial JSON + logs available on request.)
Why this validates the kernel patch
The vLLM workaround works by growing the buffer row stride to
cols+256,so that the same in-kernel offsets that were writing past
max_model_lennow land in the +256 padding — i.e. they're still OOB with respect to the
logical row but in-bounds with respect to the allocation. This PR
instead keeps the allocation as-is but masks the stores that exceed the
logical row. Both prevent the HIP MAF; this PR is the proper in-kernel
fix.
Hardware overhead
buffer_storewith a lane-predicate mask compiles to the sameBUFFER_STORE_*instruction as the unmasked variant on gfx950, with VCC set from the mask
expression. The mask expression is purely a cheap integer comparison of
already-computed quantities. Expected overhead: a few VALU cycles per
store, negligible relative to the full kernel cost.
Tests
No aiter unit tests cover the
max_model_len == context_lengthcornercase for preshuffle; I'll happily add a regression test in a follow-up if
reviewers want one. The patch is structurally equivalent to the
already-masked sibling stores in the same kernels.
Context
This is part of enabling DeepSeek V3.2 MTP (speculative decoding with
num_speculative_tokens=1) on MI355X. Seevllm/v1/attention/ops/rocm_aiter_mla_sparse.py::_get_paged_logits_bufferfor the vLLM caller. A companion vLLM draft PR documents the same
investigation and links to this one as the preferred fix; it keeps a
temporary buffer over-allocation as defense-in-depth until this PR is
merged and rolled into an aiter release.