Skip to content

feat: Support top-p and top-k#1578

Open
zhandaz wants to merge 9 commits intomainfrom
zhanda/top-p-k-with-tests
Open

feat: Support top-p and top-k#1578
zhandaz wants to merge 9 commits intomainfrom
zhanda/top-p-k-with-tests

Conversation

@zhandaz
Copy link
Copy Markdown
Contributor

@zhandaz zhandaz commented Nov 27, 2025

What does this PR do ?

This PR implements top-k and top-p sampling in training to to ensure consistency between training and inference with sampling parameters.

  1. Why top-p and top-k are a bit tricky? In TP, the vocab is sharded across GPUs (vocab parallel). However, top-p and top-k require the entire probabilities across the full vocabulary, not individual shards. This means applying all-gather to get the full logits, which leads to communication overhead and the large memory consumption.
  2. Solution: We convert from vocab-parallel to batch-sequence-parallel layout via all-to-all communication, apply filtering on the full vocabulary, then convert back.

Detailed Changes

  1. In nemo_rl/model/policy/utils.py: added the TrainingSamplingParams and implemented apply_top_k_top_p with proper autograd support.
  2. In nemo_rl/distributed/model_utils.py: implemented DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling, and integrated into all logprob computation paths (TP, CP, DTensor, packed sequences).
  3. Loss Functions and policy workers: updated ClippedPGLossFn, NLLLoss, and DPOLossFn to accept sampling_params; changed the policy workers call the functions with sampling_params.

Tests

  • test_top_k_top_p_filtering_forward_backward: Tests top_k and top_p functionalities. Test gradient masking correctness.
  • test_distributed_logprob_with_sampling: Tests TP distributed implementation.
  • test_vllm_policy_logprob_agreement_with_sampling: Test that policy worker logprobs match vLLM with sampling parameters.

We have also tested that with temperature != 1, the logprobs match. See test_vllm_policy_logprob_agreement_with_sampling.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Enhanced generation with top-k/top-p filtering support for improved sampling control
    • Improved temperature scaling during model generation
    • Better distributed sampling computation paths
  • Bug Fixes

    • Fixed NaN issues in gradient computations during sampling operations
    • Adjusted sampling parameter validation thresholds for better compatibility
  • Tests

    • Added comprehensive tests for sampling functionality and distributed computation paths

✏️ Tip: You can customize this high-level summary in your review settings.

@zhandaz zhandaz requested review from a team as code owners November 27, 2025 20:20
Signed-off-by: Zhanda <zhandazhu@gmail.com>
@zhandaz zhandaz force-pushed the zhanda/top-p-k-with-tests branch from f45b990 to d584519 Compare November 27, 2025 20:20
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: f45b990 (PR #1578 from zhanda/top-p-k-with-tests)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: d584519 (PR #1578 from zhanda/top-p-k-with-tests)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Nov 27, 2025

📝 Walkthrough

Walkthrough

This PR integrates top-k/top-p generation sampling filtering throughout NeMo RL's training pipeline. It introduces TrainingSamplingParams to encapsulate sampling configuration, adds distributed sampling autograd functions with gradient-aware mask handling, propagates sampling parameters through loss functions and policy workers, replaces policy_cfg temperature handling with centralized sampling parameters, and implements conditional top-k/top-p filtering in both distributed and non-distributed contexts.

Changes

Cohort / File(s) Summary
Core Sampling Utilities
nemo_rl/models/policy/utils.py
Introduces TrainingSamplingParams dataclass and helper functions (need_top_k_filtering, need_top_p_filtering). Adds ApplyTopKTopP autograd function with gradient masking support and overhauls apply_top_k_top_p to handle combined top-k/top-p filtering logic with explicit keep_mask tracking.
Distributed Logprob Computation
nemo_rl/distributed/model_utils.py
Adds DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling autograd functions for sampled distributed logprob computation. Introduces all_to_all_vp2sq and all_to_all_sq2vp layout transformation helpers. Extends public functions (dtensor_from_parallel_logits_to_logprobs, from_parallel_logits_to_logprobs, from_parallel_logits_to_logprobs_packed_sequences, get_logprobs_from_vocab_parallel_logits) with sampling_params parameter and conditional routing to sampling-enabled paths.
Loss Functions
nemo_rl/algorithms/loss_functions.py
Adds sampling_params: TrainingSamplingParams | None parameter to public methods: ClippedPGLossFn.__call__, NLLLoss.__call__, DistillationLossFn.__call__, DistillationLossFn._dpo_loss, DPOLossFn.__call__, DPOLossFn._dpo_loss, SequencePackingLossWrapper.__call__. Implements masking safeguards to replace -inf with zeros after logprob computation. Raises ValueError for unsupported distillation+sampling combinations.
Megatron Integration
nemo_rl/models/megatron/common.py
Replaces policy_cfg parameter with sampling_params: Optional[TrainingSamplingParams] in forward_step_arbitrary_loss. Refactors temperature scaling to use sampling_params.temperature and propagates sampling_params through model invocation and loss wrapper construction.
Policy Workers — DTensor
nemo_rl/models/policy/dtensor_policy_worker.py
Initializes self.sampling_params from generation config. Adds _apply_top_k_top_p_filtering method for local filtering. Propagates sampling_params through train (to loss computation) and get_logprobs (to vocab-parallel logprob pathways). Applies temperature scaling via sampling_params.
Policy Workers — DTensor v2
nemo_rl/models/policy/dtensor_policy_worker_v2.py
Integrates TrainingSamplingParams and adds _apply_top_k_top_p_filtering method. Replaces generation-based temperature with sampling_params.temperature. Propagates sampling_params through train (to loss function) and get_logprobs (to distributed/non-distributed paths). Applies conditional filtering on chunk logits when TP filtering is not globally applied.
Policy Workers — Megatron
nemo_rl/models/policy/megatron_policy_worker.py
Initializes self.sampling_params from generation config. Passes sampling_params to forward_step_arbitrary_loss and loss computation. Refactors temperature scaling logic to use sampling_params.temperature across training and logprob collection paths. Propagates sampling_params through from_parallel_logits_to_logprobs and related logprob calls. Exports TrainingSamplingParams.
Generation Configuration
nemo_rl/models/generation/vllm/vllm_generation.py
Adjusts global validation thresholds: TOP_K_THRESHOLD from 8000 to 1 (relaxes top-k validation) and TOP_P_THRESHOLD from 0.99 to 0.1 (tightens minimum top_p requirement).
Test Coverage
tests/unit/distributed/test_model_utils.py
Adds SamplingParamsTestActor class and new tests for top-k/top-p filtering validation, distributed logprob with sampling (including chunked variants), and gradient verification across multiple TP/CP configurations using Ray-based worker groups.
Test Updates
tests/unit/models/generation/test_vllm_generation.py
Adds test_vllm_policy_logprob_agreement_with_sampling function to validate logprob alignment between policy workers and vLLM across sampling configurations (temperature, top_p, top_k) for both dtensor and megatron policies at TP sizes 1 and 2. Includes training-step validation with KL penalties.
Test Alignment
tests/unit/models/generation/test_vllm_logprobs_mode.py
Updates apply_top_k_top_p call site to unpack the returned tuple (result, keep_mask) instead of treating it as a single return value.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

  • High logic density: New autograd functions with gradient masking in model_utils.py and complex combined top-k/top-p filtering logic in utils.py require careful verification of correctness.
  • Distributed system complexity: DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling involve TP group operations, all-to-all transformations, and backward pass reconstruction; bidirectional data flow and reshaping constraints demand thorough inspection.
  • Widespread parameter propagation: Sampling_params threading across 7+ files (loss functions, policy workers, model utilities) increases surface area; each integration point should be verified for consistency.
  • Mask handling edge cases: Masking safeguards (-inf → 0 replacements) in loss computation and gradient masking in autograd functions are critical for numerical stability.
  • Test completeness: New test coverage spans distributed and non-distributed paths, chunked variants, and sampling configurations; assertions and gradient checks should be validated.

Specific areas requiring attention:

  • ApplyTopKTopP autograd function backward pass and mask propagation logic
  • DistributedLogprobWithSampling.forward and DistributedLogprobWithSampling.backward implementations, especially all-to-all layout transformations
  • Masking safeguard consistency between loss functions and distributed logprob computation
  • Temperature scaling refactoring across three policy worker implementations
  • test_vllm_policy_logprob_agreement_with_sampling correctness (appears duplicated in diff; verify final state)

Possibly related PRs

Suggested labels

CI:L1

Suggested reviewers

  • terrykong
  • yuki-97
  • ashors1

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 74.14% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning PR introduces major training changes with new tests but provides no quantitative results, metrics, or threshold values in documentation. Document test pass/fail outcomes, actual agreement metrics (max/mean abs diff, multiplicative error), assertion tolerances, convergence metrics, and performance measurements.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'feat: Support top-p and top-k' directly matches the main objective: implementing top-k and top-p sampling during training.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhanda/top-p-k-with-tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@zhandaz zhandaz added the CI:L1 Run doctests, unit tests, and functional tests label Nov 27, 2025
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

♻️ Duplicate comments (2)
nemo_rl/models/policy/dtensor_policy_worker.py (2)

496-512: Shared temperature and top‑k/top‑p helpers look good

_apply_temperature_scaling and _apply_top_k_top_p_filtering here are identical in spirit to the v2 worker and correctly guard all work on self.sampling_params and need_*_filtering. Using them as the only sampling‑aware touchpoints in the worker simplifies reasoning about behavior. No functional issues.


1089-1091: DTensor get_logprobs: sampling integration mirrors v2 and looks correct

The get_logprobs changes closely mirror DTensorPolicyWorkerV2:

  • Temperature scaling via _apply_temperature_scaling(logits) before any TP/CP reshaping.
  • Passing sampling_params to get_logprobs_from_vocab_parallel_logits for both CP and non‑CP DTensor paths.
  • Using _apply_top_k_top_p_filtering only in the non‑DTensor full‑vocab path (optionally chunked along the sequence axis), followed by log_softmax and standard gathering of next‑token logprobs.

This is consistent with the distributed sampling utilities; just as with v2, please confirm get_logprobs_from_vocab_parallel_logits doesn’t also apply temperature scaling internally when sampling_params is present.

Also applies to: 1129-1135, 1140-1145, 1147-1176

🧹 Nitpick comments (10)
tests/unit/models/generation/test_vllm_generation.py (2)

2667-2674: Consider adding strict=True to zip() for defensive programming.

Adding strict=True will raise a ValueError if the iterables have different lengths, which helps catch bugs early if there's a length mismatch between input_lengths and unpadded_sequence_lengths.

     padding_mask = torch.zeros_like(generation_results["logprobs"], dtype=torch.bool)
     for i, (input_len, total_valid_len) in enumerate(
         zip(
             test_input_data.get("input_lengths"),
             generation_results["unpadded_sequence_lengths"],
+            strict=True,
         )
     ):

2728-2735: Consider adding strict=True to zip() here as well.

Same recommendation as above for consistency and defensive programming.

     for idx, (input_len, total_len) in enumerate(
         zip(
             test_input_data.get("input_lengths"),
             generation_results["unpadded_sequence_lengths"],
+            strict=True,
         )
     ):
nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)

177-185: Sampling params initialization and config defaults

Wiring TrainingSamplingParams from cfg["generation"] is coherent with the rest of the PR and keeps all sampling-related state in one object. One thing to watch: using generation_cfg.get("top_p", 1.0) / get("temperature", 1.0) introduces non-None defaults in code instead of treating these as required or YAML-sourced defaults, which slightly diverges from the stated “YAML is the single source of truth” guideline. If you want stricter config behavior, consider requiring these keys in YAML (or centralizing defaulting in a single helper rather than repeating literals here and in other workers).


755-758: Passing sampling_params into loss_fn_ is correct but can be API‑breaking

Using _apply_temperature_scaling before loss computation and threading sampling_params=self.sampling_params into loss_fn_ is consistent with the distributed logprob design and aligns training with inference sampling behavior.

However, this unconditionally adds a new keyword argument to every LossFunction used with this worker. Any custom loss not updated to accept sampling_params (or generic **kwargs) will now error with an unexpected keyword argument. Please confirm all concrete LossFunction implementations used in this worker are updated accordingly, or add a small adapter that only passes sampling_params when the loss advertises support.

Also applies to: 816-822


1330-1336: Temperature scaling in score() may be surprising for reward models

score() now always routes logits through _apply_temperature_scaling, which is driven by generation sampling config. For policy models this is arguably fine (you might think of scores under a temperature‑adjusted distribution), but for reward models (self._is_reward_model is true) you generally don’t want reward logits to depend on generation sampling hyperparameters.

If reward models are ever configured with a non‑default generation.temperature, their scores will now be scaled; if that’s not intended, consider guarding temperature scaling with if not self._is_reward_model or making it opt‑in for score().

nemo_rl/models/policy/megatron_policy_worker.py (1)

478-486: Sampling params wiring mirrors DTensor workers; same config‑default caveat

The sampling‑params construction here matches the DTensor worker path and ensures Megatron training/logprob/top‑k flows see the same top_k/top_p/temperature as generation.

As in the DTensor worker, using generation_cfg.get("top_p", 1.0) / get("temperature", 1.0) introduces code‑level defaults rather than treating YAML as the single source of truth for non‑None defaults. If you want tighter config validation, consider either:

  • Requiring these keys in PolicyConfig (and indexing directly), or
  • Centralizing defaulting into a small helper (e.g., build_training_sampling_params(cfg)) used by all workers instead of repeating literals.
nemo_rl/models/policy/dtensor_policy_worker.py (3)

216-225: Sampling params init matches v2 worker; consider centralizing + config defaults

This is the same pattern as in DTensorPolicyWorkerV2: constructing TrainingSamplingParams from cfg["generation"] and defaulting top_p / temperature to 1.0 in code. Behavior‑wise this is fine, but it duplicates logic across workers and again introduces non‑None defaults outside YAML.

Consider extracting a small helper (e.g., build_training_sampling_params(policy_cfg)) in policy.utils and using it in both DTensor workers and the Megatron worker, and/or shifting these defaults fully into configuration to keep with the “YAML as single source of truth” guideline.


840-846: LossFunction kwargs: sampling_params threading may affect custom losses

As in the v2 worker, passing sampling_params=self.sampling_params into loss_fn_ is the right direction for built‑in losses, but it does change the expected call signature. Any external or custom LossFunction not updated to accept sampling_params (or generic **kwargs) will now fail.

If you expect third‑party losses here, you may want an adapter layer or a capability flag on LossFunction to make this more graceful; otherwise, documenting that sampling_params is now part of the LossFunction.__call__ contract should be enough.


1348-1350: score() temperature scaling and reward models

Same concern as in the v2 worker: score() now always applies _apply_temperature_scaling, tying reward/final scores to generation sampling temperature when self.sampling_params is set. If reward models or scalar scoring flows should remain independent of generation sampling hyperparameters, consider skipping temperature scaling here for those cases (e.g., if not self._is_reward_model or a dedicated flag).

nemo_rl/models/policy/utils.py (1)

255-281: Type annotation mismatch with callers.

The signature declares top_p: float but callers in loss_functions.py may pass None when sampling_params is None. While need_top_p_filtering handles None, the type annotation suggests float is expected.

Consider updating the type hint for consistency:

 def apply_top_k_top_p(
     logits: torch.Tensor,
     top_k: int | None,
-    top_p: float,
+    top_p: float | None,
 ) -> tuple[torch.Tensor, torch.Tensor | None]:

Or ensure callers always pass 1.0 as the default (preferred, as suggested in loss_functions.py comments).

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b772e48 and d584519.

📒 Files selected for processing (11)
  • nemo_rl/algorithms/loss_functions.py (16 hunks)
  • nemo_rl/distributed/model_utils.py (15 hunks)
  • nemo_rl/models/generation/vllm/vllm_generation.py (1 hunks)
  • nemo_rl/models/megatron/common.py (5 hunks)
  • nemo_rl/models/policy/dtensor_policy_worker.py (7 hunks)
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py (7 hunks)
  • nemo_rl/models/policy/megatron_policy_worker.py (7 hunks)
  • nemo_rl/models/policy/utils.py (5 hunks)
  • tests/unit/distributed/test_model_utils.py (3 hunks)
  • tests/unit/models/generation/test_vllm_generation.py (1 hunks)
  • tests/unit/models/generation/test_vllm_logprobs_mode.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/megatron/common.py
  • tests/unit/models/generation/test_vllm_logprobs_mode.py
  • tests/unit/distributed/test_model_utils.py
  • tests/unit/models/generation/test_vllm_generation.py
  • nemo_rl/models/policy/utils.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
  • nemo_rl/algorithms/loss_functions.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/distributed/model_utils.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/megatron/common.py
  • nemo_rl/models/policy/utils.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
  • nemo_rl/algorithms/loss_functions.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/distributed/model_utils.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/megatron/common.py
  • tests/unit/models/generation/test_vllm_logprobs_mode.py
  • tests/unit/distributed/test_model_utils.py
  • tests/unit/models/generation/test_vllm_generation.py
  • nemo_rl/models/policy/utils.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
  • nemo_rl/algorithms/loss_functions.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/distributed/model_utils.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/models/megatron/common.py
  • tests/unit/models/generation/test_vllm_logprobs_mode.py
  • tests/unit/distributed/test_model_utils.py
  • tests/unit/models/generation/test_vllm_generation.py
  • nemo_rl/models/policy/utils.py
  • nemo_rl/models/policy/megatron_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py
  • nemo_rl/algorithms/loss_functions.py
  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/distributed/model_utils.py
🧠 Learnings (2)
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.

Applied to files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
📚 Learning: 2025-09-19T03:19:35.875Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/unit/algorithms/test_loss_functions.py:1580-1619
Timestamp: 2025-09-19T03:19:35.875Z
Learning: The DistillationLossFn in nemo_rl/algorithms/loss_functions.py does not have k truncation logic - it processes whatever topk size is provided without capping it to vocabulary size or other limits. Large k values in tests will create correspondingly large GPU tensors.

Applied to files:

  • nemo_rl/models/generation/vllm/vllm_generation.py
  • nemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (6)
nemo_rl/models/megatron/common.py (1)
nemo_rl/models/policy/utils.py (1)
  • TrainingSamplingParams (96-110)
tests/unit/models/generation/test_vllm_logprobs_mode.py (1)
nemo_rl/models/policy/utils.py (1)
  • apply_top_k_top_p (255-281)
tests/unit/distributed/test_model_utils.py (5)
nemo_rl/distributed/model_utils.py (9)
  • DistributedLogprob (66-147)
  • DistributedLogprobWithSampling (265-392)
  • backward (108-147)
  • backward (217-262)
  • backward (358-392)
  • backward (526-608)
  • backward (670-734)
  • backward (1185-1214)
  • backward (1468-1498)
nemo_rl/models/policy/utils.py (2)
  • apply_top_k_top_p (255-281)
  • backward (243-252)
nemo_rl/distributed/virtual_cluster.py (2)
  • PY_EXECUTABLES (43-59)
  • RayVirtualCluster (186-505)
nemo_rl/distributed/named_sharding.py (3)
  • NamedSharding (19-222)
  • layout (99-101)
  • names (84-86)
nemo_rl/distributed/worker_groups.py (2)
  • RayWorkerBuilder (131-301)
  • RayWorkerGroup (304-1031)
nemo_rl/algorithms/loss_functions.py (2)
nemo_rl/models/policy/utils.py (4)
  • apply_top_k_top_p (255-281)
  • TrainingSamplingParams (96-110)
  • need_top_k_filtering (85-87)
  • need_top_p_filtering (90-92)
nemo_rl/distributed/model_utils.py (3)
  • from_parallel_logits_to_logprobs (851-952)
  • gather_logits_at_global_indices (1341-1424)
  • get_logprobs_from_vocab_parallel_logits (1217-1267)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
nemo_rl/models/policy/utils.py (4)
  • TrainingSamplingParams (96-110)
  • apply_top_k_top_p (255-281)
  • need_top_k_filtering (85-87)
  • need_top_p_filtering (90-92)
nemo_rl/distributed/model_utils.py (1)
nemo_rl/models/policy/utils.py (3)
  • TrainingSamplingParams (96-110)
  • need_top_k_filtering (85-87)
  • need_top_p_filtering (90-92)
🪛 Ruff (0.14.6)
tests/unit/distributed/test_model_utils.py

1058-1058: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)

tests/unit/models/generation/test_vllm_generation.py

2603-2603: Avoid specifying long messages outside the exception class

(TRY003)


2668-2671: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


2729-2732: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

nemo_rl/algorithms/loss_functions.py

1014-1016: Avoid specifying long messages outside the exception class

(TRY003)

nemo_rl/distributed/model_utils.py

301-304: Avoid specifying long messages outside the exception class

(TRY003)


432-435: Avoid specifying long messages outside the exception class

(TRY003)


444-447: Avoid specifying long messages outside the exception class

(TRY003)


1523-1526: Avoid specifying long messages outside the exception class

(TRY003)


1532-1535: Avoid specifying long messages outside the exception class

(TRY003)


1573-1576: Avoid specifying long messages outside the exception class

(TRY003)


1582-1584: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Coverage (e2e)
  • GitHub Check: Coverage (doc-test)
  • GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (34)
tests/unit/models/generation/test_vllm_logprobs_mode.py (1)

258-258: LGTM!

The tuple unpacking correctly matches the updated apply_top_k_top_p API that now returns (filtered_logits, keep_mask). Discarding the mask with _ is appropriate here since the test only needs the filtered logits for comparison with vLLM's upstream implementation.

tests/unit/distributed/test_model_utils.py (6)

23-25: LGTM!

New imports correctly align with the sampling-enabled distributed logprob utilities introduced in this PR.


39-41: LGTM!

Import of apply_top_k_top_p from the policy utils module is appropriate for testing the filtering functionality.


966-1092: Well-structured test for top-k/top-p filtering.

The test_top_k_top_p_filtering_forward_backward method thoroughly validates:

  1. Top-k filtering preserves correct tokens
  2. Top-p filtering respects cumulative probability threshold
  3. No-filtering case passes through unchanged
  4. Filtered logits produce valid probabilities
  5. Backward pass correctly masks gradients

Good coverage of the filtering logic and gradient propagation.


1093-1200: Comprehensive distributed logprob test with sampling.

The test_distributed_logprob_with_sampling method correctly validates:

  1. Forward pass matches expected computation with top-k/top-p filtering
  2. Backward pass gradients match expected values
  3. Both chunked and non-chunked variants are tested

The test appropriately computes the expected values using full logits and validates the distributed implementation produces equivalent results.


1228-1277: Good parametrized test coverage.

The test matrix covers:

  • tp_size: [1, 2] - validates single-GPU and distributed cases
  • top_k, top_p: Various combinations including no filtering, top-k only, top-p only, and combined

This provides good coverage for the filtering functionality across different parallelism configurations.


1280-1344: Comprehensive distributed sampling test.

Good test coverage for DistributedLogprobWithSampling and ChunkedDistributedLogprobWithSampling with:

  • Multiple sampling parameter combinations
  • Chunked and non-chunked modes
  • Validation of both forward and backward passes
tests/unit/models/generation/test_vllm_generation.py (3)

2504-2523: Excellent test for validating sampling parameter consistency.

This test comprehensively validates that policy workers produce logprobs matching vLLM when using the same sampling parameters. The parametrization covers various configurations:

  • Temperature scaling only
  • Top-p only
  • Top-k only
  • Combined sampling

Testing both dtensor and megatron policy types with TP=1 and TP=2 provides good coverage.


2739-2773: Well-designed KL validation approach.

Using reference_policy_kl_penalty=1.0 with ClippedPGLossFn is a clever way to validate that the train path computes logprobs correctly with sampling parameters. If the logprobs match vLLM's generation logprobs, the KL should be near zero. The assertion threshold of 1e-4 is appropriately strict.


2789-2799: Good validation of KL penalty across microbatches.

The loop checking each KL penalty value in the list ensures all microbatches have consistent logprob computation. The detailed error message helps debugging if the assertion fails.

nemo_rl/models/megatron/common.py (4)

34-34: LGTM!

Clean import of the new TrainingSamplingParams dataclass that encapsulates sampling configuration.


264-278: Good API improvement.

Replacing the generic policy_cfg: Optional[dict] with the specific TrainingSamplingParams dataclass improves type safety and makes the function signature clearer. The updated docstring correctly documents the parameter's purpose.


368-369: LGTM: Correct temperature scaling placement.

Temperature scaling is correctly applied to the logits before loss computation. The in-place division (div_) is appropriate here since we want to modify the output tensor that will be used for subsequent loss calculation. The guard condition correctly skips scaling when temperature is 1.0.


390-390: Correct propagation of sampling_params to loss function.

The sampling_params are correctly passed to the loss function wrapper, enabling top-k/top-p filtering during distributed logprob computation in the loss function.

nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)

473-489: Helper methods for temperature and top‑k/top‑p filtering look correct

Centralizing temperature scaling in _apply_temperature_scaling and local full‑vocab filtering in _apply_top_k_top_p_filtering is clean. In‑place div_ on logits is fine here (including DTensor) and apply_top_k_top_p is only used when either need_top_k_filtering or need_top_p_filtering is true, so no extra overhead in the default case. No issues from a correctness or style perspective.


1071-1073: Logprob path: sampling integration across DTensor and non‑DTensor looks sound

The changes in get_logprobs correctly:

  • Apply temperature scaling once at the logits level via _apply_temperature_scaling.
  • Pass sampling_params into get_logprobs_from_vocab_parallel_logits for TP/CP/DTensor paths.
  • For the non‑DTensor path, apply _apply_top_k_top_p_filtering on full‑vocab logits (or chunked along the sequence dimension only), then compute log_softmax and gather token logprobs.

This respects the requirement that top‑k/top‑p operate on full vocabulary logits while keeping chunking limited to the sequence axis, so semantics are preserved. I don’t see numerical or shape issues here; just ensure get_logprobs_from_vocab_parallel_logits itself does not also apply temperature scaling to avoid double‑dividing.

Also applies to: 1111-1118, 1121-1128, 1131-1160


1484-1486: Temperature scaling before top‑k logits is consistent with training

Applying _apply_temperature_scaling before running distributed/local top‑k keeps this utility aligned with the temperature‑scaled training/logprob behavior without changing the meaning of k or the indices returned. This looks correct and requires no additional changes.

nemo_rl/models/policy/megatron_policy_worker.py (4)

128-133: Importing TrainingSamplingParams here is appropriate

Bringing TrainingSamplingParams into the Megatron worker keeps the sampling configuration plumbing symmetric with the DTensor workers and centralizes sampling‑related types in policy.utils. No concerns.


993-997: forward_step_arbitrary_loss now depends on sampling_params

Passing sampling_params=self.sampling_params into forward_step_arbitrary_loss is the right way to thread sampling behavior into the Megatron training step and keep it aligned with the DTensor implementation.

This does, however, hard‑require that forward_step_arbitrary_loss (and any intermediate wrappers it calls) accept a sampling_params keyword. Please double‑check that all call sites and the function signature in nemo_rl.models.megatron.common are updated accordingly so this doesn’t break older code paths.


1316-1321: Megatron logprob path: temperature and sampling propagation look coherent

Inside get_logprobs’s forward_step_fn:

  • Applying output_tensor.div_(self.sampling_params.temperature) when a non‑unit temperature is set matches the DTensor worker behavior and ensures TP‑sharded logits are already temperature‑scaled.
  • Passing sampling_params=self.sampling_params down to both from_parallel_logits_to_logprobs_packed_sequences and from_parallel_logits_to_logprobs lets those helpers apply top‑k/top‑p masking on full‑vocab probabilities while this function remains agnostic to TP/CP details.

This wiring looks correct; just ensure the from_parallel_logits_to_logprobs* helpers themselves do not additionally apply temperature scaling when sampling_params is provided, otherwise you’d effectively divide by temperature twice here.

Also applies to: 1327-1351


1602-1608: Temperature scaling in Megatron get_topk_logits

Scaling output_tensor by 1 / temperature before calling distributed_vocab_topk mirrors the training/logprob behavior and keeps top‑k logits consistent with the effective sampling distribution. Since top‑k operates along the vocab dimension only, this is safe and does not alter the semantics of k or the returned indices.

nemo_rl/models/policy/dtensor_policy_worker.py (1)

1520-1522: Temperature scaling before top‑k logits (DTensor worker)

Scaling logits via _apply_temperature_scaling before the top‑k computation is consistent with both the Megatron and v2 DTensor workers and matches the intended “train under the same effective sampling distribution” behavior. Implementation is straightforward and correct.

nemo_rl/algorithms/loss_functions.py (2)

27-36: LGTM on new imports.

The imports for apply_top_k_top_p, TrainingSamplingParams, need_top_k_filtering, and need_top_p_filtering are appropriately added to support the new sampling functionality.


279-282: Correct NaN prevention with masking.

Good defensive fix to handle cases where top-k/p filtering masks out padding tokens, resulting in -inf logprobs that would otherwise produce NaN when multiplied by zero masks.

nemo_rl/models/policy/utils.py (4)

85-92: LGTM on filtering check helpers.

Clean and simple helper functions that correctly identify when filtering should be applied. The logic properly handles None, -1 for top-k, and 1.0 for top-p as "no filtering" cases.


95-110: LGTM on TrainingSamplingParams dataclass.

Well-documented dataclass with appropriate defaults that disable filtering by default. The design avoids vLLM dependency while maintaining API consistency.


113-152: LGTM on top-k only filtering function.

Efficient implementation that avoids full vocabulary sorting when only top-k is needed. The threshold-based masking is correct and the keep_mask return enables proper gradient handling.


224-252: LGTM on ApplyTopKTopP autograd function structure.

The autograd function correctly saves the keep_mask during forward and applies it during backward to zero out gradients for filtered tokens. However, the correctness depends on fixing the keep_mask ordering issue in _apply_top_k_top_p_fn.

nemo_rl/distributed/model_utils.py (6)

20-25: LGTM on new imports.

The imports correctly bring in the sampling utilities needed for the distributed implementation.


265-355: LGTM on DistributedLogprobWithSampling forward pass.

The implementation correctly:

  1. Validates divisibility constraints
  2. Converts from vocab-parallel to batch-sequence-parallel layout via all-to-all
  3. Applies top-k/top-p filtering on the materialized full vocabulary
  4. Computes log probabilities and gathers results across ranks

The layout transformation is necessary because top-p filtering requires access to the full vocabulary distribution.


357-392: LGTM on backward pass structure.

The backward pass correctly:

  1. Extracts the local portion of gradients
  2. Computes the standard cross-entropy gradient (one_hot - softmax)
  3. Applies the keep_mask to zero gradients for filtered tokens
  4. Converts back to vocab-parallel layout

However, this depends on the keep_mask bug fix in _apply_top_k_top_p_fn.


509-522: Memory consideration: saving full logits for rematerialization.

The chunked version saves the full vocab_parallel_logits tensor for backward rematerialization. While this is correct for the gradient computation, it's worth noting this may use significant memory for large vocabulary sizes. The non-chunked version saves just the softmax output which is smaller when vocab > batch*seq.

This is an acceptable trade-off for chunked processing where the goal is to reduce peak memory during computation.


1501-1601: LGTM on all-to-all layout transformation helpers.

Well-documented functions with clear einops-style comments explaining the transformations:

  • vp2sq: (BS, V_local) -> (BS_local, V) - materializes full vocabulary
  • sq2vp: (BS_local, V) -> (BS, V_local) - inverse operation

The divisibility checks and error messages are appropriate.


794-840: LGTM on sampling path integration.

The routing logic correctly:

  1. Checks if top-k or top-p filtering is needed
  2. Routes to the appropriate sampling-enabled or standard autograd function
  3. Maintains backward compatibility when no sampling is configured

The TODO comment about optimizing top-k separately is a reasonable future enhancement.

Comment thread nemo_rl/algorithms/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss_functions.py Outdated
Comment thread nemo_rl/algorithms/loss_functions.py Outdated
Comment thread nemo_rl/distributed/model_utils.py
Comment thread nemo_rl/models/generation/vllm/vllm_generation.py Outdated
Comment thread nemo_rl/models/policy/utils.py Outdated
Comment thread tests/unit/distributed/test_model_utils.py Outdated
Signed-off-by: Zhanda <zhandazhu@gmail.com>
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: f46d960 (PR #1578 from zhanda/top-p-k-with-tests)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz zhandaz added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Nov 28, 2025
Signed-off-by: Zhanda <zhandazhu@gmail.com>
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Dec 1, 2025

ℹ️ File Consistency Check

Check based on commit: 452b2d9 (PR #1578 from zhanda/top-p-k-with-tests)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz zhandaz added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 1, 2025
Signed-off-by: Zhanda <zhandazhu@gmail.com>
@zhandaz zhandaz added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Dec 1, 2025
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Dec 1, 2025

ℹ️ File Consistency Check

Check based on commit: d76ed60 (PR #1578 from zhanda/top-p-k-with-tests)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome work @zhandaz !

@gshennvm can you review?

@zhandaz can you include convergence plots? maybe at least small qwen out to a few hundred steps to show no collapse and low gen_kl_error? i think we should have at least one nightly test that tests T!=1 & non trivial top-p | top-k, maybe a variant of the qwen30b recipe with just those things changed so that it's comparable?

@terrykong terrykong linked an issue Dec 4, 2025 that may be closed by this pull request
@zhandaz
Copy link
Copy Markdown
Contributor Author

zhandaz commented Dec 4, 2025

Sure. I will add that in this PR as well.

zhandaz added a commit that referenced this pull request Feb 12, 2026
@zhandaz zhandaz self-assigned this Feb 12, 2026
@zhandaz zhandaz mentioned this pull request Feb 12, 2026
4 tasks
zhandaz added a commit that referenced this pull request Feb 12, 2026
Signed-off-by: Zhanda <zhandazhu@gmail.com>
yuki-97 pushed a commit that referenced this pull request Mar 3, 2026
(not cp loss_function.py common.py megatron_policy_worker.py)

Signed-off-by: Zhanda <zhandazhu@gmail.com>
yuki-97 pushed a commit that referenced this pull request Mar 3, 2026
(not cp loss_function.py common.py megatron_policy_worker.py)

Signed-off-by: Zhanda <zhandazhu@gmail.com>
yuki-97 pushed a commit that referenced this pull request Mar 8, 2026
(not cp loss_function.py common.py megatron_policy_worker.py)

Signed-off-by: Zhanda <zhandazhu@gmail.com>
yuki-97 pushed a commit that referenced this pull request Mar 10, 2026
(not cp loss_function.py common.py megatron_policy_worker.py)

Signed-off-by: Zhanda <zhandazhu@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Top-p/Top-k Sampling Params handling in VLLM v1

2 participants