-
Notifications
You must be signed in to change notification settings - Fork 171
Description
Problem Description
Issues Found in aiter/aiter/mla.py
Issue 1: Missing final_lse Computation in Non-Persistent Mode (Stage 2)
Location: _fwd_kernel_stage2_asm kernel (lines 18-97) and mla_decode_fwd function
Description:
In non-persistent mode, the Triton stage2 kernel (_fwd_kernel_stage2_asm) performs the reduction across KV splits but does not compute or store the final_lse values. The final_lse tensor is allocated at line 216 but is never populated in the stage2 kernel.
Impact:
The final_lse tensor returned from mla_decode_fwd contains NaN values when in non-persistent mode, which can cause issues in downstream operations that rely on the log-sum-exp values.
Code Reference:
final_lseis allocated:aiter/aiter/mla.py:216- Stage2 kernel is called:
aiter/aiter/mla.py:248-268
Suggested Fix:
The _fwd_kernel_stage2_asm kernel should be updated to:
- Accept
final_lseas a parameter - Store the computed (
e_max + tl.log(e_sum)) to thefinal_lsetensor after the reduction loop - Ensure proper stride calculations for writing to
final_lse
Issue 2: Inconsistent Tensor Reshaping for final_lse in Persistent Mode
Location: mla_decode_fwd function, lines 269-336
Description:
In persistent mode, when nhead falls within the condition at line 275 (nhead in range 32-512, step 16), the code transforms the tensors by:
- Recalculating
total_s = ori_total_s * (ori_nhead // 16)(line 278) - Setting
nhead = 16(line 279) - Reshaping
qandotensors using.view()(lines 280-281)
The final_lse tensor is then allocated with these modified dimensions at line 296:
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)However, at the end of the function (lines 328-334), only q and o are reshaped back to their original dimensions. The final_lse tensor is not reshaped back to match the original (ori_total_s, ori_nhead) shape.
Impact:
The returned final_lse tensor has incorrect dimensions when io_transformed is True. It returns shape (ori_total_s * (ori_nhead // 16), 16) instead of the expected (ori_total_s, ori_nhead).
Code Reference:
- Transformation logic:
aiter/aiter/mla.py:275-282 final_lseallocation with modified dims:aiter/aiter/mla.py:296qandoreshape back:aiter/aiter/mla.py:333-334- Missing
final_lsereshape
Suggested Fix:
Add a line to reshape final_lse back to original dimensions alongside q and o (after line 334):
if io_transformed:
if persistent_mode:
logits = logits.view(-1, 1, ori_nhead, v_head_dim)
else:
logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim)
q = q.view(ori_total_s, ori_nhead, -1)
o = o.view(ori_total_s, ori_nhead, -1)
final_lse = final_lse.view(ori_total_s, ori_nhead) # ADD THIS LINEIssue 3: Inconsistent nhead Condition Logic Between Test and Implementation
Location: aiter/op_tests/test_mla_persistent.py:391-392 vs aiter/aiter/mla.py:275
Description:
The condition for testing persistent mode with bfloat16 dtype in the test file does not match the implementation logic in mla.py:
Test file (lines 391-392):
if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and (
nhead == 16 or (nhead in range(32, 128, 16) and decode_qlen == 1)
):Implementation file (line 275):
elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:Discrepancies:
- Range mismatch: Test uses
range(32, 128, 16)which includes [32, 48, 64, 80, 96, 112], while implementation usesrange(32, 512 + 1, 16)which extends up to 512 - Upper bound: Test stops at 112 (< 128), implementation goes up to 512
Impact:
- The test file configuration at line 422 includes
(128, 2)inlist_nhead:list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)]
- This test case
(nhead=128, mtp=2)will NOT executetest_absorb_decode_bf16()because:- It fails the test condition:
nhead=128is not inrange(32, 128, 16) - Even if it were in range,
mtp=2fails themtp == 1condition
- It fails the test condition:
- However,
nhead=128also doesn't match the implementation condition at line 275, as it would fall into thenhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8)condition at line 272
Issue 4: Unclear Support for mtp=2 in Persistent Mode Tests
Location: aiter/op_tests/test_mla_persistent.py:422
Description:
The test configuration includes (128, 2) where mtp=2 (meaning max_seqlen_q=2), but it's unclear if mtp > 1 is actually supported in persistent mode for nhead=128:
list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)]According to the implementation logic:
- Lines 272-274: Native support exists for
nhead == 16ornhead == 128with fp8 dtype - Line 275: Special handling for
nhead in range(32, 512+1, 16)only whenmax_seqlen_q == 1
Issue: The test includes (128, 2) but this configuration:
- Doesn't match line 272 conditions unless dtype is fp8
- Doesn't match line 275 because
mtp=2meansmax_seqlen_q=2, not 1
Questions:
- Is
mtp > 1supported fornhead=128in persistent mode? - Should this test case be removed or is there missing implementation support?
Issue 5: Test Failure for nhead=128, mtp=1 with bfloat16 dtype
Location: aiter/op_tests/test_mla_persistent.py and aiter/aiter/mla.py
Description:
When attempting to test nhead=128 with mtp=1 (i.e., max_seqlen_q=1) in persistent mode with bfloat16 dtype, the test_absorb_decode_bf16() unit test fails.
Root Cause Analysis:
Looking at the implementation logic in mla.py:272-284:
if nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8):
# Natively support cases
pass
elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:
# we use nhead=16 to simulate such cases by customized metadata
...
else:
assert False, f"{nhead=} and {max_seqlen_q=} not supported"The issue: nhead=128 with bfloat16 dtype and persistent mode:
- Does NOT match line 272 condition (requires fp8 dtype:
kv_buffer.dtype == dtypes.fp8) - Does match line 275 condition (
128 in range(32, 512+1, 16)andpersistent_mode=Trueandmax_seqlen_q=1) - Therefore, it executes through the transformation path (lines 278-282)
Test Findings:
Testing was performed with nhead=128, mtp=1 configuration with bfloat16 dtype:
-
Without any fixes:
- The code successfully goes through the transformation path (lines 275-282)
- The transformation reshapes
total_sandnheadto simulate the case - Result:
out_refvsout_asmcomparison FAILS - The output tensors do not match the reference implementation
-
With
final_lsereshape fix applied (from Issue 2):- Added
final_lse = final_lse.view(ori_total_s, ori_nhead)after line 334 - Result: BOTH comparisons fail:
out_refvsout_asmstill FAILSattn_lsevslse_refalso FAILS
- The reshape fix appears to expose additional correctness issues
- Added
Impact:
The transformation logic (lines 275-282) appears to have deeper correctness issues beyond just the missing final_lse reshape. The test failures indicate that:
- The transformation approach may not correctly handle
nhead=128with bfloat16 dtype - The metadata generation or kernel execution may have bugs when using transformed dimensions
- Both the output values and the LSE (log-sum-exp) values are computed incorrectly
Status:
- Metadata generation for transformed nhead=128 case
- Stride calculations in the persistent kernel path
- Whether the transformation logic is fundamentally incompatible with nhead=128
Recommended Action:
- Short-term: Explicitly exclude
nhead=128with bfloat16 dtype in persistent mode (update line 272 condition or add explicit check) until the issue is resolved - Long-term: Investigate and fix the transformation logic or implement native support for
nhead=128with bfloat16:- Debug the metadata generation for this configuration
- Verify stride calculations in the reshaped tensor path
- Consider whether nhead=128 should have native support similar to nhead=16
Operating System
Ubuntu
CPU
GPU
MI300X
ROCm Version
ROCm7
ROCm Component
No response
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response