Skip to content

mla: drop max_split_per_batch=16 cap to match vLLM#611

Open
peizhang56 wants to merge 1 commit intomainfrom
mla-drop-max-split-per-batch-cap
Open

mla: drop max_split_per_batch=16 cap to match vLLM#611
peizhang56 wants to merge 1 commit intomainfrom
mla-drop-max-split-per-batch-cap

Conversation

@peizhang56
Copy link
Copy Markdown

Summary

ATOM was passing max_split_per_batch=16 to aiter's get_mla_metadata_v1
in three sites (one in atom/plugin/attention.py, two in
atom/model_ops/attentions/aiter_mla.py). aiter then computes the work
split as min(num_clusters, max_split_per_batch * bs), which severely
under-utilizes the GPU at small batch / large KV.

vLLM's AiterMLAMetadataBuilder._build_decode (in
vllm/v1/attention/backends/mla/rocm_aiter_mla.py) omits the parameter
entirely, letting it default to -1 so the kernel uses all
num_clusters splits. This PR aligns ATOM with vLLM.

Why

  • The FP8 MLA decode-stage1 kernel mla_a8w8_qh16_qseqlen1_gqaratio16_ps
    was running ~4x slower in ATOM than in vLLM on the same workload.
  • Root cause: min(num_clusters, 16 * bs) caps splits far below the
    available CUs at small batches. e.g. bs=2 on a 256-CU GPU yields
    only 32 splits used out of 256 CUs.
  • Buffer safety: get_mla_metadata_info_v1 pre-sizes the reduce/partial
    buffers for ~2 * num_clusters tiles, so passing -1 is within the
    pre-allocated capacity.

Test plan

The aiter persistent-MLA op test reproduces the difference directly.
-ms is the same max_split_per_batch knob exposed at the test layer:

python op_tests/test_mla_persistent.py -d fp8 -kvd fp8 -n 16,1 \
    -k 512 -qr 64 -vh 512 -blk 1 -b 4 -c 100000 -ms 16

vs.

python op_tests/test_mla_persistent.py -d fp8 -kvd fp8 -n 16,1 \
    -k 512 -qr 64 -vh 512 -blk 1 -b 4 -c 100000 -ms -1

The -ms -1 run reports a substantially lower MLA decode kernel time.
Mirror runs through ATOM's serving stack on the same GPU show the same
speedup on the mla_a8w8_qh16_qseqlen1_gqaratio16_ps invocation.

  • Re-run the aiter op test on gfx950 with both -ms 16 and -ms -1
  • Re-run an ATOM kimi-2.5 long-context decode trace and confirm the
    mla_a8w8_qh16_qseqlen1_gqaratio16_ps kernel time matches vLLM
  • Smoke-test ATOM short-context / small-batch decode for regressions

Notes

Other max_split_per_batch references are intentionally left untouched:
atom/model_ops/attentions/aiter_attention.py already passes -1, and
the SGLang backend in
atom/plugin/sglang/attention_backend/sgl_attn_backend.py keeps its own
configurable knob.

Historical context: the cap was introduced in #47 ("limit
max_split_per_batch to 16") with no recorded rationale, and propagated
to the plugin via #304. vLLM's MLA backend never set the cap.

ATOM was passing `max_split_per_batch=16` to aiter's `get_mla_metadata_v1`
in three sites (one in `atom/plugin/attention.py`, two in
`atom/model_ops/attentions/aiter_mla.py`). aiter then computed the work
split as `min(num_clusters, max_split_per_batch * bs)`, which severely
under-utilizes the GPU at small batch / large KV. vLLM's
`AiterMLAMetadataBuilder._build_decode` (in
`vllm/v1/attention/backends/mla/rocm_aiter_mla.py`) omits the parameter
entirely, letting it default to -1 so the kernel uses all `num_clusters`
splits.

This change drops the cap so the FP8 MLA decode-stage1 kernel
(`mla_a8w8_qh16_qseqlen1_gqaratio16_ps`) gets full CU utilization.
The aiter persistent-MLA op test makes the win clear:

    python op_tests/test_mla_persistent.py -d fp8 -kvd fp8 -n 16,1 \
        -k 512 -qr 64 -vh 512 -blk 1 -b 4 -c 100000 -ms 16

(`-ms 16` = the previous ATOM behavior; `-ms -1` = the new behavior =
vLLM behavior. Compare the reported decode kernel time across the two
runs.)

Buffer safety: `get_mla_metadata_info_v1` already pre-sizes the
reduce/partial buffers to `~2 * num_clusters` tiles, so any value of
`max_split_per_batch` (including -1) is within the pre-allocated
capacity. The other `max_split_per_batch` references in the repo
(`aiter_attention.py` already uses -1; the SGLang backend keeps its
own configurable knob) are intentionally left untouched.

Made-with: Cursor
@peizhang56
Copy link
Copy Markdown
Author

This helped with Kimi2.5 mla when running large context on MI355
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant