Conversation
Signed-off-by: Zhanda <zhandazhu@gmail.com>
f45b990 to
d584519
Compare
ℹ️ File Consistency CheckCheck based on commit: f45b990 (PR #1578 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: d584519 (PR #1578 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
📝 WalkthroughWalkthroughThis PR integrates top-k/top-p generation sampling filtering throughout NeMo RL's training pipeline. It introduces Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Specific areas requiring attention:
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
♻️ Duplicate comments (2)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
496-512: Shared temperature and top‑k/top‑p helpers look good
_apply_temperature_scalingand_apply_top_k_top_p_filteringhere are identical in spirit to the v2 worker and correctly guard all work onself.sampling_paramsandneed_*_filtering. Using them as the only sampling‑aware touchpoints in the worker simplifies reasoning about behavior. No functional issues.
1089-1091: DTensor get_logprobs: sampling integration mirrors v2 and looks correctThe
get_logprobschanges closely mirrorDTensorPolicyWorkerV2:
- Temperature scaling via
_apply_temperature_scaling(logits)before any TP/CP reshaping.- Passing
sampling_paramstoget_logprobs_from_vocab_parallel_logitsfor both CP and non‑CP DTensor paths.- Using
_apply_top_k_top_p_filteringonly in the non‑DTensor full‑vocab path (optionally chunked along the sequence axis), followed bylog_softmaxand standard gathering of next‑token logprobs.This is consistent with the distributed sampling utilities; just as with v2, please confirm
get_logprobs_from_vocab_parallel_logitsdoesn’t also apply temperature scaling internally whensampling_paramsis present.Also applies to: 1129-1135, 1140-1145, 1147-1176
🧹 Nitpick comments (10)
tests/unit/models/generation/test_vllm_generation.py (2)
2667-2674: Consider addingstrict=Trueto zip() for defensive programming.Adding
strict=Truewill raise aValueErrorif the iterables have different lengths, which helps catch bugs early if there's a length mismatch betweeninput_lengthsandunpadded_sequence_lengths.padding_mask = torch.zeros_like(generation_results["logprobs"], dtype=torch.bool) for i, (input_len, total_valid_len) in enumerate( zip( test_input_data.get("input_lengths"), generation_results["unpadded_sequence_lengths"], + strict=True, ) ):
2728-2735: Consider addingstrict=Trueto zip() here as well.Same recommendation as above for consistency and defensive programming.
for idx, (input_len, total_len) in enumerate( zip( test_input_data.get("input_lengths"), generation_results["unpadded_sequence_lengths"], + strict=True, ) ):nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
177-185: Sampling params initialization and config defaultsWiring
TrainingSamplingParamsfromcfg["generation"]is coherent with the rest of the PR and keeps all sampling-related state in one object. One thing to watch: usinggeneration_cfg.get("top_p", 1.0)/get("temperature", 1.0)introduces non-None defaults in code instead of treating these as required or YAML-sourced defaults, which slightly diverges from the stated “YAML is the single source of truth” guideline. If you want stricter config behavior, consider requiring these keys in YAML (or centralizing defaulting in a single helper rather than repeating literals here and in other workers).
755-758: Passing sampling_params into loss_fn_ is correct but can be API‑breakingUsing
_apply_temperature_scalingbefore loss computation and threadingsampling_params=self.sampling_paramsintoloss_fn_is consistent with the distributed logprob design and aligns training with inference sampling behavior.However, this unconditionally adds a new keyword argument to every
LossFunctionused with this worker. Any custom loss not updated to acceptsampling_params(or generic**kwargs) will now error with an unexpected keyword argument. Please confirm all concreteLossFunctionimplementations used in this worker are updated accordingly, or add a small adapter that only passessampling_paramswhen the loss advertises support.Also applies to: 816-822
1330-1336: Temperature scaling in score() may be surprising for reward models
score()now always routes logits through_apply_temperature_scaling, which is driven bygenerationsampling config. For policy models this is arguably fine (you might think of scores under a temperature‑adjusted distribution), but for reward models (self._is_reward_modelis true) you generally don’t want reward logits to depend on generation sampling hyperparameters.If reward models are ever configured with a non‑default
generation.temperature, their scores will now be scaled; if that’s not intended, consider guarding temperature scaling withif not self._is_reward_modelor making it opt‑in forscore().nemo_rl/models/policy/megatron_policy_worker.py (1)
478-486: Sampling params wiring mirrors DTensor workers; same config‑default caveatThe sampling‑params construction here matches the DTensor worker path and ensures Megatron training/logprob/top‑k flows see the same
top_k/top_p/temperatureas generation.As in the DTensor worker, using
generation_cfg.get("top_p", 1.0)/get("temperature", 1.0)introduces code‑level defaults rather than treating YAML as the single source of truth for non‑None defaults. If you want tighter config validation, consider either:
- Requiring these keys in
PolicyConfig(and indexing directly), or- Centralizing defaulting into a small helper (e.g.,
build_training_sampling_params(cfg)) used by all workers instead of repeating literals.nemo_rl/models/policy/dtensor_policy_worker.py (3)
216-225: Sampling params init matches v2 worker; consider centralizing + config defaultsThis is the same pattern as in
DTensorPolicyWorkerV2: constructingTrainingSamplingParamsfromcfg["generation"]and defaultingtop_p/temperatureto1.0in code. Behavior‑wise this is fine, but it duplicates logic across workers and again introduces non‑None defaults outside YAML.Consider extracting a small helper (e.g.,
build_training_sampling_params(policy_cfg)) inpolicy.utilsand using it in both DTensor workers and the Megatron worker, and/or shifting these defaults fully into configuration to keep with the “YAML as single source of truth” guideline.
840-846: LossFunction kwargs: sampling_params threading may affect custom lossesAs in the v2 worker, passing
sampling_params=self.sampling_paramsintoloss_fn_is the right direction for built‑in losses, but it does change the expected call signature. Any external or customLossFunctionnot updated to acceptsampling_params(or generic**kwargs) will now fail.If you expect third‑party losses here, you may want an adapter layer or a capability flag on
LossFunctionto make this more graceful; otherwise, documenting thatsampling_paramsis now part of theLossFunction.__call__contract should be enough.
1348-1350: score() temperature scaling and reward modelsSame concern as in the v2 worker:
score()now always applies_apply_temperature_scaling, tying reward/final scores to generation sampling temperature whenself.sampling_paramsis set. If reward models or scalar scoring flows should remain independent of generation sampling hyperparameters, consider skipping temperature scaling here for those cases (e.g.,if not self._is_reward_modelor a dedicated flag).nemo_rl/models/policy/utils.py (1)
255-281: Type annotation mismatch with callers.The signature declares
top_p: floatbut callers inloss_functions.pymay passNonewhensampling_paramsisNone. Whileneed_top_p_filteringhandlesNone, the type annotation suggestsfloatis expected.Consider updating the type hint for consistency:
def apply_top_k_top_p( logits: torch.Tensor, top_k: int | None, - top_p: float, + top_p: float | None, ) -> tuple[torch.Tensor, torch.Tensor | None]:Or ensure callers always pass
1.0as the default (preferred, as suggested in loss_functions.py comments).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
nemo_rl/algorithms/loss_functions.py(16 hunks)nemo_rl/distributed/model_utils.py(15 hunks)nemo_rl/models/generation/vllm/vllm_generation.py(1 hunks)nemo_rl/models/megatron/common.py(5 hunks)nemo_rl/models/policy/dtensor_policy_worker.py(7 hunks)nemo_rl/models/policy/dtensor_policy_worker_v2.py(7 hunks)nemo_rl/models/policy/megatron_policy_worker.py(7 hunks)nemo_rl/models/policy/utils.py(5 hunks)tests/unit/distributed/test_model_utils.py(3 hunks)tests/unit/models/generation/test_vllm_generation.py(1 hunks)tests/unit/models/generation/test_vllm_logprobs_mode.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/megatron/common.pytests/unit/models/generation/test_vllm_logprobs_mode.pytests/unit/distributed/test_model_utils.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/algorithms/loss_functions.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/distributed/model_utils.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/megatron/common.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/algorithms/loss_functions.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/distributed/model_utils.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/megatron/common.pytests/unit/models/generation/test_vllm_logprobs_mode.pytests/unit/distributed/test_model_utils.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/algorithms/loss_functions.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/distributed/model_utils.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/megatron/common.pytests/unit/models/generation/test_vllm_logprobs_mode.pytests/unit/distributed/test_model_utils.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/utils.pynemo_rl/models/policy/megatron_policy_worker.pynemo_rl/models/policy/dtensor_policy_worker_v2.pynemo_rl/algorithms/loss_functions.pynemo_rl/models/policy/dtensor_policy_worker.pynemo_rl/distributed/model_utils.py
🧠 Learnings (2)
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.
Applied to files:
nemo_rl/models/generation/vllm/vllm_generation.py
📚 Learning: 2025-09-19T03:19:35.875Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/unit/algorithms/test_loss_functions.py:1580-1619
Timestamp: 2025-09-19T03:19:35.875Z
Learning: The DistillationLossFn in nemo_rl/algorithms/loss_functions.py does not have k truncation logic - it processes whatever topk size is provided without capping it to vocabulary size or other limits. Large k values in tests will create correspondingly large GPU tensors.
Applied to files:
nemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/loss_functions.py
🧬 Code graph analysis (6)
nemo_rl/models/megatron/common.py (1)
nemo_rl/models/policy/utils.py (1)
TrainingSamplingParams(96-110)
tests/unit/models/generation/test_vllm_logprobs_mode.py (1)
nemo_rl/models/policy/utils.py (1)
apply_top_k_top_p(255-281)
tests/unit/distributed/test_model_utils.py (5)
nemo_rl/distributed/model_utils.py (9)
DistributedLogprob(66-147)DistributedLogprobWithSampling(265-392)backward(108-147)backward(217-262)backward(358-392)backward(526-608)backward(670-734)backward(1185-1214)backward(1468-1498)nemo_rl/models/policy/utils.py (2)
apply_top_k_top_p(255-281)backward(243-252)nemo_rl/distributed/virtual_cluster.py (2)
PY_EXECUTABLES(43-59)RayVirtualCluster(186-505)nemo_rl/distributed/named_sharding.py (3)
NamedSharding(19-222)layout(99-101)names(84-86)nemo_rl/distributed/worker_groups.py (2)
RayWorkerBuilder(131-301)RayWorkerGroup(304-1031)
nemo_rl/algorithms/loss_functions.py (2)
nemo_rl/models/policy/utils.py (4)
apply_top_k_top_p(255-281)TrainingSamplingParams(96-110)need_top_k_filtering(85-87)need_top_p_filtering(90-92)nemo_rl/distributed/model_utils.py (3)
from_parallel_logits_to_logprobs(851-952)gather_logits_at_global_indices(1341-1424)get_logprobs_from_vocab_parallel_logits(1217-1267)
nemo_rl/models/policy/dtensor_policy_worker.py (1)
nemo_rl/models/policy/utils.py (4)
TrainingSamplingParams(96-110)apply_top_k_top_p(255-281)need_top_k_filtering(85-87)need_top_p_filtering(90-92)
nemo_rl/distributed/model_utils.py (1)
nemo_rl/models/policy/utils.py (3)
TrainingSamplingParams(96-110)need_top_k_filtering(85-87)need_top_p_filtering(90-92)
🪛 Ruff (0.14.6)
tests/unit/distributed/test_model_utils.py
1058-1058: Do not assert False (python -O removes these calls), raise AssertionError()
Replace assert False
(B011)
tests/unit/models/generation/test_vllm_generation.py
2603-2603: Avoid specifying long messages outside the exception class
(TRY003)
2668-2671: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
2729-2732: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
nemo_rl/algorithms/loss_functions.py
1014-1016: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/distributed/model_utils.py
301-304: Avoid specifying long messages outside the exception class
(TRY003)
432-435: Avoid specifying long messages outside the exception class
(TRY003)
444-447: Avoid specifying long messages outside the exception class
(TRY003)
1523-1526: Avoid specifying long messages outside the exception class
(TRY003)
1532-1535: Avoid specifying long messages outside the exception class
(TRY003)
1573-1576: Avoid specifying long messages outside the exception class
(TRY003)
1582-1584: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Coverage (e2e)
- GitHub Check: Coverage (doc-test)
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (34)
tests/unit/models/generation/test_vllm_logprobs_mode.py (1)
258-258: LGTM!The tuple unpacking correctly matches the updated
apply_top_k_top_pAPI that now returns(filtered_logits, keep_mask). Discarding the mask with_is appropriate here since the test only needs the filtered logits for comparison with vLLM's upstream implementation.tests/unit/distributed/test_model_utils.py (6)
23-25: LGTM!New imports correctly align with the sampling-enabled distributed logprob utilities introduced in this PR.
39-41: LGTM!Import of
apply_top_k_top_pfrom the policy utils module is appropriate for testing the filtering functionality.
966-1092: Well-structured test for top-k/top-p filtering.The
test_top_k_top_p_filtering_forward_backwardmethod thoroughly validates:
- Top-k filtering preserves correct tokens
- Top-p filtering respects cumulative probability threshold
- No-filtering case passes through unchanged
- Filtered logits produce valid probabilities
- Backward pass correctly masks gradients
Good coverage of the filtering logic and gradient propagation.
1093-1200: Comprehensive distributed logprob test with sampling.The
test_distributed_logprob_with_samplingmethod correctly validates:
- Forward pass matches expected computation with top-k/top-p filtering
- Backward pass gradients match expected values
- Both chunked and non-chunked variants are tested
The test appropriately computes the expected values using full logits and validates the distributed implementation produces equivalent results.
1228-1277: Good parametrized test coverage.The test matrix covers:
tp_size: [1, 2] - validates single-GPU and distributed casestop_k, top_p: Various combinations including no filtering, top-k only, top-p only, and combinedThis provides good coverage for the filtering functionality across different parallelism configurations.
1280-1344: Comprehensive distributed sampling test.Good test coverage for
DistributedLogprobWithSamplingandChunkedDistributedLogprobWithSamplingwith:
- Multiple sampling parameter combinations
- Chunked and non-chunked modes
- Validation of both forward and backward passes
tests/unit/models/generation/test_vllm_generation.py (3)
2504-2523: Excellent test for validating sampling parameter consistency.This test comprehensively validates that policy workers produce logprobs matching vLLM when using the same sampling parameters. The parametrization covers various configurations:
- Temperature scaling only
- Top-p only
- Top-k only
- Combined sampling
Testing both
dtensorandmegatronpolicy types with TP=1 and TP=2 provides good coverage.
2739-2773: Well-designed KL validation approach.Using
reference_policy_kl_penalty=1.0withClippedPGLossFnis a clever way to validate that the train path computes logprobs correctly with sampling parameters. If the logprobs match vLLM's generation logprobs, the KL should be near zero. The assertion threshold of1e-4is appropriately strict.
2789-2799: Good validation of KL penalty across microbatches.The loop checking each KL penalty value in the list ensures all microbatches have consistent logprob computation. The detailed error message helps debugging if the assertion fails.
nemo_rl/models/megatron/common.py (4)
34-34: LGTM!Clean import of the new
TrainingSamplingParamsdataclass that encapsulates sampling configuration.
264-278: Good API improvement.Replacing the generic
policy_cfg: Optional[dict]with the specificTrainingSamplingParamsdataclass improves type safety and makes the function signature clearer. The updated docstring correctly documents the parameter's purpose.
368-369: LGTM: Correct temperature scaling placement.Temperature scaling is correctly applied to the logits before loss computation. The in-place division (
div_) is appropriate here since we want to modify the output tensor that will be used for subsequent loss calculation. The guard condition correctly skips scaling when temperature is 1.0.
390-390: Correct propagation of sampling_params to loss function.The sampling_params are correctly passed to the loss function wrapper, enabling top-k/top-p filtering during distributed logprob computation in the loss function.
nemo_rl/models/policy/dtensor_policy_worker_v2.py (3)
473-489: Helper methods for temperature and top‑k/top‑p filtering look correctCentralizing temperature scaling in
_apply_temperature_scalingand local full‑vocab filtering in_apply_top_k_top_p_filteringis clean. In‑placediv_on logits is fine here (including DTensor) andapply_top_k_top_pis only used when eitherneed_top_k_filteringorneed_top_p_filteringis true, so no extra overhead in the default case. No issues from a correctness or style perspective.
1071-1073: Logprob path: sampling integration across DTensor and non‑DTensor looks soundThe changes in
get_logprobscorrectly:
- Apply temperature scaling once at the logits level via
_apply_temperature_scaling.- Pass
sampling_paramsintoget_logprobs_from_vocab_parallel_logitsfor TP/CP/DTensor paths.- For the non‑DTensor path, apply
_apply_top_k_top_p_filteringon full‑vocab logits (or chunked along the sequence dimension only), then computelog_softmaxand gather token logprobs.This respects the requirement that top‑k/top‑p operate on full vocabulary logits while keeping chunking limited to the sequence axis, so semantics are preserved. I don’t see numerical or shape issues here; just ensure
get_logprobs_from_vocab_parallel_logitsitself does not also apply temperature scaling to avoid double‑dividing.Also applies to: 1111-1118, 1121-1128, 1131-1160
1484-1486: Temperature scaling before top‑k logits is consistent with trainingApplying
_apply_temperature_scalingbefore running distributed/local top‑k keeps this utility aligned with the temperature‑scaled training/logprob behavior without changing the meaning ofkor the indices returned. This looks correct and requires no additional changes.nemo_rl/models/policy/megatron_policy_worker.py (4)
128-133: Importing TrainingSamplingParams here is appropriateBringing
TrainingSamplingParamsinto the Megatron worker keeps the sampling configuration plumbing symmetric with the DTensor workers and centralizes sampling‑related types inpolicy.utils. No concerns.
993-997: forward_step_arbitrary_loss now depends on sampling_paramsPassing
sampling_params=self.sampling_paramsintoforward_step_arbitrary_lossis the right way to thread sampling behavior into the Megatron training step and keep it aligned with the DTensor implementation.This does, however, hard‑require that
forward_step_arbitrary_loss(and any intermediate wrappers it calls) accept asampling_paramskeyword. Please double‑check that all call sites and the function signature innemo_rl.models.megatron.commonare updated accordingly so this doesn’t break older code paths.
1316-1321: Megatron logprob path: temperature and sampling propagation look coherentInside
get_logprobs’sforward_step_fn:
- Applying
output_tensor.div_(self.sampling_params.temperature)when a non‑unit temperature is set matches the DTensor worker behavior and ensures TP‑sharded logits are already temperature‑scaled.- Passing
sampling_params=self.sampling_paramsdown to bothfrom_parallel_logits_to_logprobs_packed_sequencesandfrom_parallel_logits_to_logprobslets those helpers apply top‑k/top‑p masking on full‑vocab probabilities while this function remains agnostic to TP/CP details.This wiring looks correct; just ensure the
from_parallel_logits_to_logprobs*helpers themselves do not additionally apply temperature scaling whensampling_paramsis provided, otherwise you’d effectively divide by temperature twice here.Also applies to: 1327-1351
1602-1608: Temperature scaling in Megatron get_topk_logitsScaling
output_tensorby1 / temperaturebefore callingdistributed_vocab_topkmirrors the training/logprob behavior and keeps top‑k logits consistent with the effective sampling distribution. Since top‑k operates along the vocab dimension only, this is safe and does not alter the semantics ofkor the returned indices.nemo_rl/models/policy/dtensor_policy_worker.py (1)
1520-1522: Temperature scaling before top‑k logits (DTensor worker)Scaling logits via
_apply_temperature_scalingbefore the top‑k computation is consistent with both the Megatron and v2 DTensor workers and matches the intended “train under the same effective sampling distribution” behavior. Implementation is straightforward and correct.nemo_rl/algorithms/loss_functions.py (2)
27-36: LGTM on new imports.The imports for
apply_top_k_top_p,TrainingSamplingParams,need_top_k_filtering, andneed_top_p_filteringare appropriately added to support the new sampling functionality.
279-282: Correct NaN prevention with masking.Good defensive fix to handle cases where top-k/p filtering masks out padding tokens, resulting in -inf logprobs that would otherwise produce NaN when multiplied by zero masks.
nemo_rl/models/policy/utils.py (4)
85-92: LGTM on filtering check helpers.Clean and simple helper functions that correctly identify when filtering should be applied. The logic properly handles
None,-1for top-k, and1.0for top-p as "no filtering" cases.
95-110: LGTM on TrainingSamplingParams dataclass.Well-documented dataclass with appropriate defaults that disable filtering by default. The design avoids vLLM dependency while maintaining API consistency.
113-152: LGTM on top-k only filtering function.Efficient implementation that avoids full vocabulary sorting when only top-k is needed. The threshold-based masking is correct and the keep_mask return enables proper gradient handling.
224-252: LGTM on ApplyTopKTopP autograd function structure.The autograd function correctly saves the keep_mask during forward and applies it during backward to zero out gradients for filtered tokens. However, the correctness depends on fixing the keep_mask ordering issue in
_apply_top_k_top_p_fn.nemo_rl/distributed/model_utils.py (6)
20-25: LGTM on new imports.The imports correctly bring in the sampling utilities needed for the distributed implementation.
265-355: LGTM on DistributedLogprobWithSampling forward pass.The implementation correctly:
- Validates divisibility constraints
- Converts from vocab-parallel to batch-sequence-parallel layout via all-to-all
- Applies top-k/top-p filtering on the materialized full vocabulary
- Computes log probabilities and gathers results across ranks
The layout transformation is necessary because top-p filtering requires access to the full vocabulary distribution.
357-392: LGTM on backward pass structure.The backward pass correctly:
- Extracts the local portion of gradients
- Computes the standard cross-entropy gradient (one_hot - softmax)
- Applies the keep_mask to zero gradients for filtered tokens
- Converts back to vocab-parallel layout
However, this depends on the keep_mask bug fix in
_apply_top_k_top_p_fn.
509-522: Memory consideration: saving full logits for rematerialization.The chunked version saves the full
vocab_parallel_logitstensor for backward rematerialization. While this is correct for the gradient computation, it's worth noting this may use significant memory for large vocabulary sizes. The non-chunked version saves just the softmax output which is smaller when vocab > batch*seq.This is an acceptable trade-off for chunked processing where the goal is to reduce peak memory during computation.
1501-1601: LGTM on all-to-all layout transformation helpers.Well-documented functions with clear einops-style comments explaining the transformations:
vp2sq:(BS, V_local) -> (BS_local, V)- materializes full vocabularysq2vp:(BS_local, V) -> (BS, V_local)- inverse operationThe divisibility checks and error messages are appropriate.
794-840: LGTM on sampling path integration.The routing logic correctly:
- Checks if top-k or top-p filtering is needed
- Routes to the appropriate sampling-enabled or standard autograd function
- Maintains backward compatibility when no sampling is configured
The TODO comment about optimizing top-k separately is a reasonable future enhancement.
Signed-off-by: Zhanda <zhandazhu@gmail.com>
ℹ️ File Consistency CheckCheck based on commit: f46d960 (PR #1578 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhanda <zhandazhu@gmail.com>
ℹ️ File Consistency CheckCheck based on commit: 452b2d9 (PR #1578 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhanda <zhandazhu@gmail.com>
ℹ️ File Consistency CheckCheck based on commit: d76ed60 (PR #1578 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
terrykong
left a comment
There was a problem hiding this comment.
awesome work @zhandaz !
@gshennvm can you review?
@zhandaz can you include convergence plots? maybe at least small qwen out to a few hundred steps to show no collapse and low gen_kl_error? i think we should have at least one nightly test that tests T!=1 & non trivial top-p | top-k, maybe a variant of the qwen30b recipe with just those things changed so that it's comparable?
|
Sure. I will add that in this PR as well. |
Signed-off-by: Zhanda <zhandazhu@gmail.com>
Signed-off-by: Zhanda <zhandazhu@gmail.com>
(not cp loss_function.py common.py megatron_policy_worker.py) Signed-off-by: Zhanda <zhandazhu@gmail.com>
(not cp loss_function.py common.py megatron_policy_worker.py) Signed-off-by: Zhanda <zhandazhu@gmail.com>
(not cp loss_function.py common.py megatron_policy_worker.py) Signed-off-by: Zhanda <zhandazhu@gmail.com>
(not cp loss_function.py common.py megatron_policy_worker.py) Signed-off-by: Zhanda <zhandazhu@gmail.com>
What does this PR do ?
This PR implements top-k and top-p sampling in training to to ensure consistency between training and inference with sampling parameters.
all-to-allcommunication, apply filtering on the full vocabulary, then convert back.Detailed Changes
nemo_rl/model/policy/utils.py: added theTrainingSamplingParamsand implementedapply_top_k_top_pwith proper autograd support.nemo_rl/distributed/model_utils.py: implementedDistributedLogprobWithSamplingandChunkedDistributedLogprobWithSampling, and integrated into all logprob computation paths (TP, CP, DTensor, packed sequences).ClippedPGLossFn,NLLLoss, andDPOLossFnto acceptsampling_params; changed the policy workers call the functions withsampling_params.Tests
test_top_k_top_p_filtering_forward_backward: Tests top_k and top_p functionalities. Test gradient masking correctness.test_distributed_logprob_with_sampling: Tests TP distributed implementation.test_vllm_policy_logprob_agreement_with_sampling: Test that policy worker logprobs match vLLM with sampling parameters.We have also tested that with
temperature != 1, the logprobs match. Seetest_vllm_policy_logprob_agreement_with_sampling.Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.