Skip to content

fix: clamp NaN/Inf in topk_softmax to prevent duplicate expert IDs#39391

Merged
vadiklyutiy merged 5 commits into
vllm-project:mainfrom
jhaotingc:fix/topk-nan-clamp
Apr 21, 2026
Merged

fix: clamp NaN/Inf in topk_softmax to prevent duplicate expert IDs#39391
vadiklyutiy merged 5 commits into
vllm-project:mainfrom
jhaotingc:fix/topk-nan-clamp

Conversation

@jhaotingc
Copy link
Copy Markdown
Contributor

@jhaotingc jhaotingc commented Apr 9, 2026

Purpose

Fix #39244

Fix CUDA illegal memory access crash when serving MoE models (e.g., Qwen3.5-397B-A17B-FP8) with FlashInfer CUTLASS MoE and CUDA graphs at high concurrency on H200.

CUDA graph replay pads the batch to the nearest capture size. Padded tokens have degenerate hidden states that produce NaN gating logits. The topkGating kernel's softmax outputs all-NaN, and the argmax loop picks expert 0 for every top-k slot (IEEE 754: NaN > NaN is false), producing duplicate expert IDs [0,0,0,0,0,0,0,0]. These duplicates trigger an uninitialized-memory bug in FlashInfer's three-step MoE sort, causing finalizeMoeRoutingKernel to dereference wild pointers.

The fix clamps NaN/Inf values to 0 after softmax/sigmoid scoring in topkGating, before the argmax selection loop. With all-zero scores, argmax picks unique experts [0,1,2,...,k-1] via index tie-breaking. Zero performance overhead.

Test Plan

  • Kernel unit test: verify topk_softmax produces unique expert IDs for NaN/Inf/normal gating inputs

  • Kernel microbenchmark: compare eager + CUDA graph replay latency for normal vs NaN inputs (batch 1-512, 128/256 experts)

  • End-to-end: serve Qwen3.5-397B-A17B-FP8 (TP=4, EP=4, CUDA graphs, VLLM_USE_FLASHINFER_MOE_FP8=1) with 8 concurrent requests

  • Full sweep: sglang benchmark conc 1-512, ISL=1600, OSL=600, REPEAT=5 on 4x H200

  • tests/kernels/moe/test_fused_topk.py::test_fused_topk_nan_inf_clamp

Test Result

Kernel correctness (H200):

Input Before fix After fix
normal unique IDs unique IDs (unchanged)
all_nan [0,0,0,0,0,0,0,0] (512/512 dup) [0,1,2,3,4,5,6,7] (0 dup)
all_inf [0,0,0,0,0,0,0,0] (512/512 dup) [0,1,2,3,4,5,6,7] (0 dup)

Kernel perf (H200, CUDA graph replay, median of 1000 runs):

Batch Experts normal (us) all_nan (us) diff
1 128 8.29 8.22 -0.8%
128 128 8.26 8.32 +0.8%
512 128 8.51 8.64 +1.5%
512 256 8.90 8.93 +0.4%

All within noise. Zero measurable overhead.

End-to-end (4x H200, Qwen3.5-397B-A17B-FP8):

Test Before After
8 concurrent curl 5/8 OK, 3/8 crash 8/8 HTTP 200
sweep conc 1-512 crash at conc 16+ all pass

unit test

State Kernel Test result
BEFORE partial fix 60 failed, 12 passed
AFTER — all three clamps active full fix 72 passed

Summary

vLLM crashes with CUDA error: an illegal memory access was encountered when serving Qwen3.5-397B-A17B-FP8 with VLLM_USE_FLASHINFER_MOE_FP8=1 and CUDA graphs enabled. The crash occurs at high concurrency (8+ requests) when the MoE batch size exceeds 256 tokens.

Root Cause

CUDA graph replay pads the batch to the nearest capture size (e.g., 300 real tokens padded to 512). Padded tokens have stale/degenerate hidden states that produce NaN gating logits in the MoE router. The topk_softmax CUDA kernel then produces duplicate expert IDs for NaN inputs (e.g., [0,0,0,0,0,0,0,0] for every padded token), because IEEE 754 NaN > NaN is always false, so the argmax never updates from expert 0, and the -10000 zeroing of the winner also fails (-10000 > NaN is false).

These duplicate expert IDs trigger a latent bug in FlashInfer's blockExpertPrefixSumKernel (three-step MoE sort path, used when num_tokens > 256): it uses break after the first expert match, so duplicate expert slots leave unpermuted_row_to_permuted_row entries uninitialized. finalizeMoeRoutingKernel then reads garbage values as row indices, causing wild pointer dereferences.

Chain of events

CUDA graph replay with padded tokens
  -> stale hidden states -> NaN gating logits
    -> topk_softmax produces [0,0,0,0,0,0,0,0] for padded tokens
      -> duplicate expert IDs enter cutlass_fused_moe (num_tokens > 256)
        -> blockExpertPrefixSumKernel skips duplicate slots (break)
          -> unpermuted_row_to_permuted_row has uninitialized entries
            -> finalizeMoeRoutingKernel reads garbage -> OOB -> CRASH

Why it only happens with CUDA graphs

In eager mode, there are no padded tokens -- the batch contains only real tokens with valid hidden states, the router produces unique expert IDs, and the three-step sort works correctly. The crash requires:

  1. Batch size > 256 (three-step sort path)
  2. Duplicate expert IDs (from NaN gating on padded tokens)

Both conditions only occur together during CUDA graph replay at high concurrency.

Fix

Clamp NaN/Inf values to 0 in topk_softmax after softmax/sigmoid scoring, before the argmax selection loop:

// csrc/moe/topk_softmax_kernels.cu, after line 443
#pragma unroll
for (int ii = 0; ii < VPT; ++ii) {
    if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) {
        row_chunk[ii] = 0.f;
    }
}

With all-zero scores, the argmax uses index tie-breaking to pick unique experts [0,1,2,...,k-1], preventing duplicates. Normal (non-NaN) inputs are unaffected -- the clamp is a no-op.

Why this is the right fix location

The topk_softmax kernel (csrc/moe/topk_softmax_kernels.cu:266) is where the NaN propagates into duplicate expert IDs. Fixing it here:

  • Prevents the bad input from reaching ANY downstream MoE kernel (FlashInfer, Triton, etc.)
  • Zero performance overhead (see benchmarks below)
  • Handles all NaN sources (CUDA graph padding, numerical overflow, any future degenerate input)

Performance Impact

Benchmarked on H200, production MoE configs (128/256 experts, top_k=8). The fix adds isnan/isinf checks (single PTX predicate instructions) per element. The kernel is memory-bandwidth bound, so the extra comparisons are invisible:

Eager mode (us, median of 1000 runs)

Batch Experts normal all_nan diff
1 128 10.94 10.91 -0.3%
8 128 10.72 10.72 0.0%
32 128 10.72 10.75 +0.3%
128 128 10.72 10.66 -0.6%
256 128 10.78 10.72 -0.6%
512 128 10.66 10.69 +0.3%
512 256 10.72 10.72 0.0%

CUDA graph replay mode (us, median of 1000 runs)

Batch Experts normal all_nan diff
1 128 8.29 8.22 -0.8%
8 128 8.19 8.16 -0.4%
32 128 8.29 8.19 -1.2%
128 128 8.26 8.32 +0.8%
256 128 8.32 8.32 0.0%
512 128 8.51 8.64 +1.5%
512 256 8.90 8.93 +0.4%

All differences are within noise (<2%). Zero measurable overhead.

Verification

Standalone (topk kernel)

Before fix:

all_nan: dup_tokens=512/512  topk_ids=[0,0,0,0,0,0,0,0]
all_inf: dup_tokens=512/512  topk_ids=[0,0,0,0,0,0,0,0]

After fix:

all_nan: dup_tokens=0/512  topk_ids=[0,1,2,3,4,5,6,7]
all_inf: dup_tokens=0/512  topk_ids=[0,1,2,3,4,5,6,7]


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements NaN and Inf clamping within the optimized topkGating kernel to prevent duplicate expert IDs, which avoids crashes in FlashInfer's MoE sort. While the change addresses the issue for the optimized path, the reviewer noted that the fallback paths in moeSoftmax and moeSigmoid remain vulnerable and should also be updated to ensure a complete fix for all expert configurations.

Comment on lines +452 to +456
for (int ii = 0; ii < VPT; ++ii) {
if (isnan(row_chunk[ii]) || isinf(row_chunk[ii])) {
row_chunk[ii] = 0.f;
}
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fix correctly addresses the issue for the optimized topkGating kernel. However, the same vulnerability exists in the fallback path used for models with a non-standard number of experts (those that are not a power of 2 or a multiple of 64). In topkGatingKernelLauncher (line 711), the default case calls moeSoftmax or moeSigmoid followed by moeTopK. These kernels currently lack NaN/Inf clamping, meaning they will still produce duplicate expert IDs for degenerate inputs, potentially leading to the same illegal memory access crash in downstream kernels like FlashInfer. To ensure a complete fix, NaN/Inf clamping should also be added to the output loops of moeSoftmax (line 125) and moeSigmoid (line 146).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also included

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a test for this

@ZJY0516 ZJY0516 requested a review from zyongye April 16, 2026 16:50
@ZJY0516 ZJY0516 linked an issue Apr 16, 2026 that may be closed by this pull request
1 task
When CUDA graph padding produces degenerate hidden states that result in
NaN gating logits, softmax outputs all-NaN. The argmax loop then picks
expert 0 for every top-k slot (since NaN > NaN is false per IEEE 754),
producing duplicate expert IDs like [0,0,0,0,0,0,0,0].

These duplicates cause FlashInfer's three-step MoE sort
(blockExpertPrefixSumKernel) to leave permutation entries uninitialized,
leading to wild pointer dereferences in finalizeMoeRoutingKernel and
CUDA illegal memory access crashes.

The fix clamps NaN/Inf values to 0 after softmax/sigmoid scoring, before
the argmax selection loop. With all-zero scores, the argmax uses index
tie-breaking to pick unique experts [0,1,2,...,k-1], preventing duplicates.

Normal (non-NaN) inputs are unaffected — the clamp is a no-op.

Tested: Qwen3.5-397B-A17B-FP8, TP=4 EP=4, CUDA graphs, 8 concurrent
requests — 8/8 HTTP 200 (previously 5/8 OK + crash).

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@jhaotingc jhaotingc force-pushed the fix/topk-nan-clamp branch from 73b8ea7 to 2d6669d Compare April 16, 2026 18:10
@vadiklyutiy vadiklyutiy added the verified Run pre-commit for new contributors without triggering other tests label Apr 16, 2026
@vadiklyutiy vadiklyutiy self-requested a review April 16, 2026 18:12
@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Padded tokens have degenerate hidden states that produce NaN gating logits.

Just wondering, what actual values we use for padding?

@jhaotingc
Copy link
Copy Markdown
Contributor Author

Padded tokens have degenerate hidden states that produce NaN gating logits.

Just wondering, what actual values we use for padding?

It's zeros (found by claude)
https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu_model_runner.py#L3281-L3284

Parametrized over {bf16, fp16, fp32} x {NaN, +Inf} x {softmax, sigmoid}
x {topk=3,4} x {num_experts=8,16} -- 48 cases exercising the patched
topkGatingSoftmax warp kernel path.

Each test case poisons 3 of 4 gating rows with NaN or +Inf and asserts:
- poisoned rows produce unique top-k expert IDs (no duplicates)
- poisoned rows produce finite weights
- the clean row still matches the torch.topk reference

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 16, 2026

Hi @jhaotingc, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

The original fix in 2d6669d only patched the topkGatingSoftmax warp
kernel, which handles num_experts that are a power of 2 or a multiple
of 64. For other num_experts values (e.g., 6), topk_softmax dispatches
to a fallback: moeSoftmax/moeSigmoid writes normalized scores into a
workspace, then moeTopK runs argmax. Neither kernel clamps NaN/Inf, so
the same crash mode (duplicate expert IDs -> uninitialized permutation
entries -> OOB in finalizeMoeRoutingKernel) is reachable for models
with non-power-of-2 expert counts.

Add a clamp-to-zero at the output of moeSoftmax and moeSigmoid so the
workspace moeTopK reads is always finite. With all-zero scores, argmax
tie-breaking plus the -10000 winner zeroing produce unique experts
[0, 1, ..., k-1], matching the warp-kernel fix.

Extend tests/kernels/moe/test_fused_topk.py::test_fused_topk_nan_inf_clamp
to parametrize num_experts over {6, 8, 16} so both paths are covered.

Verified on H200 cu13.0.2 torch2.11.0 with three rebuild cycles:
- All three clamps active: 72/72 pass
- Default-path clamps disabled, warp clamp active: 18/72 fail
  (all failures are num_experts=6, except sigmoid+Inf which saturates
  to 1.0 and coincidentally produces valid output via tie-breaking)
- All clamps disabled (baseline from commit 5f7fab8): prior cycle
  with warp-kernel-only test showed 36/48 fail

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Apr 16, 2026

Hi @jhaotingc, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
@jhaotingc
Copy link
Copy Markdown
Contributor Author

@ZJY0516 added tests

Comment thread csrc/moe/topk_softmax_kernels.cu
@jhaotingc
Copy link
Copy Markdown
Contributor Author

can we trigger pipeline 🤔? TY

const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float softmax_val = expf(val - float_max) * normalizing_factor;
float softmax_val = expf(val - float_max) * normalizing_factor;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to change this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/vllm-project/vllm/blob/main/csrc/moe/topk_softmax_kernels.cu#L644-L646

Here's the logics to choose either (1) topkGatingSoftmax kernel or (2) moeSoftmax/moeSigmoid kernel when calculating topk.
If the expert num is not what's listed in that switch-case, they'll fall to the 2nd path.
That's why gemini suggests fix the path as well, tho it's a rare case for expert to be a bad number.
One of the unit tests (num_expert=6) covers the 2nd path.

const int idx = thread_row_offset + ii;
const float val = toFloat(input[idx]);
const float sigmoid_val = 1.0f / (1.0f + __expf(-val));
float sigmoid_val = 1.0f / (1.0f + __expf(-val));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a reasonable approach. My concern would be around fragility in case some other topk softmax kernel is used that doesn't suppress NaNs.

The softmax values are small so another reasonable and less fragile approach would be to call torch.nan_to_num

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Apr 20, 2026

My only concern is whether the overhead from these additional check is acceptable.

My concern would be around fragility in case some other topk softmax kernel is used that doesn't suppress NaNs.

I hope the fused_topk test will catch these in CI

Btw I think we should merge this pr before v0.20 release

@ZJY0516 ZJY0516 added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 20, 2026
@jhaotingc
Copy link
Copy Markdown
Contributor Author

My only concern is whether the overhead from these additional check is acceptable.

My concern would be around fragility in case some other topk softmax kernel is used that doesn't suppress NaNs.

I hope the fused_topk test will catch these in CI

Btw I think we should merge this pr before v0.20 release

Yeah I added some kernel level test in the Performance Impact section, seems that it's negligible.

@jhaotingc
Copy link
Copy Markdown
Contributor Author

I hope the fused_topk test will catch these in CI

Following the code path, only these two paths (power-of-two and others) and two paths are all tested by the unit test.
Also rarely the number of expert is a weird number 🤣

@jhaotingc
Copy link
Copy Markdown
Contributor Author

(APIServer pid=19255) 2026-04-20 18:42:19,444	ERROR serialization.py:533 -- Failed to unpickle serialized exception
--
(APIServer pid=19255) Traceback (most recent call last):
(APIServer pid=19255)   File "/usr/local/lib/python3.12/dist-packages/ray/exceptions.py", line 50, in from_ray_exception
(APIServer pid=19255)     return pickle.loads(ray_exception.serialized_exception)
(APIServer pid=19255)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=19255) TypeError: RaySystemError.__init__() missing 1 required positional argument: 'client_exc'

None related error. Can we merge?

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

This test doesn't run in CI (see .buildkite/) . Is it intentionally?

gating_output = torch.randn((num_tokens, num_experts), dtype=dtype, device="cuda")
gating_output[1:, :] = bad_value

topk_weights, topk_ids, _ = fused_topk(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add test for fused_topk_bias as well.

Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After fixing

  • add to CI
  • add fused_topk_bias

will look good to me

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

After fixing

  • add to CI
  • add fused_topk_bias

will look good to me

merge it wo fixing comment above because need in v0.20.
Created #40457 as reminder to fix it. @jhaotingc could you pls open new PR to fix it.

@vadiklyutiy vadiklyutiy merged commit 28c2221 into vllm-project:main Apr 21, 2026
146 checks passed
haic0 pushed a commit to haic0/vllm that referenced this pull request Apr 21, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: root <root@gbt350-odcdh5-wbb3.png-odc.dcgpu>
khluu pushed a commit that referenced this pull request Apr 22, 2026
…39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
(cherry picked from commit 28c2221)
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request Apr 22, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
baonudesifeizhai pushed a commit to baonudesifeizhai/vllm that referenced this pull request Apr 23, 2026
yzong-rh pushed a commit to yzong-rh/vllm that referenced this pull request Apr 23, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Yifan <yzong@redhat.com>
avinashsingh77 pushed a commit to avinashsingh77/vllm that referenced this pull request Apr 27, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
Lafunamor pushed a commit to Lafunamor/vllm that referenced this pull request May 1, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Signed-off-by: Adrian <info@zzit.ch>
Copilot AI pushed a commit to hongbolv/vllm that referenced this pull request May 7, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
(cherry picked from commit 28c2221)
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
(cherry picked from commit fafa76f)
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
my-other-github-account pushed a commit to my-other-github-account/vllm that referenced this pull request May 15, 2026
…llm-project#39391)

Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
(cherry picked from commit 6d1e61a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed verified Run pre-commit for new contributors without triggering other tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: CUDA illegal memory access with FlashInfer MoE FP8 on Qwen3.5-397B (num_tokens > 256) [Bug]: qwen 3.5 crash with mtp

4 participants