Conversation
Signed-off-by: Terry Kong <terryk@nvidia.com>
📝 WalkthroughWalkthroughAdds early-exit prints and returns on timeout or max-steps across training loops (distillation, DPO, RM, SFT). GRPO gains timeout-based checkpointing in sync/async paths and a new async parameter. Unit tests added/expanded to cover early exits, epochs/steps stopping, and timeout behavior. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Trainer
participant TimeoutChecker as TimeoutChecker
participant Policy
participant Checkpointer
rect rgb(245,245,255)
note over Trainer,Policy: Generic training loop with early-exit
Trainer->>Policy: train(step)
Policy-->>Trainer: step complete
Trainer->>TimeoutChecker: check_save(iteration)
alt Timeout reached
TimeoutChecker-->>Trainer: should_exit = true
note right of Trainer: Print "Timeout reached, stopping early"
Trainer-->>Trainer: return
else Max steps reached
note right of Trainer: Print "Max steps reached, stopping early"
Trainer-->>Trainer: return
else Continue
TimeoutChecker-->>Trainer: should_exit = false
Trainer-->>Trainer: next step
end
end
sequenceDiagram
autonumber
participant AsyncTrainer as async_grpo_train
participant TimeoutChecker as TimeoutChecker
participant Policy
participant Checkpointer
rect rgb(240,255,240)
note over AsyncTrainer,Checkpointer: GRPO async with timeout-based checkpointing
AsyncTrainer->>TimeoutChecker: start(iteration)
loop each step
AsyncTrainer->>Policy: train_async_step()
Policy-->>AsyncTrainer: step done
AsyncTrainer->>TimeoutChecker: mark(iteration)
AsyncTrainer->>TimeoutChecker: check_save(iteration)
alt should_save_by_step or should_save_by_timeout
note right of AsyncTrainer: Prepare state
AsyncTrainer->>Checkpointer: save_checkpoint()
Checkpointer-->>AsyncTrainer: saved
note right of AsyncTrainer: Early exit after save
AsyncTrainer-->>AsyncTrainer: return
else continue
AsyncTrainer-->>AsyncTrainer: next step
end
end
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
terrykong
left a comment
There was a problem hiding this comment.
@yfw @parthchadha to review
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
tests/unit/algorithms/test_rm.py (1)
178-229: Timeout-driven early-exit test is solid; minor optional enhancements
- Good: side_effect drives exit at 8th check; verifies message and no further epochs.
- Optional: also assert start_iterations()/mark_iteration() were invoked to ensure loop timing hooks are exercised.
with patch("nemo_rl.algorithms.rm.TimeoutChecker") as mock_timeout_class: mock_timeout_instance = MagicMock() # Create a side_effect that returns False 7 times, then True check_results = [False] * 7 + [True] mock_timeout_instance.check_save.side_effect = check_results + mock_timeout_instance.start_iterations = MagicMock() + mock_timeout_instance.mark_iteration = MagicMock() mock_timeout_class.return_value = mock_timeout_instance @@ # Verify training stopped at 8 steps (when check_save returned True) assert mock_components["policy"].train.call_count == 8 + mock_timeout_instance.start_iterations.assert_called_once() + assert mock_timeout_instance.mark_iteration.call_count == 8tests/unit/algorithms/test_distillation.py (1)
215-270: Strengthen step-continuation assertionPrinted step lines begin with separators (e.g., "===== Step …"), so startswith("Step ") won’t match. Check substring presence instead to ensure no steps beyond 8 occur after timeout.
- for line in remaining_lines: - # Distillation doesn't have epochs, but check for step markers - assert not line.startswith("Step ") or "Step 8" in line, ( - f"Training continued after timeout: {line}" - ) + for line in remaining_lines: + # Verify no further step markers after timeout + assert (" Step " not in line) or (" Step 8/" in line), ( + f"Training continued after timeout: {line}" + )nemo_rl/algorithms/grpo.py (1)
1697-1718: Consider exiting immediately after timeout-triggered checkpointWhen should_save_by_timeout is True and checkpointing is enabled, you still perform post-checkpoint logging and only exit later. To minimize work under a hard “must save by” constraint, return immediately after finalize_checkpoint when timeout triggered.
- if master_config["checkpointing"]["enabled"] and ( - should_save_by_step or should_save_by_timeout - ): + if master_config["checkpointing"]["enabled"] and ( + should_save_by_step or should_save_by_timeout + ): policy.prepare_for_training() @@ with timer.time("checkpointing"): print(f"Saving checkpoint for step {step + 1}...") checkpoint_path = checkpointer.init_tmp_checkpoint( step + 1, grpo_save_state, master_config ) policy.save_checkpoint( @@ checkpointer.finalize_checkpoint(checkpoint_path) policy.offload_after_refit() + if should_save_by_timeout: + print("Timeout has been reached, stopping training early", flush=True) + returntests/unit/algorithms/test_dpo.py (2)
102-217: Fixture quality LGTM; optional lint nitComprehensive mocks and sharding annotations. If Ruff flags ARG001 for local iterator helpers, prefix the unused parameter with “_” to silence.
- def train_iter(self): + def train_iter(_self): return iter([mock_batch] * 10) @@ - def val_iter(self): + def val_iter(_self): return iter([mock_batch] * 10)
268-317: Strengthen timeout continuation check (mirror distillation test fix)Use substring detection for “ Step ” rather than startswith to catch the actual printed format.
- for line in remaining_lines: - assert "Epoch" not in line or "Epoch 1/10" in line, ( - f"Training continued to next epoch after timeout: {line}" - ) + for line in remaining_lines: + # No new epochs after timeout + assert ("Epoch" not in line) or ("Epoch 1/10" in line), ( + f"Training continued to next epoch after timeout: {line}" + ) + # No further step markers after timeout + assert (" Step " not in line) or (" Step 8/" in line), ( + f"Training continued after timeout: {line}" + )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
nemo_rl/algorithms/distillation.py(1 hunks)nemo_rl/algorithms/dpo.py(1 hunks)nemo_rl/algorithms/grpo.py(4 hunks)nemo_rl/algorithms/rm.py(1 hunks)nemo_rl/algorithms/sft.py(1 hunks)tests/unit/algorithms/test_distillation.py(2 hunks)tests/unit/algorithms/test_dpo.py(2 hunks)tests/unit/algorithms/test_grpo.py(2 hunks)tests/unit/algorithms/test_rm.py(2 hunks)tests/unit/algorithms/test_sft.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
tests/unit/algorithms/test_distillation.pynemo_rl/algorithms/rm.pytests/unit/algorithms/test_sft.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/distillation.pytests/unit/algorithms/test_grpo.pytests/unit/algorithms/test_rm.pynemo_rl/algorithms/grpo.pytests/unit/algorithms/test_dpo.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/algorithms/rm.pynemo_rl/algorithms/dpo.pynemo_rl/algorithms/sft.pynemo_rl/algorithms/distillation.pynemo_rl/algorithms/grpo.py
🧬 Code graph analysis (6)
tests/unit/algorithms/test_distillation.py (2)
nemo_rl/algorithms/distillation.py (2)
_default_distillation_save_state(96-102)distillation_train(468-849)nemo_rl/utils/timer.py (1)
check_save(284-308)
tests/unit/algorithms/test_sft.py (3)
tests/unit/algorithms/test_distillation.py (2)
test_exit_on_timeout(215-269)mock_components(34-186)nemo_rl/algorithms/sft.py (2)
_default_sft_save_state(56-63)sft_train(347-612)nemo_rl/utils/timer.py (1)
check_save(284-308)
tests/unit/algorithms/test_grpo.py (3)
nemo_rl/algorithms/grpo.py (3)
_default_grpo_save_state(113-121)async_grpo_train(1159-1839)grpo_train(576-1038)nemo_rl/algorithms/loss_functions.py (1)
ClippedPGLossFn(66-359)nemo_rl/distributed/batched_data_dict.py (1)
BatchedDataDict(75-860)
tests/unit/algorithms/test_rm.py (2)
nemo_rl/algorithms/rm.py (2)
_default_rm_save_state(52-59)rm_train(420-662)nemo_rl/utils/timer.py (1)
check_save(284-308)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/utils/timer.py (4)
TimeoutChecker(264-321)start_iterations(310-311)mark_iteration(313-321)check_save(284-308)
tests/unit/algorithms/test_dpo.py (3)
nemo_rl/algorithms/dpo.py (3)
_default_dpo_save_state(51-58)add_ref_logprobs_to_data(270-303)dpo_train(486-741)nemo_rl/algorithms/loss_functions.py (1)
PreferenceLoss(449-542)nemo_rl/models/policy/dtensor_policy_worker.py (2)
train(529-895)get_reference_policy_logprobs(1654-1669)
🪛 Ruff (0.14.0)
tests/unit/algorithms/test_grpo.py
380-380: Unused function argument: self
(ARG001)
388-388: Unused function argument: self
(ARG001)
tests/unit/algorithms/test_dpo.py
155-155: Unused function argument: self
(ARG001)
163-163: Unused function argument: self
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (9)
nemo_rl/algorithms/dpo.py (1)
731-737: Early-exit messaging and return look goodConsistent message, flush, and immediate return. Matches the timeout/max-steps semantics introduced elsewhere.
nemo_rl/algorithms/distillation.py (1)
841-849: Early exit on timeout/max-steps is correctClear message + flush and immediate return. Aligned with other training loops.
nemo_rl/algorithms/rm.py (1)
649-659: RM early-exit behavior LGTMTimeout and max-steps exits now print and return immediately. Works with max_num_steps == -1 sentinel too.
tests/unit/algorithms/test_rm.py (1)
15-15: Import addition is appropriateNeeded for TimeoutChecker patching in new test.
tests/unit/algorithms/test_distillation.py (1)
15-15: Import addition is appropriateRequired for TimeoutChecker patch.
nemo_rl/algorithms/grpo.py (3)
1202-1208: Async path now tracks timeout — goodTimeoutChecker init + start covers async as well. Consistent with sync training.
1810-1818: Async early-exit messaging and return look goodMatches sync behavior; flush ensures visibility.
1159-1173: API change verified: max_trajectory_age_steps is passed at all async_grpo_train call sites or covered by the default.tests/unit/algorithms/test_dpo.py (1)
15-28: Imports/extensions are appropriateBrings in DPO helpers and PreferenceLoss needed for new tests.
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
fix: grpo early exit edge case (NVIDIA-NeMo#1361) See merge request jiaqiz/nemo-rl-internal!3
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: Lawrence Lane <llane@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Terry Kong <terryk@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit