Skip to content

[Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer (fix MI355X DSV32 MTP MAF)#2866

Open
maeehart wants to merge 1 commit intoROCm:mainfrom
maeehart:fix/mqa-logits-oob-store-mask
Open

[Bug] pa_mqa_logits: mask OOB stores on OutLogits_buffer (fix MI355X DSV32 MTP MAF)#2866
maeehart wants to merge 1 commit intoROCm:mainfrom
maeehart:fix/mqa-logits-oob-store-mask

Conversation

@maeehart
Copy link
Copy Markdown

Summary

Fix 10 unmasked buffer_store(ptr=OutLogits_buffer, ...) sites in the Gluon
pa_mqa_logits preshuffle kernels that can overshoot the allocated
out_logits row on DeepSeek V3.2 MTP decodes, producing intermittent HIP
memory access faults on gfx950 (MI355X).

All 10 sites receive the minimal guard mask=<row-offset> < max_model_len,
matching the pattern of the adjacent (already-masked) stores in the same
kernels.

Root cause

Where the bug lives

aiter/ops/triton/gluon/pa_mqa_logits.py — functions
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle and
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx.

Each kernel writes per-(batch, next_n) rows into OutLogits_buffer of shape
(batch_size * next_n, max_model_len) with stride_out_batch == max_model_len.
Each buffer_store(ptr=OutLogits_buffer, offsets=row*stride + col_expr) is
valid only when col_expr < max_model_len; otherwise the write crosses into
the next row or past the tensor entirely.

Why col_expr can exceed max_model_len

split_context_length is rounded up to a KVBlockSize multiple when
LoadBlockIndiceForEachStage is true (see
aiter/ops/triton/gluon/pa_mqa_logits.py):

if LoadBlockIndiceForEachStage:
    split_context_block = tl.cdiv(split_context_length, KVBlockSize)
    split_context_length = split_context_block * KVBlockSize   # ↑ round-up

Combined with split_context_start + split_context_length <= max_model_len
when context_length < max_model_len, this works. But when
context_length == max_model_len (a full-context decode step) the rounded
split_context_length cannot grow and the same invariant still holds at the
split boundary — however the intra-split chunk index context_idx + ChunkKPerStage advances past the logical end. The 10 unmasked stores all
write at offsets of the form
context_idx + {ChunkK | ChunkKPerStage} + arange(ChunkKPerStage), which at
the last iteration of the prefix loop equals
split_context_start + split_context_length, i.e. exactly at or just past
the row end.

Concretely, for MI355X + DSV32 MTP at max_model_len=2048,
KVBlockSize=64, ChunkK=256, ChunkKPerStage=128: the overshoot is up to
ChunkKPerStage = 128 float32 elements = 512 B past the row. With
stride_out_batch = 2048, those 512 B land in the next row if one exists,
or in unrelated memory for the last (pid_batch, pid_next_n) program —
which is what faults.

Why the sibling stores already have the fix

Look at the existing masked store at line 274:

gl.amd.cdna3.buffer_store(
    logits,
    ptr=OutLogits_buffer,
    offsets=... + (context_idx + gl.arange(0, ChunkK, ...)),
    mask=context_idx + gl.arange(0, ChunkK, ...) >= 0,  # lower-bound
)

and at line 651 (in _preshuffle):

gl.amd.cdna3.buffer_store(
    logits,
    ptr=OutLogits_buffer,
    offsets=... + (context_idx + ChunkKPerStage + gl.arange(...)),
    mask=context_idx + ChunkKPerStage + gl.arange(...) >= split_context_start,
)

The pattern is clearly "mask the store when the offset is out of the
logical range". The author just missed the symmetric upper bound on 10
sites. This PR adds it.

List of fixed call sites

File line (pre-patch) Function Sibling/type
723 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle prefix loop, first store; partner at L759 already masked (<= context_length-next_n+...)
916 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle tl.where(mask, ..., -inf) applied to values but store unmasked
988 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle chunk-loop + ChunkKPerStage store
1062 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle chunk-loop + ChunkK store, tl.where(...) applied, store unmasked
1097 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle final + ChunkKPerStage, tl.where(...) applied, store unmasked
1534 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx mirror of L723 in varctx kernel
1727 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx mirror of L916
1799 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx mirror of L988
1873 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx mirror of L1062
1908 _gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx mirror of L1097

After this PR, all 18 buffer_store(ptr=OutLogits_buffer, ...) sites in
these two kernels are masked.

Why the mask choice < max_model_len (and not <= context_length-...)

There are two equally-correct bounds:

  1. Buffer bound: offset < max_model_len. Matches the allocation. The
    minimum mask that guarantees no OOB write.
  2. Logical bound: offset <= context_length - next_n + pid_next_n.
    Stricter; also masks the in-row lanes past the logical sequence end,
    where 5 of the 10 sites already store -inf via a preceding
    tl.where(mask, logits, -inf).

I chose (1) because:

  • Minimal behavioural diff. For the 5 sites with tl.where, the values
    beyond the logical end are already -inf; writing them is harmless
    because large_context_topk respects seq_lens and ignores them. So the
    only new behaviour with (1) is "don't write OOB bytes past the
    allocation" — which is exactly the bug fix and nothing more.
  • No new kernel state. max_model_len is already a kernel argument;
    context_length is also loaded, but using max_model_len makes the mask
    a compile-time stride check against the tensor, which is trivial to
    verify correct at code-review time.
  • Easier to reason about for reviewers inspecting this file a year from
    now.

The stricter mask (2) is also acceptable and is arguably cleaner
semantically; happy to switch if reviewers prefer.

Evidence

Before this PR (baseline)

DSV32-MTP1 decode with concurrency=4 on MI355X reproduces the MAF
deterministically within ~60 s of first trial, with
AMD_SERIALIZE_KERNEL=3 AMD_LOG_LEVEL=4 HIP_LAUNCH_BLOCKING=1:

Memory access fault by GPU node-X (Agent handle: 0x...) on address 0x...
Reason: Unknown.
...last kernel launched on all three faulting workers:
_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle

(Log excerpts in DSV32_MTP_ANALYSIS.md V5.1 entry.)

After this PR (equivalent vLLM-side workaround)

We validated the fix via an equivalent workaround in vLLM
(_get_paged_logits_buffer over-allocates the output tensor by +256
columns so that the OOB stores land in padding). This is equivalent in
effect to the in-kernel mask because both prevent the OOB address access;
we use it to prove the kernel's unmasked stores are indeed the only
problematic writes. 20/20 c=4 MTP=1 trials passed, each running ~30 s
of sustained decode with realistic prompts:

Trial Status p50 latency Avg draft acceptance
1 OK 30.4 s 0.84
2 OK 30.1 s 0.83
... OK ... ...
20 OK 30.2 s 0.85

Zero memory access faults, zero engine wedges, stable spec-decode acceptance
rate across all 20 runs. (Raw per-trial JSON + logs available on request.)

Why this validates the kernel patch

The vLLM workaround works by growing the buffer row stride to cols+256,
so that the same in-kernel offsets that were writing past max_model_len
now land in the +256 padding — i.e. they're still OOB with respect to the
logical row but in-bounds with respect to the allocation. This PR
instead keeps the allocation as-is but masks the stores that exceed the
logical row. Both prevent the HIP MAF; this PR is the proper in-kernel
fix.

Hardware overhead

buffer_store with a lane-predicate mask compiles to the same BUFFER_STORE_*
instruction as the unmasked variant on gfx950, with VCC set from the mask
expression. The mask expression is purely a cheap integer comparison of
already-computed quantities. Expected overhead: a few VALU cycles per
store, negligible relative to the full kernel cost.

Tests

No aiter unit tests cover the max_model_len == context_length corner
case for preshuffle; I'll happily add a regression test in a follow-up if
reviewers want one. The patch is structurally equivalent to the
already-masked sibling stores in the same kernels.

Context

This is part of enabling DeepSeek V3.2 MTP (speculative decoding with
num_speculative_tokens=1) on MI355X. See
vllm/v1/attention/ops/rocm_aiter_mla_sparse.py::_get_paged_logits_buffer
for the vLLM caller. A companion vLLM draft PR documents the same
investigation and links to this one as the preferred fix; it keeps a
temporary buffer over-allocation as defense-in-depth until this PR is
merged and rolled into an aiter release.

@maeehart maeehart requested review from a team and Copilot April 22, 2026 18:42
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2866 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Fixes intermittent GPU memory access faults in Gluon pa_mqa_logits preshuffle Triton kernels by preventing out-of-bounds writes to OutLogits_buffer on MI355X (gfx950) full-context decode paths.

Changes:

  • Add missing upper-bound lane masks (< max_model_len) to 10 buffer_store(ptr=OutLogits_buffer, ...) call sites.
  • Apply the same fix to both fixed-context and varctx preshuffle kernels.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +937 to +940
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
)
< max_model_len,
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new masks only enforce an upper bound. If context_idx (or derived column offsets) can ever be negative in this code path, offset < max_model_len will still be true and may allow negative column offsets (underflow) to be stored. To make the store guard symmetric and consistent with other masked stores in this file, consider AND-ing in the appropriate lower-bound predicate (e.g., >= 0 or >= split_context_start, whichever matches the logical indexing at this site).

Suggested change
context_idx
+ gl.arange(0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout))
)
< max_model_len,
(
context_idx
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
)
>= 0
)
& (
(
context_idx
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
)
< max_model_len
),

Copilot uses AI. Check for mistakes.
Comment on lines +738 to +745
mask=(
context_idx
+ ChunkK
+ gl.arange(
0, ChunkKPerStage, layout=gl.SliceLayout(0, mfma_layout)
)
)
< max_model_len,
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column-offset expression used for the mask is fairly long and is likely duplicated in the corresponding offsets= argument to buffer_store (same arithmetic + same gl.arange(...)). Consider computing this once into a local col_offsets (or similar) and reusing it for both offsets and mask. This reduces the risk of future divergence between the two expressions and typically helps readability (and can also avoid recomputing identical IR).

Copilot uses AI. Check for mistakes.
@maeehart
Copy link
Copy Markdown
Author

Cross-reference: vLLM-side companion draft PR is up at vllm-project/vllm#40643. That PR (1) adopts the new \deepgemm_fp8_paged_mqa_logits\ API on the DSv3.2 decode path, (2) caches the output buffer across the 61 decode layers, and (3) over-allocates the logits row stride by +256 float32 columns as defense-in-depth against the OOB store fixed here. Once this PR lands and a new aiter wheel is released, the +256 padding in vLLM can be set to 0 with zero functional change.

@azaidy azaidy requested a review from cagrikymk April 22, 2026 19:13
@valarLip valarLip requested a review from sjfeng1999 April 23, 2026 03:30
The gluon `_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle` and
`_gluon_deepgemm_fp8_paged_mqa_logits_preshuffle_varctx` kernels have 10
`buffer_store(ptr=OutLogits_buffer, ...)` call sites that are missing the
upper-bound mask present on their sibling stores.  When
`context_length == max_model_len` (the last-token position in a long-
context decode step), `split_context_length` is rounded UP to a
`KVBlockSize` multiple at line 427 and the final prefix/suffix store then
writes up to `ChunkKPerStage` float32 elements past the logical row end.
With `stride_out_batch == max_model_len`, those writes cross into the
next row / the next allocation, causing intermittent HIP memory-access
faults on gfx950 during DeepSeek V3.2 MTP decoding.

This change adds `mask=<offset> < max_model_len` to every unmasked
`buffer_store` on `OutLogits_buffer` in both preshuffle kernels, matching
the pattern of their already-masked neighbours.  The existing
`tl.where(..., -inf)` masking of the *values* is preserved; the only
behavioural change is that out-of-row lanes no longer emit buffer
stores.  Hardware overhead is negligible: `buffer_store` with a predicate
is the same SMEM descriptor path as the unmasked variant, just with a
VCC mask setup.

Repro + end-to-end fix evidence: see PR description.

Signed-off-by: Markus Hartikainen <markus.hartikainen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants