Skip to content

fix(hip): clamp max_num_kv_chunks to avoid SIGFPE in single decode on CPX devices#218

Merged
demandal25 merged 2 commits intoROCm:amd-integrationfrom
demandal25:fix/single-decode-sigfpe-cpx
Apr 29, 2026
Merged

fix(hip): clamp max_num_kv_chunks to avoid SIGFPE in single decode on CPX devices#218
demandal25 merged 2 commits intoROCm:amd-integrationfrom
demandal25:fix/single-decode-sigfpe-cpx

Conversation

@demandal25
Copy link
Copy Markdown
Collaborator

@demandal25 demandal25 commented Apr 29, 2026

Summary

  • tests/attention/test_logits_cap.py::test_single_decode_logits_soft_cap SIGFPEs at (seq_len=257, num_heads=32, head_dim=256, soft_cap=1.0) on MI308X CPX (20 CUs).
  • Root cause: in SingleDecodeWithKVCacheDispatched's partition-KV path, max_num_kv_chunks = max_grid_size / num_kv_heads underflows to 0 when num_kv_heads > max_grid_size (e.g. 20 CUs × 1 block/SM = 20 < 32 kv-heads). The next line then calls ceil_div(seq_len, 0) → SIGFPE in the host launch code.
  • Fix: clamp max_num_kv_chunks to >= 1. With the clamp, the kernel falls back to one CTA per kv-head (no further KV split) — the correct behavior when the device can't fit all kv-heads simultaneously.

The existing guard at line 700 only catches num_blocks_per_sm == 0; it does not cover this divisor-underflow case.

Test plan

  • Minimal repro (seq_len ∈ {256, 257, 320, 729, 33001}, head_dim=256, num_heads=32) returns valid output instead of crashing.
  • pytest tests/attention/test_logits_cap.py — all 450 cases pass (was crashing on first head_dim=256 / num_heads=32 / seq_len>256 case).
  • Run full ROCm test suite in CI to confirm no regressions in non-CPX paths.

🤖 Generated with Claude Code

In the partition-KV path of SingleDecodeWithKVCacheDispatched, max_num_kv_chunks
was computed as `max_grid_size / num_kv_heads` without a lower bound. When
num_kv_heads exceeds max_grid_size — e.g. MI308X CPX exposes 20 CUs while a
shape uses 32 kv-heads — the integer division underflows to 0 and the
subsequent `ceil_div(seq_len, 0)` raises SIGFPE in the kernel-launch host code.

The existing guard only catches `num_blocks_per_sm == 0`, not this divisor
underflow. Clamp the result to >=1 so the path falls back to one CTA per
kv-head (no further KV split), which is the correct behavior when the device
cannot fit all kv-heads simultaneously.

Reproduces with `tests/attention/test_logits_cap.py::test_single_decode_logits_soft_cap`
at (seq_len=257, num_heads=32, head_dim=256, soft_cap=1.0) on a 20-CU device.
After the fix, all 450 tests in test_logits_cap.py pass.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings April 29, 2026 04:45
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a crash in the HIP single-decode “partition-KV” launch path on MI308X CPX devices by preventing a host-side divide-by-zero when the device can’t schedule one CTA per KV head concurrently.

Changes:

  • Clamp max_num_kv_chunks to be at least 1 to avoid ceil_div(seq_len, 0) and resulting SIGFPE.
  • Add an explanatory comment describing why the clamp is needed on CPX configurations where num_kv_heads > max_grid_size.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread include/flashinfer/attention/generic/decode.cuh Outdated
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
Copilot AI review requested due to automatic review settings April 29, 2026 16:20
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a crash (SIGFPE) in the HIP generic single-decode “partition-KV” launch path on devices where num_kv_heads > max_grid_size, by preventing max_num_kv_chunks from truncating to zero and later becoming a divisor in ceil_div.

Changes:

  • Clamp max_num_kv_chunks to be at least 1 when computing the partitioning factor.
  • Add an inline comment explaining the MI308X CPX scenario and why the fallback behavior is correct.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@demandal25 demandal25 merged commit e882a22 into ROCm:amd-integration Apr 29, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants