fix(hip): clamp max_num_kv_chunks to avoid SIGFPE in single decode on CPX devices#218
Merged
demandal25 merged 2 commits intoROCm:amd-integrationfrom Apr 29, 2026
Conversation
In the partition-KV path of SingleDecodeWithKVCacheDispatched, max_num_kv_chunks was computed as `max_grid_size / num_kv_heads` without a lower bound. When num_kv_heads exceeds max_grid_size — e.g. MI308X CPX exposes 20 CUs while a shape uses 32 kv-heads — the integer division underflows to 0 and the subsequent `ceil_div(seq_len, 0)` raises SIGFPE in the kernel-launch host code. The existing guard only catches `num_blocks_per_sm == 0`, not this divisor underflow. Clamp the result to >=1 so the path falls back to one CTA per kv-head (no further KV split), which is the correct behavior when the device cannot fit all kv-heads simultaneously. Reproduces with `tests/attention/test_logits_cap.py::test_single_decode_logits_soft_cap` at (seq_len=257, num_heads=32, head_dim=256, soft_cap=1.0) on a 20-CU device. After the fix, all 450 tests in test_logits_cap.py pass. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This PR fixes a crash in the HIP single-decode “partition-KV” launch path on MI308X CPX devices by preventing a host-side divide-by-zero when the device can’t schedule one CTA per KV head concurrently.
Changes:
- Clamp
max_num_kv_chunksto be at least 1 to avoidceil_div(seq_len, 0)and resulting SIGFPE. - Add an explanatory comment describing why the clamp is needed on CPX configurations where
num_kv_heads > max_grid_size.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Debasis Mandal <Debasis.Mandal@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR fixes a crash (SIGFPE) in the HIP generic single-decode “partition-KV” launch path on devices where num_kv_heads > max_grid_size, by preventing max_num_kv_chunks from truncating to zero and later becoming a divisor in ceil_div.
Changes:
- Clamp
max_num_kv_chunksto be at least1when computing the partitioning factor. - Add an inline comment explaining the MI308X CPX scenario and why the fallback behavior is correct.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
tests/attention/test_logits_cap.py::test_single_decode_logits_soft_capSIGFPEs at(seq_len=257, num_heads=32, head_dim=256, soft_cap=1.0)on MI308X CPX (20 CUs).SingleDecodeWithKVCacheDispatched's partition-KV path,max_num_kv_chunks = max_grid_size / num_kv_headsunderflows to 0 whennum_kv_heads > max_grid_size(e.g. 20 CUs × 1 block/SM = 20 < 32 kv-heads). The next line then callsceil_div(seq_len, 0)→ SIGFPE in the host launch code.max_num_kv_chunksto>= 1. With the clamp, the kernel falls back to one CTA per kv-head (no further KV split) — the correct behavior when the device can't fit all kv-heads simultaneously.The existing guard at line 700 only catches
num_blocks_per_sm == 0; it does not cover this divisor-underflow case.Test plan
seq_len ∈ {256, 257, 320, 729, 33001},head_dim=256,num_heads=32) returns valid output instead of crashing.pytest tests/attention/test_logits_cap.py— all 450 cases pass (was crashing on first head_dim=256 / num_heads=32 / seq_len>256 case).🤖 Generated with Claude Code