fix(hip): correct bf16 rowsum MMA and custom-mask MFMA row indexing on CDNA3#214
Conversation
82dad85 to
cdb682b
Compare
There was a problem hiding this comment.
Pull request overview
Fixes two correctness bugs in the HIP/ROCm FlashAttention v2 path on CDNA3 (MI300x/gfx942) that could silently produce wrong results for bf16 rowsum and for custom attention masks under GQA.
Changes:
- Fix bf16 rowsum MFMA dispatch and bf16 “ones” operand packing in
m16k16_rowsum_f16f16f32. - Fix packed query row indexing for CDNA3 MFMA C/D register layout in the custom-mask/GQA path (2 call sites).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
include/gpu_iface/backend/hip/mma_hip.h |
Adds bf16-specific MFMA intrinsic dispatch and correct bf16 packing for the rowsum path. |
include/flashinfer/attention/generic/prefill.cuh |
Updates packed query index computation to contiguous-band mapping for correct per-head mask/transform application under GQA. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
3c907ac to
1addc8f
Compare
diptorupd
left a comment
There was a problem hiding this comment.
Thank you very much for the changes. These two bugs have indeed been pending for a very long time and I am glad you fixed the issues.
Please address the last two review comments and we should be good to merge the changes.
I have tested your changes locally with my suggested changes and everything works as expected.
…n CDNA3 Two bugs in the HIP/ROCm attention kernels produce silently wrong results on AMD MI300x (gfx942) when using bfloat16 and/or custom attention masks with GQA. Bug 1 -- bf16 rowsum MMA type confusion (mma_hip.h): m16k16_rowsum_f16f16f32() unconditionally calls the fp16 MFMA intrinsic (__builtin_amdgcn_mfma_f32_16x16x16f16) and constructs the B operand as _Float16(1.0), even when DType is __hip_bfloat16. The hardware reinterprets the bf16(1.0) bit pattern 0x3F80 as fp16(1.875), inflating the softmax denominator by 1.875x per tile accumulation. Fix: dispatch to __builtin_amdgcn_mfma_f32_16x16x16bf16_1k for bfloat16 with correctly packed bf16(1.0) B operands (0x3F803F80). Bug 2 -- MFMA C/D register row-mapping in custom mask path (prefill.cuh): logits_transform() and logits_mask() compute the packed query index using an interleaved pattern (lane_idx/TPR + LIS*j) that does not match the CDNA3 MFMA C/D register layout, which maps each thread to 4 contiguous rows. This causes the custom mask and logits transform to be applied to the wrong GQA head for 3 out of every 4 packed query rows. The kNone (no mask) path is unaffected. Fix: use contiguous-band indexing (lane_idx/TPR)*NAPTR + j in both logits_transform() and logits_mask(). Tested on AMD Instinct MI300x with BatchPrefillWithPagedKVCacheWrapper using bfloat16, custom masks, and GQA (group_size > 1). Both BatchDecode and BatchPrefill now match PyTorch SDPA reference outputs. Made-with: Cursor
Add regression tests that would have caught the two kernel bugs fixed in the previous commit: 1. test_batch_prefill_paged_kv_bf16_correctness: runs BatchPrefill with bfloat16 and compares against a naive PyTorch attention reference. The existing HIP tests only used fp16, so the bf16 rowsum bug (m16k16_rowsum using the wrong MFMA intrinsic) was never exercised. 2. test_batch_prefill_custom_mask_gqa_correctness: runs BatchPrefill with a packed custom mask and GQA (group_size > 1) and compares against a naive PyTorch reference. The existing tests never used custom masks, so the MFMA row-mapping bug in logits_mask/ logits_transform was never exercised. 3. test_custom_mask_all_true_matches_causal: verifies that an all-True custom mask (kCustom path) produces identical output to causal=True (kNone path). This is a direct regression test for Bug 2, which caused these two paths to diverge. Made-with: Cursor
1addc8f to
78d34d2
Compare
|
Test results after these changes: |
Summary
Two bugs in the HIP/ROCm FA2 attention kernels produce silently wrong results on AMD MI300x (gfx942) when using bfloat16 and/or custom attention masks with GQA.
Bug 1: bf16 rowsum MMA type confusion (
mma_hip.h)m16k16_rowsum_f16f16f32()unconditionally calls the fp16 MFMA intrinsic (__builtin_amdgcn_mfma_f32_16x16x16f16) and constructs the B operand as_Float16(1.0), even whenDTypeis__hip_bfloat16. The hardware reinterprets thebf16(1.0)bit pattern0x3F80asfp16(1.875), inflating the softmax denominator by 1.875x per tile accumulation.Symptom: Draft token acceptance rate in speculative decoding collapsed to <10% (vs ~65% expected). BatchPrefill outputs are systematically underscaled vs PyTorch SDPA reference.
Fix: Dispatch to
__builtin_amdgcn_mfma_f32_16x16x16bf16_1kfor bfloat16 with correctly packedbf16(1.0)B operands (0x3F803F80).Note:
mma_sync_m16n16k16_row_col_f16f16f32in the same file already correctly dispatches between fp16 and bf16 -- onlym16k16_rowsumwas missed.Bug 2: MFMA C/D register row-mapping in custom mask path (
prefill.cuh)logits_transform()andlogits_mask()compute the packed query index using an interleaved pattern (lane_idx/TPR + LIS*j) that does not match the CDNA3 MFMA C/D register layout. On MI300x, themfma_f32_16x16x16instruction maps each thread to 4 contiguous rows (threadcowns rows4c, 4c+1, 4c+2, 4c+3), not interleaved.This causes the custom mask and logits transform to be applied to the wrong GQA head for 3 out of every 4 packed query rows when
group_size > 1. ThekNone(no mask) path is unaffected because it doesn't usegroup_size.divmod().Diagnostic observations:
kNonemode: cosine similarity to SDPA reference ~ 1.0 (correct)kCustomwith all-True mask: cosine similarity to SDPA reference ~ 0.41 (broken)r == 0correct; heads withr >= 1wronggroup_size == 1(no GQA): all heads correctFix: Use contiguous-band indexing
(lane_idx/TPR) * NAPTR + jin bothlogits_transform()andlogits_mask(). Note thatwrite_lse()andwrite_o_reg_gmem()in the same file already use the correct contiguous mapping.Test plan
Files changed
include/gpu_iface/backend/hip/mma_hip.h-- bf16 rowsum fixinclude/flashinfer/attention/generic/prefill.cuh-- MFMA row-mapping fix (2 locations)