Skip to content

fix(generation): handle CUDA multinomial limit in beam search sampling#45369

Closed
sharziki wants to merge 1 commit intohuggingface:mainfrom
sharziki:fix/45245-beam-search-multinomial-limit
Closed

fix(generation): handle CUDA multinomial limit in beam search sampling#45369
sharziki wants to merge 1 commit intohuggingface:mainfrom
sharziki:fix/45245-beam-search-multinomial-limit

Conversation

@sharziki
Copy link
Copy Markdown
Contributor

Summary

Fixes #45245torch.multinomial crashes with RuntimeError: number of categories cannot exceed 2^24 when num_beams * vocab_size > 16,777,216 during beam search with do_sample=True.

Root cause: In _get_top_k_continuations(), the accumulated log-probs are flattened to shape (batch_size, num_beams * vocab_size) and passed directly to torch.multinomial. With large beam counts (e.g. 128) and large vocabularies (e.g. 164K), this exceeds PyTorch's CUDA limit of 2^24 categories.

Fix: When the flattened dimension exceeds 2^24, pre-filter to the top 2^24 candidates using torch.topk (which has no such limit), then sample from the filtered set. The candidate indices are mapped back to the original space. This preserves the sampling distribution — with 16.7M out of ~21M candidates retained, virtually all probability mass is covered.

The fix is 7 net new lines. No new files, no new dependencies, no behavioral change for users within the limit.

Coordination

Test plan

  • Verify model.generate(num_beams=128, do_sample=True) no longer crashes with large-vocab models
  • Verify normal beam search (num_beams < 2^24/vocab_size) is unaffected (takes the else branch)
  • ruff check src/transformers/generation/utils.py passes

🤖 Generated with Claude Code

torch.multinomial on CUDA requires the last dimension to be <= 2^24.
With large num_beams * vocab_size (e.g. 128 * 164K = 21M), this limit
is exceeded, causing a RuntimeError. Pre-filter to the top 2^24
candidates via torch.topk before sampling when necessary.

Fixes huggingface#45245

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Copy Markdown
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=45369&sha=838fbd

@Rocketknight1
Copy link
Copy Markdown
Member

Hi @sharziki, as commented in the issue I don't think we need extra code paths to solve what is a very rare edge case. If you're doing torch.multinomial over 16 million values then something has gone terribly wrong 😅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: number of categories cannot exceed 2^24

2 participants