Skip to content

fix(hip): correct bf16 rowsum MMA and custom-mask MFMA row indexing on CDNA3#214

Merged
diptorupd merged 2 commits intoROCm:amd-integrationfrom
subhajitdchow:fix/hip-bf16-rowsum-and-custom-mask-mfma
Apr 14, 2026
Merged

fix(hip): correct bf16 rowsum MMA and custom-mask MFMA row indexing on CDNA3#214
diptorupd merged 2 commits intoROCm:amd-integrationfrom
subhajitdchow:fix/hip-bf16-rowsum-and-custom-mask-mfma

Conversation

@subhajitdchow
Copy link
Copy Markdown

Summary

Two bugs in the HIP/ROCm FA2 attention kernels produce silently wrong results on AMD MI300x (gfx942) when using bfloat16 and/or custom attention masks with GQA.

Bug 1: bf16 rowsum MMA type confusion (mma_hip.h)

m16k16_rowsum_f16f16f32() unconditionally calls the fp16 MFMA intrinsic (__builtin_amdgcn_mfma_f32_16x16x16f16) and constructs the B operand as _Float16(1.0), even when DType is __hip_bfloat16. The hardware reinterprets the bf16(1.0) bit pattern 0x3F80 as fp16(1.875), inflating the softmax denominator by 1.875x per tile accumulation.

Symptom: Draft token acceptance rate in speculative decoding collapsed to <10% (vs ~65% expected). BatchPrefill outputs are systematically underscaled vs PyTorch SDPA reference.

Fix: Dispatch to __builtin_amdgcn_mfma_f32_16x16x16bf16_1k for bfloat16 with correctly packed bf16(1.0) B operands (0x3F803F80).

Note: mma_sync_m16n16k16_row_col_f16f16f32 in the same file already correctly dispatches between fp16 and bf16 -- only m16k16_rowsum was missed.

Bug 2: MFMA C/D register row-mapping in custom mask path (prefill.cuh)

logits_transform() and logits_mask() compute the packed query index using an interleaved pattern (lane_idx/TPR + LIS*j) that does not match the CDNA3 MFMA C/D register layout. On MI300x, the mfma_f32_16x16x16 instruction maps each thread to 4 contiguous rows (thread c owns rows 4c, 4c+1, 4c+2, 4c+3), not interleaved.

This causes the custom mask and logits transform to be applied to the wrong GQA head for 3 out of every 4 packed query rows when group_size > 1. The kNone (no mask) path is unaffected because it doesn't use group_size.divmod().

Diagnostic observations:

  • kNone mode: cosine similarity to SDPA reference ~ 1.0 (correct)
  • kCustom with all-True mask: cosine similarity to SDPA reference ~ 0.41 (broken)
  • Per-head: GQA heads with r == 0 correct; heads with r >= 1 wrong
  • group_size == 1 (no GQA): all heads correct

Fix: Use contiguous-band indexing (lane_idx/TPR) * NAPTR + j in both logits_transform() and logits_mask(). Note that write_lse() and write_o_reg_gmem() in the same file already use the correct contiguous mapping.

Test plan

  • BatchPrefill bf16 outputs match PyTorch SDPA reference (max abs diff < 0.05)
  • BatchDecode still passes (was already correct)
  • Custom mask with GQA (group_size > 1): per-head outputs correct for all r values
  • All-True custom mask matches kNone output
  • Tested on AMD Instinct MI300x (gfx942), ROCm 7.1.1, PyTorch 2.8

Files changed

  • include/gpu_iface/backend/hip/mma_hip.h -- bf16 rowsum fix
  • include/flashinfer/attention/generic/prefill.cuh -- MFMA row-mapping fix (2 locations)

@subhajitdchow subhajitdchow requested a review from diptorupd April 13, 2026 20:21
@subhajitdchow subhajitdchow force-pushed the fix/hip-bf16-rowsum-and-custom-mask-mfma branch from 82dad85 to cdb682b Compare April 13, 2026 20:25
@diptorupd diptorupd requested a review from Copilot April 13, 2026 20:57
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes two correctness bugs in the HIP/ROCm FlashAttention v2 path on CDNA3 (MI300x/gfx942) that could silently produce wrong results for bf16 rowsum and for custom attention masks under GQA.

Changes:

  • Fix bf16 rowsum MFMA dispatch and bf16 “ones” operand packing in m16k16_rowsum_f16f16f32.
  • Fix packed query row indexing for CDNA3 MFMA C/D register layout in the custom-mask/GQA path (2 call sites).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
include/gpu_iface/backend/hip/mma_hip.h Adds bf16-specific MFMA intrinsic dispatch and correct bf16 packing for the rowsum path.
include/flashinfer/attention/generic/prefill.cuh Updates packed query index computation to contiguous-band mapping for correct per-head mask/transform application under GQA.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread include/gpu_iface/backend/hip/mma_hip.h
Comment thread include/gpu_iface/backend/hip/mma_hip.h Outdated
Comment thread include/flashinfer/attention/generic/prefill.cuh
Comment thread include/flashinfer/attention/generic/prefill.cuh
Comment thread include/flashinfer/attention/generic/prefill.cuh Outdated
@subhajitdchow subhajitdchow force-pushed the fix/hip-bf16-rowsum-and-custom-mask-mfma branch from 3c907ac to 1addc8f Compare April 14, 2026 04:20
Copy link
Copy Markdown
Collaborator

@diptorupd diptorupd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for the changes. These two bugs have indeed been pending for a very long time and I am glad you fixed the issues.

Please address the last two review comments and we should be good to merge the changes.

I have tested your changes locally with my suggested changes and everything works as expected.

Comment thread include/gpu_iface/backend/hip/mma_hip.h
Comment thread include/gpu_iface/backend/hip/mma_hip.h Outdated
…n CDNA3

Two bugs in the HIP/ROCm attention kernels produce silently wrong results
on AMD MI300x (gfx942) when using bfloat16 and/or custom attention masks
with GQA.

Bug 1 -- bf16 rowsum MMA type confusion (mma_hip.h):
  m16k16_rowsum_f16f16f32() unconditionally calls the fp16 MFMA intrinsic
  (__builtin_amdgcn_mfma_f32_16x16x16f16) and constructs the B operand as
  _Float16(1.0), even when DType is __hip_bfloat16. The hardware reinterprets
  the bf16(1.0) bit pattern 0x3F80 as fp16(1.875), inflating the softmax
  denominator by 1.875x per tile accumulation.

  Fix: dispatch to __builtin_amdgcn_mfma_f32_16x16x16bf16_1k for bfloat16
  with correctly packed bf16(1.0) B operands (0x3F803F80).

Bug 2 -- MFMA C/D register row-mapping in custom mask path (prefill.cuh):
  logits_transform() and logits_mask() compute the packed query index using
  an interleaved pattern (lane_idx/TPR + LIS*j) that does not match the
  CDNA3 MFMA C/D register layout, which maps each thread to 4 contiguous
  rows. This causes the custom mask and logits transform to be applied to
  the wrong GQA head for 3 out of every 4 packed query rows. The kNone
  (no mask) path is unaffected.

  Fix: use contiguous-band indexing (lane_idx/TPR)*NAPTR + j in both
  logits_transform() and logits_mask().

Tested on AMD Instinct MI300x with BatchPrefillWithPagedKVCacheWrapper
using bfloat16, custom masks, and GQA (group_size > 1). Both
BatchDecode and BatchPrefill now match PyTorch SDPA reference outputs.

Made-with: Cursor
Add regression tests that would have caught the two kernel bugs fixed in
the previous commit:

1. test_batch_prefill_paged_kv_bf16_correctness: runs BatchPrefill with
   bfloat16 and compares against a naive PyTorch attention reference.
   The existing HIP tests only used fp16, so the bf16 rowsum bug
   (m16k16_rowsum using the wrong MFMA intrinsic) was never exercised.

2. test_batch_prefill_custom_mask_gqa_correctness: runs BatchPrefill
   with a packed custom mask and GQA (group_size > 1) and compares
   against a naive PyTorch reference.  The existing tests never used
   custom masks, so the MFMA row-mapping bug in logits_mask/
   logits_transform was never exercised.

3. test_custom_mask_all_true_matches_causal: verifies that an all-True
   custom mask (kCustom path) produces identical output to causal=True
   (kNone path).  This is a direct regression test for Bug 2, which
   caused these two paths to diverge.

Made-with: Cursor
@subhajitdchow subhajitdchow force-pushed the fix/hip-bf16-rowsum-and-custom-mask-mfma branch from 1addc8f to 78d34d2 Compare April 14, 2026 18:02
@diptorupd
Copy link
Copy Markdown
Collaborator

Test results after these changes:

============================================================================== tests coverage ===============================================================================
_____________________________________________________________ coverage: platform linux, python 3.12.13-final-0 ______________________________________________________________

Name                                      Stmts   Miss Branch BrPart   Cover   Missing
--------------------------------------------------------------------------------------
flashinfer/__init__.py                      137     78      4      2  43.26%   30-171, 266
flashinfer/activation.py                     62     19     22     11  64.29%   33, 49, 56, 63-67, 97, 99, 101, 141, 143, 145, 181, 183, 185, 213-222
flashinfer/aiter_utils.py                     9      2      2      1  72.73%   15-16, 18->exit
flashinfer/aot_hip.py                       129     66     50     10  44.13%   37, 39, 157-168, 207->209, 220-226, 233-242, 246, 254, 261-263, 267-272, 276-277, 293-301, 305-356, 367
flashinfer/compilation_context_hip.py        32      7      6      3  73.68%   54->60, 56->60, 81-85, 103, 115
flashinfer/decode_rocm.py                   271     72    102     34  68.90%   116, 124-193, 284, 306-324, 468->470, 471, 473, 474->476, 476->478, 482, 485-519, 549-552, 554, 686-694, 717, 721, 725, 730, 745, 775-777, 883, 890, 919, 920->923, 939, 943, 981, 1040-1046, 1145->1147, 1147->1152, 1173, 1174->1177, 1178, 1180, 1181->1183, 1183->1186, 1192, 1222, 1276, 1289-1292, 1311-1317, 1330
flashinfer/device_utils.py                   26     15     12      0  28.95%   50-54, 65-69, 79-83
flashinfer/get_include_paths.py              10      0      0      0 100.00%
flashinfer/hip_utils.py                     168     38     64      3  77.16%   62-63, 73-91, 101-115, 125-143, 158, 172-174
flashinfer/jit/__init__.py                   96     62      6      2  35.29%   25-103, 147
flashinfer/jit/activation.py                 23      1      2      1  92.00%   58
flashinfer/jit/attention/__init__.py         40     25      4      2  38.64%   20-54, 58->exit
flashinfer/jit/attention/modules_hip.py     187     23     34      9  83.71%   64, 159, 189, 287-292, 410-415, 554, 609-612, 731, 802-805
flashinfer/jit/core.py                      254     88     72     17  59.20%   17-19, 20->24, 42-43, 55-56, 75, 82, 89, 95, 102-130, 132->156, 152-153, 157-160, 179-182, 200, 204-210, 221-226, 230-231, 262-264, 267-274, 286, 312, 324, 335-336, 339, 343, 366-398, 400->425, 421, 426, 430, 469, 475
flashinfer/jit/cpp_ext_hip.py                83     16     22      7  74.29%   66-67, 75->78, 83->86, 102-111, 114, 151, 192, 205-209
flashinfer/jit/env.py                       103     65     26      5  33.33%   48-50, 73-174, 176->exit, 187-203, 218-226, 234-244
flashinfer/jit/utils.py                      16      0      4      0 100.00%
flashinfer/logits_processor/__init__.py      18      0      0      0 100.00%
flashinfer/norm.py                           63     15     16      6  73.42%   62, 78, 90, 124, 136, 170, 186, 198, 234, 246, 273-275, 285-286
flashinfer/page.py                           54      9      4      1  82.76%   49, 88-93, 154, 274
flashinfer/prefill_rocm.py                  638    168    258     71  68.19%   59->70, 66, 78, 83-97, 124, 314, 429-432, 461, 494, 556, 615, 675, 750, 765-905, 923-943, 1144->1146, 1146->1148, 1148->1150, 1152, 1157, 1172-1180, 1193, 1196->1198, 1249, 1447-1449, 1464, 1491, 1495, 1499, 1503, 1508, 1512, 1558-1560, 1718->1721, 1747, 1756, 1765, 1775, 1783, 1791, 1806-1820, 1847, 1850, 1858->1876, 1878->1900, 2010-2018, 2109->2111, 2111->2116, 2117, 2129, 2140, 2141->2143, 2144, 2146, 2147->2149, 2149->2151, 2157, 2179, 2204, 2248-2251, 2272-2280, 2284, 2290-2298, 2458-2460, 2466->2471, 2485-2495, 2531-2533, 2662->2664, 2665->2668, 2673, 2677, 2680, 2693-2722, 2727-2730, 2744, 2747, 2873->2875, 2885, 2886->2888, 2888->2890, 2890->2892, 2898, 2906, 2910-2916, 2919, 2943
flashinfer/quantization.py                   32     13      0      0  59.38%   37, 42, 125-136
flashinfer/rope.py                          186     36     78     21  72.35%   1179, 1235, 1266, 1354, 1362, 1366-1369, 1374->1383, 1376->1375, 1380, 1406->1409, 1536, 1550, 1559-1564, 1567->1569, 1569->1573, 1575-1582, 1586, 1592-1606, 1613, 1618, 1623, 1638
flashinfer/sampling.py                      253     41     56     16  77.02%   82, 116-117, 152-153, 195-196, 236-238, 318-320, 349, 376, 403, 457-459, 488, 493, 498, 551, 554->557, 613-614, 677-678, 759-760, 842-843, 921-922, 1037-1038, 1048, 1152-1153, 1163, 1448, 1452
flashinfer/utils.py                         466    286    178     25  33.70%   45, 91, 94-103, 109, 112-121, 126-139, 152-163, 168, 173, 181-184, 203, 215-217, 265, 273, 292-316, 323, 327, 349, 354-369, 380, 402-415, 418, 427-428, 432, 437-441, 474-480, 511-521, 564, 570-572, 582-584, 595-597, 604-605, 609-610, 614-615, 619-620, 624-625, 629-630, 634, 645, 649, 653, 660, 683, 689, 705, 710, 715, 746-788, 817-820, 831-860, 866-880, 885-888, 943-971, 1074-1249
--------------------------------------------------------------------------------------
TOTAL                                      3356   1145   1022    247  61.26%
============================================================== 29487 passed, 2876 skipped in 551.73s (0:09:11) ==============================================================

Copy link
Copy Markdown
Collaborator

@diptorupd diptorupd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@diptorupd diptorupd merged commit a7a2f0b into ROCm:amd-integration Apr 14, 2026
1 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants