Skip to content

fix: grpo early exit edge case#1361

Merged
terrykong merged 2 commits intomainfrom
tk/exit-early-grpo
Oct 15, 2025
Merged

fix: grpo early exit edge case#1361
terrykong merged 2 commits intomainfrom
tk/exit-early-grpo

Conversation

@terrykong
Copy link
Copy Markdown
Collaborator

@terrykong terrykong commented Oct 15, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • New Features
    • Added clear early-termination messages when training stops due to timeout or reaching the maximum steps across Distillation, DPO, GRPO, RM, and SFT.
    • Introduced timeout-based checkpointing in GRPO, including both synchronous and asynchronous training paths.
    • Added a new parameter to asynchronous GRPO to control maximum trajectory age.
  • Tests
    • Expanded unit tests to cover exits on max steps, max epochs, and timeout across all affected training modes, including async GRPO.

Signed-off-by: Terry Kong <terryk@nvidia.com>
@terrykong terrykong requested a review from yfw October 15, 2025 03:27
@terrykong terrykong requested review from a team as code owners October 15, 2025 03:27
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 15, 2025

📝 Walkthrough

Walkthrough

Adds early-exit prints and returns on timeout or max-steps across training loops (distillation, DPO, RM, SFT). GRPO gains timeout-based checkpointing in sync/async paths and a new async parameter. Unit tests added/expanded to cover early exits, epochs/steps stopping, and timeout behavior.

Changes

Cohort / File(s) Summary of changes
Algorithms: early-exit prints
nemo_rl/algorithms/distillation.py, nemo_rl/algorithms/dpo.py, nemo_rl/algorithms/rm.py, nemo_rl/algorithms/sft.py
Add explicit prints on timeout and max-steps before early return; replace loop breaks with returns where applicable; no public API changes.
GRPO: timeout checkpointing + async param
nemo_rl/algorithms/grpo.py
Introduce timeout-based checkpointing and early-exit in sync and async training; track iteration-based timeout; add parameter max_trajectory_age_steps: int = 1 to async_grpo_train.
Tests: distillation
tests/unit/algorithms/test_distillation.py
Add timeout exit test using patched TimeoutChecker; verify step count and printed timeout message.
Tests: DPO
tests/unit/algorithms/test_dpo.py
Add fixtures and tests for max-steps, max-epochs, timeout exits; extend coverage for reference logprobs augmentation; import and validate _default_dpo_save_state, add_ref_logprobs_to_data, dpo_train.
Tests: GRPO (sync/async)
tests/unit/algorithms/test_grpo.py
Add comprehensive sync/async fixtures; tests for max-steps, max-epochs, timeout exits; validate cleanup/outputs and call counts.
Tests: RM
tests/unit/algorithms/test_rm.py
Add timeout exit test; assert printed message and no further epochs after timeout.
Tests: SFT
tests/unit/algorithms/test_sft.py
Add timeout exit test; patch TimeoutChecker; verify printed message and halt after timeout.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Trainer
  participant TimeoutChecker as TimeoutChecker
  participant Policy
  participant Checkpointer

  rect rgb(245,245,255)
    note over Trainer,Policy: Generic training loop with early-exit
    Trainer->>Policy: train(step)
    Policy-->>Trainer: step complete
    Trainer->>TimeoutChecker: check_save(iteration)
    alt Timeout reached
      TimeoutChecker-->>Trainer: should_exit = true
      note right of Trainer: Print "Timeout reached, stopping early"
      Trainer-->>Trainer: return
    else Max steps reached
      note right of Trainer: Print "Max steps reached, stopping early"
      Trainer-->>Trainer: return
    else Continue
      TimeoutChecker-->>Trainer: should_exit = false
      Trainer-->>Trainer: next step
    end
  end
Loading
sequenceDiagram
  autonumber
  participant AsyncTrainer as async_grpo_train
  participant TimeoutChecker as TimeoutChecker
  participant Policy
  participant Checkpointer

  rect rgb(240,255,240)
    note over AsyncTrainer,Checkpointer: GRPO async with timeout-based checkpointing
    AsyncTrainer->>TimeoutChecker: start(iteration)
    loop each step
      AsyncTrainer->>Policy: train_async_step()
      Policy-->>AsyncTrainer: step done
      AsyncTrainer->>TimeoutChecker: mark(iteration)
      AsyncTrainer->>TimeoutChecker: check_save(iteration)
      alt should_save_by_step or should_save_by_timeout
        note right of AsyncTrainer: Prepare state
        AsyncTrainer->>Checkpointer: save_checkpoint()
        Checkpointer-->>AsyncTrainer: saved
        note right of AsyncTrainer: Early exit after save
        AsyncTrainer-->>AsyncTrainer: return
      else continue
        AsyncTrainer-->>AsyncTrainer: next step
      end
    end
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

CI:L0, r0.4.0

Suggested reviewers

  • guyueh1
  • chtruong814

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 67.86% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning The PR introduces substantive changes across multiple training algorithms by adding new early-exit and checkpointing logic, and although unit tests have been added, the PR description contains no summary of test outcomes, numerical regression checks, or performance comparisons to demonstrate that convergence and throughput remain unaffected. Please update the PR description to include a concise summary of test results (e.g., success rates, coverage) and, where applicable, before-and-after performance or convergence metrics to show that these early-exit changes do not regress existing behavior.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title mentions a GRPO-specific fix which is indeed part of the PR but the changes extend to multiple training algorithms (distillation, DPO, RM, SFT) and their early-exit behavior, so it only partially reflects the full scope of updates.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch tk/exit-early-grpo

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Collaborator Author

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yfw @parthchadha to review

@terrykong terrykong requested a review from parthchadha October 15, 2025 03:30
@terrykong terrykong added CI:L1 Run doctests, unit tests, and functional tests r0.4.0 and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 15, 2025
@terrykong terrykong added the CI:L1 Run doctests, unit tests, and functional tests label Oct 15, 2025
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
tests/unit/algorithms/test_rm.py (1)

178-229: Timeout-driven early-exit test is solid; minor optional enhancements

  • Good: side_effect drives exit at 8th check; verifies message and no further epochs.
  • Optional: also assert start_iterations()/mark_iteration() were invoked to ensure loop timing hooks are exercised.
         with patch("nemo_rl.algorithms.rm.TimeoutChecker") as mock_timeout_class:
             mock_timeout_instance = MagicMock()
             # Create a side_effect that returns False 7 times, then True
             check_results = [False] * 7 + [True]
             mock_timeout_instance.check_save.side_effect = check_results
+            mock_timeout_instance.start_iterations = MagicMock()
+            mock_timeout_instance.mark_iteration = MagicMock()
             mock_timeout_class.return_value = mock_timeout_instance
@@
         # Verify training stopped at 8 steps (when check_save returned True)
         assert mock_components["policy"].train.call_count == 8
+        mock_timeout_instance.start_iterations.assert_called_once()
+        assert mock_timeout_instance.mark_iteration.call_count == 8
tests/unit/algorithms/test_distillation.py (1)

215-270: Strengthen step-continuation assertion

Printed step lines begin with separators (e.g., "===== Step …"), so startswith("Step ") won’t match. Check substring presence instead to ensure no steps beyond 8 occur after timeout.

-        for line in remaining_lines:
-            # Distillation doesn't have epochs, but check for step markers
-            assert not line.startswith("Step ") or "Step 8" in line, (
-                f"Training continued after timeout: {line}"
-            )
+        for line in remaining_lines:
+            # Verify no further step markers after timeout
+            assert (" Step " not in line) or (" Step 8/" in line), (
+                f"Training continued after timeout: {line}"
+            )
nemo_rl/algorithms/grpo.py (1)

1697-1718: Consider exiting immediately after timeout-triggered checkpoint

When should_save_by_timeout is True and checkpointing is enabled, you still perform post-checkpoint logging and only exit later. To minimize work under a hard “must save by” constraint, return immediately after finalize_checkpoint when timeout triggered.

-                if master_config["checkpointing"]["enabled"] and (
-                    should_save_by_step or should_save_by_timeout
-                ):
+                if master_config["checkpointing"]["enabled"] and (
+                    should_save_by_step or should_save_by_timeout
+                ):
                     policy.prepare_for_training()
@@
                     with timer.time("checkpointing"):
                         print(f"Saving checkpoint for step {step + 1}...")
                         checkpoint_path = checkpointer.init_tmp_checkpoint(
                             step + 1, grpo_save_state, master_config
                         )
                         policy.save_checkpoint(
@@
                         checkpointer.finalize_checkpoint(checkpoint_path)
                     policy.offload_after_refit()
+                    if should_save_by_timeout:
+                        print("Timeout has been reached, stopping training early", flush=True)
+                        return
tests/unit/algorithms/test_dpo.py (2)

102-217: Fixture quality LGTM; optional lint nit

Comprehensive mocks and sharding annotations. If Ruff flags ARG001 for local iterator helpers, prefix the unused parameter with “_” to silence.

-    def train_iter(self):
+    def train_iter(_self):
         return iter([mock_batch] * 10)
@@
-    def val_iter(self):
+    def val_iter(_self):
         return iter([mock_batch] * 10)

268-317: Strengthen timeout continuation check (mirror distillation test fix)

Use substring detection for “ Step ” rather than startswith to catch the actual printed format.

-        for line in remaining_lines:
-            assert "Epoch" not in line or "Epoch 1/10" in line, (
-                f"Training continued to next epoch after timeout: {line}"
-            )
+        for line in remaining_lines:
+            # No new epochs after timeout
+            assert ("Epoch" not in line) or ("Epoch 1/10" in line), (
+                f"Training continued to next epoch after timeout: {line}"
+            )
+            # No further step markers after timeout
+            assert (" Step " not in line) or (" Step 8/" in line), (
+                f"Training continued after timeout: {line}"
+            )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5c67023 and 5b73fef.

📒 Files selected for processing (10)
  • nemo_rl/algorithms/distillation.py (1 hunks)
  • nemo_rl/algorithms/dpo.py (1 hunks)
  • nemo_rl/algorithms/grpo.py (4 hunks)
  • nemo_rl/algorithms/rm.py (1 hunks)
  • nemo_rl/algorithms/sft.py (1 hunks)
  • tests/unit/algorithms/test_distillation.py (2 hunks)
  • tests/unit/algorithms/test_dpo.py (2 hunks)
  • tests/unit/algorithms/test_grpo.py (2 hunks)
  • tests/unit/algorithms/test_rm.py (2 hunks)
  • tests/unit/algorithms/test_sft.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • tests/unit/algorithms/test_distillation.py
  • nemo_rl/algorithms/rm.py
  • tests/unit/algorithms/test_sft.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/distillation.py
  • tests/unit/algorithms/test_grpo.py
  • tests/unit/algorithms/test_rm.py
  • nemo_rl/algorithms/grpo.py
  • tests/unit/algorithms/test_dpo.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/algorithms/rm.py
  • nemo_rl/algorithms/dpo.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/algorithms/distillation.py
  • nemo_rl/algorithms/grpo.py
🧬 Code graph analysis (6)
tests/unit/algorithms/test_distillation.py (2)
nemo_rl/algorithms/distillation.py (2)
  • _default_distillation_save_state (96-102)
  • distillation_train (468-849)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
tests/unit/algorithms/test_sft.py (3)
tests/unit/algorithms/test_distillation.py (2)
  • test_exit_on_timeout (215-269)
  • mock_components (34-186)
nemo_rl/algorithms/sft.py (2)
  • _default_sft_save_state (56-63)
  • sft_train (347-612)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
tests/unit/algorithms/test_grpo.py (3)
nemo_rl/algorithms/grpo.py (3)
  • _default_grpo_save_state (113-121)
  • async_grpo_train (1159-1839)
  • grpo_train (576-1038)
nemo_rl/algorithms/loss_functions.py (1)
  • ClippedPGLossFn (66-359)
nemo_rl/distributed/batched_data_dict.py (1)
  • BatchedDataDict (75-860)
tests/unit/algorithms/test_rm.py (2)
nemo_rl/algorithms/rm.py (2)
  • _default_rm_save_state (52-59)
  • rm_train (420-662)
nemo_rl/utils/timer.py (1)
  • check_save (284-308)
nemo_rl/algorithms/grpo.py (1)
nemo_rl/utils/timer.py (4)
  • TimeoutChecker (264-321)
  • start_iterations (310-311)
  • mark_iteration (313-321)
  • check_save (284-308)
tests/unit/algorithms/test_dpo.py (3)
nemo_rl/algorithms/dpo.py (3)
  • _default_dpo_save_state (51-58)
  • add_ref_logprobs_to_data (270-303)
  • dpo_train (486-741)
nemo_rl/algorithms/loss_functions.py (1)
  • PreferenceLoss (449-542)
nemo_rl/models/policy/dtensor_policy_worker.py (2)
  • train (529-895)
  • get_reference_policy_logprobs (1654-1669)
🪛 Ruff (0.14.0)
tests/unit/algorithms/test_grpo.py

380-380: Unused function argument: self

(ARG001)


388-388: Unused function argument: self

(ARG001)

tests/unit/algorithms/test_dpo.py

155-155: Unused function argument: self

(ARG001)


163-163: Unused function argument: self

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (9)
nemo_rl/algorithms/dpo.py (1)

731-737: Early-exit messaging and return look good

Consistent message, flush, and immediate return. Matches the timeout/max-steps semantics introduced elsewhere.

nemo_rl/algorithms/distillation.py (1)

841-849: Early exit on timeout/max-steps is correct

Clear message + flush and immediate return. Aligned with other training loops.

nemo_rl/algorithms/rm.py (1)

649-659: RM early-exit behavior LGTM

Timeout and max-steps exits now print and return immediately. Works with max_num_steps == -1 sentinel too.

tests/unit/algorithms/test_rm.py (1)

15-15: Import addition is appropriate

Needed for TimeoutChecker patching in new test.

tests/unit/algorithms/test_distillation.py (1)

15-15: Import addition is appropriate

Required for TimeoutChecker patch.

nemo_rl/algorithms/grpo.py (3)

1202-1208: Async path now tracks timeout — good

TimeoutChecker init + start covers async as well. Consistent with sync training.


1810-1818: Async early-exit messaging and return look good

Matches sync behavior; flush ensures visibility.


1159-1173: API change verified: max_trajectory_age_steps is passed at all async_grpo_train call sites or covered by the default.

tests/unit/algorithms/test_dpo.py (1)

15-28: Imports/extensions are appropriate

Brings in DPO helpers and PreferenceLoss needed for new tests.

Comment thread tests/unit/algorithms/test_grpo.py Outdated
parthchadha
parthchadha previously approved these changes Oct 15, 2025
yfw
yfw previously approved these changes Oct 15, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
@terrykong terrykong dismissed stale reviews from yfw and parthchadha via 7c672f5 October 15, 2025 06:09
@terrykong terrykong added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Oct 15, 2025
@terrykong terrykong enabled auto-merge (squash) October 15, 2025 06:10
@terrykong terrykong merged commit 0a769cc into main Oct 15, 2025
40 of 41 checks passed
@terrykong terrykong deleted the tk/exit-early-grpo branch October 15, 2025 09:17
chtruong814 pushed a commit that referenced this pull request Oct 15, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
odelalleau pushed a commit to odelalleau/NeMo-RL that referenced this pull request Oct 21, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
odelalleau pushed a commit to odelalleau/NeMo-RL that referenced this pull request Oct 21, 2025
fix: grpo early exit edge case (NVIDIA-NeMo#1361)

See merge request jiaqiz/nemo-rl-internal!3
terrykong added a commit that referenced this pull request Nov 2, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
lbliii pushed a commit that referenced this pull request Nov 3, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: Lawrence Lane <llane@nvidia.com>
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Terry Kong <terryk@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Terry Kong <terryk@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants