Skip to content

[Issue][MLA]: Issues Found in aiter/aiter/mla.py #1420

@vllmellm

Description

@vllmellm

Problem Description

Issues Found in aiter/aiter/mla.py

Issue 1: Missing final_lse Computation in Non-Persistent Mode (Stage 2)

Location: _fwd_kernel_stage2_asm kernel (lines 18-97) and mla_decode_fwd function

Description:
In non-persistent mode, the Triton stage2 kernel (_fwd_kernel_stage2_asm) performs the reduction across KV splits but does not compute or store the final_lse values. The final_lse tensor is allocated at line 216 but is never populated in the stage2 kernel.

Impact:
The final_lse tensor returned from mla_decode_fwd contains NaN values when in non-persistent mode, which can cause issues in downstream operations that rely on the log-sum-exp values.

Code Reference:

  • final_lse is allocated: aiter/aiter/mla.py:216
  • Stage2 kernel is called: aiter/aiter/mla.py:248-268

Suggested Fix:
The _fwd_kernel_stage2_asm kernel should be updated to:

  1. Accept final_lse as a parameter
  2. Store the computed (e_max + tl.log(e_sum)) to the final_lse tensor after the reduction loop
  3. Ensure proper stride calculations for writing to final_lse

Issue 2: Inconsistent Tensor Reshaping for final_lse in Persistent Mode

Location: mla_decode_fwd function, lines 269-336

Description:
In persistent mode, when nhead falls within the condition at line 275 (nhead in range 32-512, step 16), the code transforms the tensors by:

  • Recalculating total_s = ori_total_s * (ori_nhead // 16) (line 278)
  • Setting nhead = 16 (line 279)
  • Reshaping q and o tensors using .view() (lines 280-281)

The final_lse tensor is then allocated with these modified dimensions at line 296:

final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)

However, at the end of the function (lines 328-334), only q and o are reshaped back to their original dimensions. The final_lse tensor is not reshaped back to match the original (ori_total_s, ori_nhead) shape.

Impact:
The returned final_lse tensor has incorrect dimensions when io_transformed is True. It returns shape (ori_total_s * (ori_nhead // 16), 16) instead of the expected (ori_total_s, ori_nhead).

Code Reference:

  • Transformation logic: aiter/aiter/mla.py:275-282
  • final_lse allocation with modified dims: aiter/aiter/mla.py:296
  • q and o reshape back: aiter/aiter/mla.py:333-334
  • Missing final_lse reshape

Suggested Fix:
Add a line to reshape final_lse back to original dimensions alongside q and o (after line 334):

if io_transformed:
    if persistent_mode:
        logits = logits.view(-1, 1, ori_nhead, v_head_dim)
    else:
        logits = logits.view(ori_total_s, num_kv_splits, ori_nhead, v_head_dim)
    q = q.view(ori_total_s, ori_nhead, -1)
    o = o.view(ori_total_s, ori_nhead, -1)
    final_lse = final_lse.view(ori_total_s, ori_nhead)  # ADD THIS LINE

Issue 3: Inconsistent nhead Condition Logic Between Test and Implementation

Location: aiter/op_tests/test_mla_persistent.py:391-392 vs aiter/aiter/mla.py:275

Description:
The condition for testing persistent mode with bfloat16 dtype in the test file does not match the implementation logic in mla.py:

Test file (lines 391-392):

if (dtype == torch.bfloat16 and kvtype == torch.bfloat16) and (
    nhead == 16 or (nhead in range(32, 128, 16) and decode_qlen == 1)
):

Implementation file (line 275):

elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:

Discrepancies:

  1. Range mismatch: Test uses range(32, 128, 16) which includes [32, 48, 64, 80, 96, 112], while implementation uses range(32, 512 + 1, 16) which extends up to 512
  2. Upper bound: Test stops at 112 (< 128), implementation goes up to 512

Impact:

  • The test file configuration at line 422 includes (128, 2) in list_nhead:
    list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)]
  • This test case (nhead=128, mtp=2) will NOT execute test_absorb_decode_bf16() because:
    • It fails the test condition: nhead=128 is not in range(32, 128, 16)
    • Even if it were in range, mtp=2 fails the mtp == 1 condition
  • However, nhead=128 also doesn't match the implementation condition at line 275, as it would fall into the nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8) condition at line 272

Issue 4: Unclear Support for mtp=2 in Persistent Mode Tests

Location: aiter/op_tests/test_mla_persistent.py:422

Description:
The test configuration includes (128, 2) where mtp=2 (meaning max_seqlen_q=2), but it's unclear if mtp > 1 is actually supported in persistent mode for nhead=128:

list_nhead = [(16, 1), (16, 2), (16, 4), (48, 1), (128, 2)]

According to the implementation logic:

  • Lines 272-274: Native support exists for nhead == 16 or nhead == 128 with fp8 dtype
  • Line 275: Special handling for nhead in range(32, 512+1, 16) only when max_seqlen_q == 1

Issue: The test includes (128, 2) but this configuration:

  1. Doesn't match line 272 conditions unless dtype is fp8
  2. Doesn't match line 275 because mtp=2 means max_seqlen_q=2, not 1

Questions:

  • Is mtp > 1 supported for nhead=128 in persistent mode?
  • Should this test case be removed or is there missing implementation support?

Issue 5: Test Failure for nhead=128, mtp=1 with bfloat16 dtype

Location: aiter/op_tests/test_mla_persistent.py and aiter/aiter/mla.py

Description:
When attempting to test nhead=128 with mtp=1 (i.e., max_seqlen_q=1) in persistent mode with bfloat16 dtype, the test_absorb_decode_bf16() unit test fails.

Root Cause Analysis:
Looking at the implementation logic in mla.py:272-284:

if nhead == 16 or (nhead == 128 and kv_buffer.dtype == dtypes.fp8):
    # Natively support cases
    pass
elif nhead in range(32, 512 + 1, 16) and persistent_mode and max_seqlen_q == 1:
    # we use nhead=16 to simulate such cases by customized metadata
    ...
else:
    assert False, f"{nhead=} and {max_seqlen_q=} not supported"

The issue: nhead=128 with bfloat16 dtype and persistent mode:

  • Does NOT match line 272 condition (requires fp8 dtype: kv_buffer.dtype == dtypes.fp8)
  • Does match line 275 condition (128 in range(32, 512+1, 16) and persistent_mode=True and max_seqlen_q=1)
  • Therefore, it executes through the transformation path (lines 278-282)

Test Findings:
Testing was performed with nhead=128, mtp=1 configuration with bfloat16 dtype:

  1. Without any fixes:

    • The code successfully goes through the transformation path (lines 275-282)
    • The transformation reshapes total_s and nhead to simulate the case
    • Result: out_ref vs out_asm comparison FAILS
    • The output tensors do not match the reference implementation
  2. With final_lse reshape fix applied (from Issue 2):

    • Added final_lse = final_lse.view(ori_total_s, ori_nhead) after line 334
    • Result: BOTH comparisons fail:
      • out_ref vs out_asm still FAILS
      • attn_lse vs lse_ref also FAILS
    • The reshape fix appears to expose additional correctness issues

Impact:
The transformation logic (lines 275-282) appears to have deeper correctness issues beyond just the missing final_lse reshape. The test failures indicate that:

  1. The transformation approach may not correctly handle nhead=128 with bfloat16 dtype
  2. The metadata generation or kernel execution may have bugs when using transformed dimensions
  3. Both the output values and the LSE (log-sum-exp) values are computed incorrectly

Status:
⚠️ Under Investigation - Root cause not yet determined. Possible areas to investigate:

  • Metadata generation for transformed nhead=128 case
  • Stride calculations in the persistent kernel path
  • Whether the transformation logic is fundamentally incompatible with nhead=128

Recommended Action:

  1. Short-term: Explicitly exclude nhead=128 with bfloat16 dtype in persistent mode (update line 272 condition or add explicit check) until the issue is resolved
  2. Long-term: Investigate and fix the transformation logic or implement native support for nhead=128 with bfloat16:
    • Debug the metadata generation for this configuration
    • Verify stride calculations in the reshaped tensor path
    • Consider whether nhead=128 should have native support similar to nhead=16

Operating System

Ubuntu

CPU

GPU

MI300X

ROCm Version

ROCm7

ROCm Component

No response

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions