Skip to content

feat: Support Ray Compiled Graph for SFT#1612

Open
katec846 wants to merge 14 commits intoNVIDIA-NeMo:mainfrom
katec846:ray_overhead_exp
Open

feat: Support Ray Compiled Graph for SFT#1612
katec846 wants to merge 14 commits intoNVIDIA-NeMo:mainfrom
katec846:ray_overhead_exp

Conversation

@katec846
Copy link
Copy Markdown
Contributor

@katec846 katec846 commented Dec 9, 2025

What does this PR do ?

Adds Ray Compiled Graph support for SFT

TODO:

  • Rebase the code
  • Verify the results after rebasing (correctness and performance)
  • Make sure it runs on dt_policy_worker (Can run without error)
  • Make RCG and warmup control by runtime flag instead of envvars
  • Remove per-worker computation time tracking
  • Add tests for Ray compiled graph

Ray Compiled Graph (RCG) - Sub-millisecond task orchestration overhead (<50μs vs ~1ms for standard Ray calls)

Key Features

Ray Compiled Graph Support

Implements Ray's compiled DAG execution for distributed training with support for:

  • Pipeline Parallelism (PP), Tensor Parallelism (TP), Context Parallelism (CP), and Data Parallelism (DP)
  • Wrapper architecture to work around Ray's ActorMethod signature hiding
  • Automatic DAG compilation and teardown
  • Multi-DP shard support with independent DAG execution per shard

Enable with:
++policy.sequence_packing.enabled=True

Technical Changes

Architecture Changes

1. NeMoRayWorkerWrapper - Ray Actor Wrapper Pattern

Problem: Ray's @ray.remote decorator wraps all methods in ActorMethod objects with generic signatures, preventing Ray Compiled Graph from inspecting actual method parameters (Ray issue #26283).

Solution:

  • Introduced NeMoRayWorkerWrapper class that receives @ray.remote decoration instead of worker classes
  • Workers are instantiated inside the wrapper: self.worker = worker_class(*args, **kwargs)
  • Explicit train_compiled(train_input: dict) method with inspectable signature for Compiled Graph
  • Generic execute_method(method_name, *args, **kwargs) for standard Ray remote calls

Changes:

  • Removed @ray.remote decorator from all worker classes (MegatronPolicyWorker, DTensorPolicyWorker, DTensorPolicyWorkerV2)
  • RayWorkerBuilder now applies @ray.remote(NeMoRayWorkerWrapper) at runtime
  • All worker method calls changed to: worker.execute_method.remote(method_name, *args, **kwargs)

2. Ray Compiled Graph Infrastructure (ray_compiled_graph.py)

New Classes:

CompiledGraphExecutor

  • Manages DAG execution for a single DP shard
  • Organizes workers into PP×TP×CP topology
  • Builds and compiles static DAG with all pipeline stages as nodes
  • Note: Megatron handles inter-stage PP communication; Ray orchestrates execution timing

MultiDPCompiledGraphExecutor

  • Manages multiple independent DAGs (one per DP shard)
  • Each DP shard gets its own compiled DAG with different data
  • Coordinates parallel execution across DP replicas

CompiledGraphWorkerGroup

  • Drop-in replacement wrapper for RayWorkerGroup
  • Automatically uses compiled graphs when enabled via config
  • Falls back to standard Ray execution when disabled or on compilation errors
  • Handles DAG compilation, per-method caching, and resource cleanup

Helper Functions:

  • get_compiled_graph_config(): Extracts RCG configuration from policy config
  • should_use_compiled_graph(): Determines if RCG should be enabled

3. Configuration Support

New RayCompiledGraphConfig in PolicyConfig:
ray_compiled_graph:
enabled: true # Enable/disable RCG
warmup_seq_len: 8192 # Optional: warmup sequence length
warmup_gbs: 128 # Optional: warmup global batch size
overlap_communication: false # Experimental: overlap GPU compute and comm### 4. Automatic Warmup

_warmup_compiled_graph() in sft.py:

  • Automatically runs before training when RCG is enabled
  • Creates fake data with maximum sequence length to pre-compile the DAG
  • Prevents recompilation during actual training
  • Configurable warmup parameters via warmup_seq_len and warmup_gbs

5. Policy Integration

lm_policy.py changes:

  • Wraps RayWorkerGroup with CompiledGraphWorkerGroup when RCG is enabled
  • Transparent to existing code - same interface as RayWorkerGroup
  • Graceful shutdown handling for compiled DAG resources

Additional Improvements

  • Improved Shutdown: Graceful compiled DAG teardown with suppressed Ray logging

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

Release Notes

  • New Features

    • Ray Compiled Graph support for distributed training with configurable warmup and communication settings.
    • New ray_compiled_graph configuration option to control optimization behavior.
  • Improvements

    • Automatic compiled graph warmup before training initialization when the feature is enabled.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 3d4324c (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 6e7810a (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: d319f3f (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 6edb10b (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: f55280d (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 89d7d95 (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 018e95a (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 8c3204f (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: f28a4d8 (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: d60c225 (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@katec846 katec846 marked this pull request as ready for review December 12, 2025 20:03
@katec846 katec846 requested review from a team as code owners December 12, 2025 20:03
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Dec 12, 2025

📝 Walkthrough

Walkthrough

This PR implements Ray Compiled Graph (RCG) support for distributed training. It introduces a new compiled graph infrastructure module, wraps workers with a uniform method invocation interface, removes @ray.remote decorators from policy worker classes for runtime application, and integrates optional RCG warmup into the SFT training pipeline.

Changes

Cohort / File(s) Summary
Ray Compiled Graph Infrastructure
nemo_rl/distributed/ray_compiled_graph.py
New module implementing RCG support with CompiledGraphExecutor (single DP), MultiDPCompiledGraphExecutor (multi-DP coordination), CompiledGraphWorkerGroup (high-level wrapper around RayWorkerGroup), and helper functions (should_use_compiled_graph, get_compiled_graph_config) to manage DAG compilation, method dispatch, and fallback to standard Ray execution.
Worker Wrapper Abstraction
nemo_rl/distributed/worker_groups.py
Introduces NeMoRayWorkerWrapper to encapsulate workers with train_compiled and execute_method, enabling uniform RPC dispatch. Refactors all worker creation and method invocations (run_single_worker_single_data, run_all_workers_*) to route through the wrapper's execute_method.remote.
Policy Worker Decorator Removal
nemo_rl/models/policy/workers/dtensor_policy_worker.py, nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py, nemo_rl/models/policy/workers/megatron_policy_worker.py
Removes @ray.remote class decorators and adds class attribute _default_options (or _ray_remote_runtime_env) to store runtime environment for runtime-applied remote decoration. Adds worker_start_time timing instrumentation in train methods.
SFT Training Warmup Integration
nemo_rl/algorithms/sft.py
Introduces _warmup_compiled_graph helper that constructs synthetic inputs, runs a single training step with eval_mode=True to trigger RCG compilation, and logs results. Integrates warmup invocation at training start when RCG is enabled.
Configuration Extensions
nemo_rl/models/policy/__init__.py
Adds RayCompiledGraphConfig TypedDict with enabled, warmup_seq_len, warmup_gbs, overlap_communication fields; extends PolicyConfig with ray_compiled_graph field.
Policy Integration
nemo_rl/models/policy/lm_policy.py
Wraps RayWorkerGroup with CompiledGraphWorkerGroup when RCG is enabled. Adds try/except guarding in del for worker group shutdown. Logs enablement status at initialization.
Environment Configuration
ray.sub
Sets RAY_LOG_SYNC_FREQUENCY default to 30 seconds and exports it in head/worker command blocks.
Test Updates
tests/unit/distributed/test_worker_groups.py, tests/unit/models/policy/test_dtensor_worker.py, tests/unit/models/policy/test_dtensor_worker_v2.py
Replaces direct remote method calls (e.g., get_pid.remote, return_state_dict.remote) with generic execute_method.remote("<method_name>", ...) invocations to align with wrapper-based architecture.

Sequence Diagram

sequenceDiagram
    participant SFT as SFT Trainer
    participant RCGWrap as CompiledGraphWorkerGroup
    participant CGExec as CompiledGraphExecutor
    participant DAG as Ray Compiled DAG
    participant Worker as Policy Worker

    SFT->>RCGWrap: Initialize with config (RCG enabled)
    RCGWrap->>RCGWrap: Wrap RayWorkerGroup

    Note over SFT,DAG: Warmup Phase (total_steps == 0)
    SFT->>RCGWrap: _warmup_compiled_graph(synthetic_data)
    RCGWrap->>CGExec: Create executor for 'train' method
    CGExec->>DAG: Build & compile DAG with workers
    DAG->>Worker: Execute train_compiled(synthetic_batch)
    Worker-->>DAG: Return training results
    DAG-->>CGExec: Results via compiled kernel
    CGExec-->>RCGWrap: Warmup complete, DAG ready

    Note over SFT,DAG: Training Phase
    SFT->>RCGWrap: run_all_workers_sharded_data(train, real_data)
    RCGWrap->>CGExec: Use compiled executor (cached)
    CGExec->>DAG: Execute train_compiled(real_batch)
    DAG->>Worker: Run via compiled kernel
    Worker-->>DAG: Training outputs
    DAG-->>CGExec: Fused result
    CGExec-->>RCGWrap: Return to caller
    RCGWrap-->>SFT: Training step complete
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • nemo_rl/distributed/ray_compiled_graph.py: New 400+ line module with dense logic for DAG construction, DP shard alignment, per-method executor management, and fallback handling. Requires careful validation of compilation triggers, error paths, and interaction with Ray runtime.
  • nemo_rl/distributed/worker_groups.py: Pervasive refactoring affecting all worker method invocations via the wrapper. Cross-verify that sharded_data slicing, common_kwargs merging, and return wrapping maintain correctness across run_* variants.
  • @ray.remote decorator removal across worker classes: Verify that runtime application of remote decoration via RayWorkerBuilder produces identical behavior to decorator-time application, especially for runtime_env propagation and actor initialization.
  • SFT warmup integration: Ensure synthetic data construction (max seq length, batch shape) is representative and that warmup exceptions don't break training initialization.
  • Test coverage: Confirm that test updates via execute_method properly invoke the new wrapper and that all method names resolve correctly.

Possibly related PRs

Suggested labels

CI:L1, CI

Suggested reviewers

  • terrykong
  • yaoyu-33
  • parthchadha

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.49% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Test Results For Major Changes ⚠️ Warning PR implements Ray Compiled Graph support with major architectural changes but incomplete test coverage and missing performance validation for sub-millisecond orchestration improvements. Add comprehensive test coverage including unit tests, integration tests, performance benchmarks showing before/after orchestration overhead, convergence validation, and test result documentation.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title 'feat: Support Ray Compiled Graph for SFT' accurately captures the main feature addition of Ray Compiled Graph support for Supervised Fine-Tuning, matching the primary objective of reducing Ray task orchestration overhead.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)

510-523: Remove unused timing code (or plumb it into metrics/logging).

worker_start_time is currently unused, so this is dead code and will fail lint (F841).

@@
     def train(
@@
     ) -> dict[str, Any]:
         """Train the policy on a batch of data with a given loss function."""
-        import time
-
-        worker_start_time = time.time()
-
         if gbs is None:
             gbs = self.cfg["train_global_batch_size"]
nemo_rl/distributed/worker_groups.py (1)

886-1018: run_all_workers_sharded_data() only shards the first axis in in_sharded_axes due to early break — limits to single-axis despite multi-axis wording in docstring.
The method accepts a list of axes but exits the loop after processing the first match (line ~956), silently ignoring subsequent axes. All current usage is single-axis only. Either assert single-axis at the start to prevent misconfiguration, or implement proper nested indexing if multi-axis is required.

Example fix to enforce single-axis:

     def run_all_workers_sharded_data(
@@
         if in_sharded_axes is None:
             in_sharded_axes = []
+        if len(in_sharded_axes) > 1:
+            raise ValueError(
+                f"Only single-axis sharding is supported (got {in_sharded_axes})."
+            )
🧹 Nitpick comments (6)
ray.sub (1)

319-321: Export on head too (consistency) and quote the export.
You export RAY_LOG_SYNC_FREQUENCY in worker containers, but the head command relies on outer interpolation; exporting in both places keeps behavior consistent and avoids surprises if the heredoc content changes.

-export RAY_LOG_SYNC_FREQUENCY=$RAY_LOG_SYNC_FREQUENCY
+export RAY_LOG_SYNC_FREQUENCY="${RAY_LOG_SYNC_FREQUENCY}"
nemo_rl/models/policy/workers/megatron_policy_worker.py (1)

422-429: Annotate _default_options as ClassVar (and consider preventing mutation).
This matches Ruff RUF012 and avoids accidental per-class shared mutation surprises.

-from typing import Any, Iterator, Optional, TypeVar, cast
+from typing import Any, ClassVar, Iterator, Optional, TypeVar, cast
@@
 class MegatronPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface):
     # Default options to use when applying ray.remote() at runtime
-    _default_options = {
+    _default_options: ClassVar[dict[str, Any]] = {
         "runtime_env": get_runtime_env_for_policy_worker("megatron_policy_worker")
     }
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)

134-141: Annotate _default_options as ClassVar (and keep it immutable-ish).

Ruff is right here: make the class attribute a ClassVar to avoid accidental instance-level expectations and typing noise.

@@
-from typing import Any, Generator, Iterable, Optional, Set, Union, cast
+from typing import Any, ClassVar, Generator, Iterable, Optional, Set, Union, cast
@@
 class DTensorPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface):
     # Default options to use when applying ray.remote() at runtime
-    _default_options = {
+    _default_options: ClassVar[dict[str, Any]] = {
         "runtime_env": get_runtime_env_for_policy_worker("dtensor_policy_worker")
     }
nemo_rl/models/policy/lm_policy.py (1)

853-859: Avoid fully silent __del__ failures (optional).

I get why you’re suppressing teardown errors during interpreter exit, but a totally silent except Exception: pass can mask real cleanup regressions. Consider a minimal opt-in debug log (e.g., behind an env var) instead of unconditional silence.

nemo_rl/distributed/worker_groups.py (1)

141-158: Ray class “unwrapping” uses private attributes — verify against Ray 2.49.2.
_ray_actor_class / __ray_metadata__ feel fragile across Ray versions; if they’re wrong, wrapper will instantiate the wrong thing (or fail) in hard-to-debug ways.

  • Please verify on Ray 2.49.2 that a remotely-decorated class actually exposes _ray_actor_class and/or __ray_metadata__.modified_class in the forms expected here, and add a small unit test that passes a @ray.remote class into NeMoRayWorkerWrapper (even if production no longer does).
nemo_rl/distributed/ray_compiled_graph.py (1)

81-123: A few low-friction cleanups to align with repo style/lint.
Mostly Ruff-driven / maintainability.

  • Prefer logger.exception(...) inside except Exception blocks where you re-raise or want stack traces (TRY400).
  • Replace the Unicode × in docstrings/comments with x (RUF002/RUF003) if Ruff is enforced.
  • Consider narrowing teardown except Exception to Ray-specific teardown errors if feasible (BLE001); if not, add a short comment why broad is required here.
  • If you’re on Python 3.12+ and Ruff enforces it, add strict= to zip(...) where applicable (B905).

Also applies to: 227-237, 260-293, 368-402, 501-507, 622-635

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e3cfb11 and d60c225.

📒 Files selected for processing (12)
  • nemo_rl/algorithms/sft.py (2 hunks)
  • nemo_rl/distributed/ray_compiled_graph.py (1 hunks)
  • nemo_rl/distributed/worker_groups.py (8 hunks)
  • nemo_rl/models/policy/__init__.py (2 hunks)
  • nemo_rl/models/policy/lm_policy.py (5 hunks)
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py (2 hunks)
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1 hunks)
  • nemo_rl/models/policy/workers/megatron_policy_worker.py (2 hunks)
  • ray.sub (2 hunks)
  • tests/unit/distributed/test_worker_groups.py (30 hunks)
  • tests/unit/models/policy/test_dtensor_worker.py (1 hunks)
  • tests/unit/models/policy/test_dtensor_worker_v2.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/__init__.py
  • nemo_rl/algorithms/sft.py
  • tests/unit/distributed/test_worker_groups.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/distributed/ray_compiled_graph.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • nemo_rl/models/policy/lm_policy.py
  • tests/unit/models/policy/test_dtensor_worker_v2.py
  • nemo_rl/distributed/worker_groups.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/__init__.py
  • nemo_rl/algorithms/sft.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/distributed/ray_compiled_graph.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • nemo_rl/models/policy/lm_policy.py
  • nemo_rl/distributed/worker_groups.py
!(**/tests/**|**/test_*.py|**/test_*.sh)

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/__init__.py
  • ray.sub
  • nemo_rl/algorithms/sft.py
  • tests/unit/distributed/test_worker_groups.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/distributed/ray_compiled_graph.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • nemo_rl/models/policy/lm_policy.py
  • tests/unit/models/policy/test_dtensor_worker_v2.py
  • nemo_rl/distributed/worker_groups.py
**/*.{py,sh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)

Files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/models/policy/__init__.py
  • nemo_rl/algorithms/sft.py
  • tests/unit/distributed/test_worker_groups.py
  • tests/unit/models/policy/test_dtensor_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/distributed/ray_compiled_graph.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • nemo_rl/models/policy/lm_policy.py
  • tests/unit/models/policy/test_dtensor_worker_v2.py
  • nemo_rl/distributed/worker_groups.py
🧠 Learnings (4)
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to nemo_rl/**/*.py : For any source file under nemo_rl/*.py that defines a class or function decorated with ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes

Applied to files:

  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • nemo_rl/distributed/ray_compiled_graph.py
  • nemo_rl/models/policy/workers/megatron_policy_worker.py
  • nemo_rl/distributed/worker_groups.py
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to **/*.py : When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml

Applied to files:

  • nemo_rl/models/policy/__init__.py
📚 Learning: 2025-09-19T02:44:38.451Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:73-84
Timestamp: 2025-09-19T02:44:38.451Z
Learning: The scheduler configuration format with a separate "milestones: [20]" entry (not wrapped under name/kwargs) is a valid and established pattern used across GRPO, DPO, and distillation configs in the NeMo RL codebase. This format specifies transition points between different schedulers (e.g., LinearLR for warmup steps, then ConstantLR).

Applied to files:

  • nemo_rl/models/policy/__init__.py
📚 Learning: 2025-11-06T22:30:22.860Z
Learnt from: ZhiyuLi-Nvidia
Repo: NVIDIA-NeMo/RL PR: 1477
File: nemo_rl/models/generation/vllm/vllm_backend.py:163-168
Timestamp: 2025-11-06T22:30:22.860Z
Learning: For Ray actor methods in the vLLM generation worker code (vllm_backend.py, vllm_worker.py, vllm_worker_async.py), error handling should use print/traceback + return False pattern rather than raising exceptions, following the Ray RPC practice where exceptions may not propagate well across process boundaries.

Applied to files:

  • tests/unit/distributed/test_worker_groups.py
  • nemo_rl/distributed/worker_groups.py
🧬 Code graph analysis (9)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/models/policy/utils.py (1)
  • get_runtime_env_for_policy_worker (286-296)
nemo_rl/algorithms/sft.py (3)
nemo_rl/models/policy/interfaces.py (2)
  • PolicyInterface (50-157)
  • train (100-118)
nemo_rl/distributed/batched_data_dict.py (2)
  • BatchedDataDict (75-860)
  • size (814-823)
nemo_rl/distributed/ray_compiled_graph.py (1)
  • get_compiled_graph_config (418-442)
tests/unit/distributed/test_worker_groups.py (1)
nemo_rl/distributed/worker_groups.py (3)
  • execute_method (186-201)
  • workers (699-700)
  • run_all_workers_single_data (838-884)
tests/unit/models/policy/test_dtensor_worker.py (1)
nemo_rl/distributed/worker_groups.py (2)
  • workers (699-700)
  • execute_method (186-201)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (4)
nemo_rl/models/policy/workers/base_policy_worker.py (1)
  • AbstractPolicyWorker (25-154)
nemo_rl/models/policy/interfaces.py (1)
  • ColocatablePolicyInterface (160-193)
nemo_rl/models/policy/utils.py (1)
  • get_runtime_env_for_policy_worker (286-296)
nemo_rl/utils/timer.py (1)
  • time (110-123)
nemo_rl/distributed/ray_compiled_graph.py (2)
nemo_rl/distributed/named_sharding.py (4)
  • NamedSharding (19-222)
  • get_axis_size (209-211)
  • get_ranks (155-199)
  • get_worker_coords (103-122)
nemo_rl/distributed/worker_groups.py (1)
  • train_compiled (161-184)
nemo_rl/models/policy/workers/megatron_policy_worker.py (4)
nemo_rl/models/policy/workers/base_policy_worker.py (1)
  • AbstractPolicyWorker (25-154)
nemo_rl/models/policy/interfaces.py (1)
  • ColocatablePolicyInterface (160-193)
nemo_rl/models/policy/utils.py (1)
  • get_runtime_env_for_policy_worker (286-296)
nemo_rl/utils/timer.py (1)
  • time (110-123)
tests/unit/models/policy/test_dtensor_worker_v2.py (1)
nemo_rl/distributed/worker_groups.py (1)
  • execute_method (186-201)
nemo_rl/distributed/worker_groups.py (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
  • train (490-866)
nemo_rl/models/policy/interfaces.py (1)
  • train (100-118)
🪛 Ruff (0.14.8)
nemo_rl/algorithms/sft.py

476-476: Do not catch blind exception: Exception

(BLE001)


479-479: String contains ambiguous (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?

(RUF001)

nemo_rl/models/policy/workers/dtensor_policy_worker.py

138-140: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


522-522: Local variable worker_start_time is assigned to but never used

Remove assignment to unused variable worker_start_time

(F841)

nemo_rl/distributed/ray_compiled_graph.py

117-121: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


131-131: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)


143-143: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


160-163: Prefer TypeError exception for invalid type

(TRY004)


160-163: Avoid specifying long messages outside the exception class

(TRY003)


200-200: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


203-203: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)


284-284: Do not catch blind exception: Exception

(BLE001)


291-291: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


393-393: Do not catch blind exception: Exception

(BLE001)


400-400: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)


501-501: Do not catch blind exception: Exception

(BLE001)


502-504: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


528-528: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)


528-528: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF002)


609-609: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


609-609: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?

(RUF003)


622-622: Do not catch blind exception: Exception

(BLE001)


623-623: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


708-710: Avoid specifying long messages outside the exception class

(TRY003)


720-720: Avoid specifying long messages outside the exception class

(TRY003)


747-747: Loop control variable dp_rank not used within loop body

(B007)


824-824: Do not catch blind exception: Exception

(BLE001)


832-832: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

nemo_rl/models/policy/workers/megatron_policy_worker.py

426-428: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


914-914: Local variable worker_start_time is assigned to but never used

Remove assignment to unused variable worker_start_time

(F841)

nemo_rl/models/policy/lm_policy.py

856-859: try-except-pass detected, consider logging the exception

(S110)


856-856: Do not catch blind exception: Exception

(BLE001)

🔇 Additional comments (9)
tests/unit/models/policy/test_dtensor_worker_v2.py (1)

205-207: LGTM: test now exercises the wrapper-based execute_method path.
This matches the new worker invocation model and should better reflect production calling patterns.

Also applies to: 229-231

tests/unit/models/policy/test_dtensor_worker.py (1)

604-606: LGTM: state_dict retrieval now goes through execute_method, consistent with wrapper dispatch.

tests/unit/distributed/test_worker_groups.py (1)

276-285: Execute-method migration looks correct; consider guarding against string-typo drift.

The switch to execute_method.remote("<method>") matches the new wrapper-based RPC and keeps tests aligned with RayWorkerGroup behavior. Only concern is the stringly-typed method names—if you want extra safety, consider a tiny helper/constants for frequently used method names in this test module.

Also applies to: 338-360, 387-406, 483-488, 514-567, 569-640, 968-1014, 1228-1252

nemo_rl/models/policy/__init__.py (1)

174-190: Config typing is good; please add YAML exemplar + document recommended defaults.

You added PolicyConfig.ray_compiled_graph and RayCompiledGraphConfig, but I don’t see corresponding exemplar YAML updates in this PR context—per guidelines, please reflect the new key + recommended defaults under examples/configs/*.yaml. Based on learnings, when adding a new config key to a TypedDict, document it and reflect it in exemplar YAMLs.

Also applies to: 192-224

nemo_rl/algorithms/sft.py (1)

542-560: Warmup trigger placement is fine; please verify get_compiled_graph_config defaults vs “YAML is source of truth”.

The integration point (enabled + total_steps == 0) is sensible. One thing to double-check: if get_compiled_graph_config(...) injects defaults in code, make sure that doesn’t conflict with the repo guideline that YAML is the single source of truth for config defaults.

nemo_rl/models/policy/lm_policy.py (1)

33-37: Worker-group wrapping is clean; please verify compiled-graph gating against Ray 2.49.2.

The conditional wrap into CompiledGraphWorkerGroup keeps the rest of Policy code unchanged and is a good integration point. Please sanity-check should_use_compiled_graph/get_compiled_graph_config behavior against the Ray 2.49.2 compiled graph API and your rollout/training call patterns.

Also applies to: 186-219

nemo_rl/distributed/worker_groups.py (2)

284-290: Wrapper-based execute_method.remote(...) plumbing looks consistent.
The call-site conversions appear coherent and should keep worker invocation uniform for both standard Ray and RCG paths.

Also applies to: 734-737, 819-829, 878-883, 999-1013


141-160: Harden train_compiled() input contract (clear error + less KeyError pain in compiled graphs).
Today a missing "data"/"loss_fn" will surface as a remote KeyError with little context.

 class NeMoRayWorkerWrapper:
@@
     def train_compiled(
         self,
         train_input: dict[str, Any],
     ) -> dict[str, Any]:
@@
-        result = self.worker.train(
-            data=train_input["data"],
-            loss_fn=train_input["loss_fn"],
+        try:
+            data = train_input["data"]
+            loss_fn = train_input["loss_fn"]
+        except KeyError as e:
+            raise KeyError(
+                "train_input must include keys: 'data' and 'loss_fn' "
+                f"(got keys: {list(train_input.keys())})"
+            ) from e
+
+        result = self.worker.train(
+            data=data,
+            loss_fn=loss_fn,
             eval_mode=train_input.get("eval_mode", False),
             gbs=train_input.get("gbs"),
             mbs=train_input.get("mbs"),
         )
         return result

Also applies to: 161-185

⛔ Skipped due to learnings
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to nemo_rl/**/*.py : For any source file under nemo_rl/*.py that defines a class or function decorated with ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Learnt from: ZhiyuLi-Nvidia
Repo: NVIDIA-NeMo/RL PR: 1477
File: nemo_rl/models/generation/vllm/vllm_backend.py:163-168
Timestamp: 2025-11-06T22:30:22.860Z
Learning: For Ray actor methods in the vLLM generation worker code (vllm_backend.py, vllm_worker.py, vllm_worker_async.py), error handling should use print/traceback + return False pattern rather than raising exceptions, following the Ray RPC practice where exceptions may not propagate well across process boundaries.
nemo_rl/distributed/ray_compiled_graph.py (1)

104-112: Axis naming + sharding semantics need verification (possible mismatch with NamedSharding.names).
This code hard-codes "pipeline_parallel", "tensor_parallel", "context_parallel", "data_parallel" and largely ignores in_sharded_axes/replicate_on_axes during compiled execution.

  • Please confirm these axis names exactly match NamedSharding.names in this repo (otherwise get_axis_size(...) / get_ranks(...) will throw).
  • Please confirm callers pass kwargs["data"] in the DP-sharded dict/list format this expects; otherwise compiled vs non-compiled paths won’t be equivalent.

Also applies to: 141-167, 545-562

Comment thread nemo_rl/algorithms/sft.py
Comment on lines +350 to +481
def _warmup_compiled_graph(
policy: PolicyInterface,
loss_fn,
tokenizer: PreTrainedTokenizerBase,
master_config: dict,
rcg_config: dict,
) -> None:
"""Warmup Ray Compiled Graph with maximum sequence length.

This creates fake data with max_total_sequence_length to ensure the
compiled graph is built with the worst-case input shape, avoiding
recompilation during actual training.

Args:
policy: The policy to warmup
loss_fn: Loss function to use
tokenizer: Tokenizer for creating fake tokens
master_config: Master configuration dict
rcg_config: Ray Compiled Graph configuration dict
"""
import torch

from nemo_rl.distributed.batched_data_dict import BatchedDataDict

# Get configuration
max_seq_len = master_config["policy"]["max_total_sequence_length"]
gbs = master_config["policy"]["train_global_batch_size"]
mbs = master_config["policy"]["train_micro_batch_size"]

# Get warmup config from rcg_config (use defaults if not specified)
warmup_seq_len = rcg_config["warmup_seq_len"] or max_seq_len
warmup_gbs = rcg_config["warmup_gbs"] or gbs

print(f" 🔧 Warmup config: SEQ_LEN={warmup_seq_len}, GBS={warmup_gbs}, MBS={mbs}")
print(f" 📦 Creating fake data with shape: ({warmup_gbs}, {warmup_seq_len})")

# Create fake data with max sequence length
# Use valid token IDs from the tokenizer's vocabulary
vocab_size = len(tokenizer)
fake_input_ids = torch.randint(
low=0,
high=min(vocab_size, 32000), # Use reasonable token range
size=(warmup_gbs, warmup_seq_len),
dtype=torch.long,
)

# Create attention mask (all ones = no padding)
fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long)

# Create position IDs
fake_position_ids = (
torch.arange(warmup_seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(warmup_gbs, -1)
)

# Create labels (same as input_ids for SFT)
fake_labels = fake_input_ids.clone()

# Create loss mask (all ones = compute loss on all tokens)
fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)

# All sequences have the same length (max_seq_len)
fake_input_lengths = torch.full((warmup_gbs,), warmup_seq_len, dtype=torch.long)

# Create sample mask (all sequences are valid, no padding/dummy sequences)
fake_sample_mask = torch.ones((warmup_gbs,), dtype=torch.float32)

# Create token mask (all tokens contribute to loss, used for token-level loss)
fake_token_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)

# Create microbatch indices and lengths for sequence packing
# For warmup, all sequences have same length, so we create simple placeholders
num_microbatches = warmup_gbs // mbs
fake_micro_batch_indices = []
fake_micro_batch_lengths = []

for mb_idx in range(num_microbatches):
start_idx = mb_idx * mbs
end_idx = start_idx + mbs
# Each microbatch contains mbs sequences, each of length warmup_seq_len
fake_micro_batch_indices.append(list(range(start_idx, end_idx)))
fake_micro_batch_lengths.append([warmup_seq_len] * mbs)

# Create BatchedDataDict with fake data
warmup_data = BatchedDataDict(
{
"input_ids": fake_input_ids,
"attention_mask": fake_attention_mask,
"position_ids": fake_position_ids,
"labels": fake_labels,
"loss_mask": fake_loss_mask,
"input_lengths": fake_input_lengths,
"sample_mask": fake_sample_mask,
"token_mask": fake_token_mask,
"micro_batch_indices": fake_micro_batch_indices,
"micro_batch_lengths": fake_micro_batch_lengths,
}
)

# Store warmup data in Ray object store for efficient sharing across workers
print(" 🚀 Running warmup training step...")

# Run one training step to trigger graph compilation
# Use eval_mode=True to skip optimizer step (we don't care about gradients)
try:
start_time = torch.cuda.Event(enable_timing=True)
end_time = torch.cuda.Event(enable_timing=True)

start_time.record()
_ = policy.train(
data=warmup_data,
loss_fn=loss_fn,
eval_mode=True, # Skip optimizer step
gbs=warmup_gbs,
mbs=mbs,
)
end_time.record()

# Wait for completion
torch.cuda.synchronize()
warmup_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds

print(f" ✅ Warmup step completed in {warmup_time:.2f}s")
print(" 💾 Compiled graph is now cached and ready for training")

except Exception as e:
print(f" ⚠️ Warmup failed: {e}")
print(
" ℹ️ Continuing with normal training (graph will compile on first real step)"
)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Warmup should match real SFT inputs; avoid driver-side CUDA timing; validate batch divisibility.

Right now warmup builds extra keys (attention_mask, position_ids, labels, loss_mask, microbatch lists) that SFT training doesn’t pass, and times via torch.cuda.Event in the driver (which may not have CUDA). I’d recommend: (1) construct warmup data with the same keys as real SFT (input_ids, input_lengths, token_mask, sample_mask + multimodal if relevant), (2) use time.perf_counter() on the driver, and (3) assert warmup_gbs % mbs == 0 (and ideally divisibility vs DP if required by your sharding path).

@@
 def _warmup_compiled_graph(
@@
 ) -> None:
@@
-    import torch
+    import time
@@
-    # Create fake data with max sequence length
-    # Use valid token IDs from the tokenizer's vocabulary
+    # Create fake data matching real SFT train() inputs as closely as possible.
     vocab_size = len(tokenizer)
     fake_input_ids = torch.randint(
@@
-    # Create attention mask (all ones = no padding)
-    fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long)
-
-    # Create position IDs
-    fake_position_ids = (
-        torch.arange(warmup_seq_len, dtype=torch.long)
-        .unsqueeze(0)
-        .expand(warmup_gbs, -1)
-    )
-
-    # Create labels (same as input_ids for SFT)
-    fake_labels = fake_input_ids.clone()
-
-    # Create loss mask (all ones = compute loss on all tokens)
-    fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)
+    if warmup_gbs % mbs != 0:
+        raise ValueError(f"warmup_gbs ({warmup_gbs}) must be divisible by mbs ({mbs})")
@@
-    # Create BatchedDataDict with fake data
     warmup_data = BatchedDataDict(
         {
             "input_ids": fake_input_ids,
-            "attention_mask": fake_attention_mask,
-            "position_ids": fake_position_ids,
-            "labels": fake_labels,
-            "loss_mask": fake_loss_mask,
             "input_lengths": fake_input_lengths,
             "sample_mask": fake_sample_mask,
             "token_mask": fake_token_mask,
-            "micro_batch_indices": fake_micro_batch_indices,
-            "micro_batch_lengths": fake_micro_batch_lengths,
         }
     )
@@
-    try:
-        start_time = torch.cuda.Event(enable_timing=True)
-        end_time = torch.cuda.Event(enable_timing=True)
-
-        start_time.record()
-        _ = policy.train(
+    try:
+        t0 = time.perf_counter()
+        _ = policy.train(
             data=warmup_data,
             loss_fn=loss_fn,
             eval_mode=True,  # Skip optimizer step
             gbs=warmup_gbs,
             mbs=mbs,
         )
-        end_time.record()
-
-        # Wait for completion
-        torch.cuda.synchronize()
-        warmup_time = start_time.elapsed_time(end_time) / 1000.0  # Convert to seconds
+        warmup_time = time.perf_counter() - t0

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.8)

476-476: Do not catch blind exception: Exception

(BLE001)


479-479: String contains ambiguous (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?

(RUF001)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modify the warmup_data would cause overhead in training.
Only modify CUDA timing

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@katec846, understood! I'll focus only on the CUDA timing issue. Here's the updated suggestion that replaces torch.cuda.Event with time.perf_counter() to avoid driver-side CUDA dependencies:

def _warmup_compiled_graph(
    policy: PolicyInterface,
    loss_fn,
    tokenizer: PreTrainedTokenizerBase,
    master_config: dict,
    rcg_config: dict,
) -> None:
    """Warmup Ray Compiled Graph with maximum sequence length.

    This creates fake data with max_total_sequence_length to ensure the
    compiled graph is built with the worst-case input shape, avoiding
    recompilation during actual training.

    Args:
        policy: The policy to warmup
        loss_fn: Loss function to use
        tokenizer: Tokenizer for creating fake tokens
        master_config: Master configuration dict
        rcg_config: Ray Compiled Graph configuration dict
    """
    import time
    import torch

    from nemo_rl.distributed.batched_data_dict import BatchedDataDict

    # Get configuration
    max_seq_len = master_config["policy"]["max_total_sequence_length"]
    gbs = master_config["policy"]["train_global_batch_size"]
    mbs = master_config["policy"]["train_micro_batch_size"]

    # Get warmup config from rcg_config (use defaults if not specified)
    warmup_seq_len = rcg_config["warmup_seq_len"] or max_seq_len
    warmup_gbs = rcg_config["warmup_gbs"] or gbs

    print(f"  🔧 Warmup config: SEQ_LEN={warmup_seq_len}, GBS={warmup_gbs}, MBS={mbs}")
    print(f"  📦 Creating fake data with shape: ({warmup_gbs}, {warmup_seq_len})")

    # Create fake data with max sequence length
    # Use valid token IDs from the tokenizer's vocabulary
    vocab_size = len(tokenizer)
    fake_input_ids = torch.randint(
        low=0,
        high=min(vocab_size, 32000),  # Use reasonable token range
        size=(warmup_gbs, warmup_seq_len),
        dtype=torch.long,
    )

    # Create attention mask (all ones = no padding)
    fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long)

    # Create position IDs
    fake_position_ids = (
        torch.arange(warmup_seq_len, dtype=torch.long)
        .unsqueeze(0)
        .expand(warmup_gbs, -1)
    )

    # Create labels (same as input_ids for SFT)
    fake_labels = fake_input_ids.clone()

    # Create loss mask (all ones = compute loss on all tokens)
    fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)

    # All sequences have the same length (max_seq_len)
    fake_input_lengths = torch.full((warmup_gbs,), warmup_seq_len, dtype=torch.long)

    # Create sample mask (all sequences are valid, no padding/dummy sequences)
    fake_sample_mask = torch.ones((warmup_gbs,), dtype=torch.float32)

    # Create token mask (all tokens contribute to loss, used for token-level loss)
    fake_token_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)

    # Create microbatch indices and lengths for sequence packing
    # For warmup, all sequences have same length, so we create simple placeholders
    num_microbatches = warmup_gbs // mbs
    fake_micro_batch_indices = []
    fake_micro_batch_lengths = []

    for mb_idx in range(num_microbatches):
        start_idx = mb_idx * mbs
        end_idx = start_idx + mbs
        # Each microbatch contains mbs sequences, each of length warmup_seq_len
        fake_micro_batch_indices.append(list(range(start_idx, end_idx)))
        fake_micro_batch_lengths.append([warmup_seq_len] * mbs)

    # Create BatchedDataDict with fake data
    warmup_data = BatchedDataDict(
        {
            "input_ids": fake_input_ids,
            "attention_mask": fake_attention_mask,
            "position_ids": fake_position_ids,
            "labels": fake_labels,
            "loss_mask": fake_loss_mask,
            "input_lengths": fake_input_lengths,
            "sample_mask": fake_sample_mask,
            "token_mask": fake_token_mask,
            "micro_batch_indices": fake_micro_batch_indices,
            "micro_batch_lengths": fake_micro_batch_lengths,
        }
    )

    # Store warmup data in Ray object store for efficient sharing across workers
    print("  🚀 Running warmup training step...")

    # Run one training step to trigger graph compilation
    # Use eval_mode=True to skip optimizer step (we don't care about gradients)
    try:
        t0 = time.perf_counter()
        _ = policy.train(
            data=warmup_data,
            loss_fn=loss_fn,
            eval_mode=True,  # Skip optimizer step
            gbs=warmup_gbs,
            mbs=mbs,
        )
        warmup_time = time.perf_counter() - t0

        print(f"  ✅ Warmup step completed in {warmup_time:.2f}s")
        print("  💾 Compiled graph is now cached and ready for training")

    except Exception as e:
        print(f"  ⚠️  Warmup failed: {e}")
        print(
            "  ℹ️  Continuing with normal training (graph will compile on first real step)"
        )

This change uses time.perf_counter() which works reliably on the driver without requiring CUDA availability.

Comment thread nemo_rl/distributed/ray_compiled_graph.py
Comment thread nemo_rl/distributed/ray_compiled_graph.py Outdated
Comment thread nemo_rl/distributed/ray_compiled_graph.py
Comment on lines +693 to +701
# Wrap in MultiWorkerFuture
from nemo_rl.distributed.worker_groups import MultiWorkerFuture

return MultiWorkerFuture(
futures=refs if isinstance(refs, list) else [refs],
return_from_workers=return_from_workers,
called_workers=list(range(len(refs))), # All workers were called
)
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

len(refs) can break when refs isn’t a list — normalize once.
Both DP=1 and DP>1 branches can construct called_workers using len(refs) even when refs may be a single ObjectRef.

-            return MultiWorkerFuture(
-                futures=refs if isinstance(refs, list) else [refs],
+            refs_list = refs if isinstance(refs, list) else [refs]
+            return MultiWorkerFuture(
+                futures=refs_list,
                 return_from_workers=return_from_workers,
-                called_workers=list(range(len(refs))),  # All workers were called
+                called_workers=list(range(len(refs_list))),  # All workers were called
             )
@@
-            return MultiWorkerFuture(
-                futures=refs if isinstance(refs, list) else [refs],
+            refs_list = refs if isinstance(refs, list) else [refs]
+            return MultiWorkerFuture(
+                futures=refs_list,
                 return_from_workers=return_from_workers,
-                called_workers=list(range(len(refs))),  # All workers were called
+                called_workers=list(range(len(refs_list))),  # All workers were called
             )

Also applies to: 770-777

🤖 Prompt for AI Agents
nemo_rl/distributed/ray_compiled_graph.py around lines 693-701 (and similarly
770-777): refs may be a single ObjectRef or a list; normalize it once to avoid
calling len() on a non-list. Replace direct use of refs with a local refs_list =
refs if isinstance(refs, list) else [refs], use futures=refs_list,
called_workers=list(range(len(refs_list))) and return_from_workers unchanged;
apply the same normalization in the DP>1 branch around lines 770-777.

Comment on lines 209 to 210
@ray.remote
class IsolatedWorkerInitializer:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add Ray-process coverage pragma for @ray.remote per repo guidelines.
This file defines a @ray.remote class under nemo_rl/**, so it should be excluded from coverage.

Based on coding guidelines, ...

 class RayWorkerBuilder:
-    @ray.remote
+    @ray.remote  # pragma: no cover
     class IsolatedWorkerInitializer:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@ray.remote
class IsolatedWorkerInitializer:
@ray.remote # pragma: no cover
class IsolatedWorkerInitializer:
🤖 Prompt for AI Agents
In nemo_rl/distributed/worker_groups.py around lines 209 to 210, the @ray.remote
class needs to be excluded from test coverage; append the repository's
Ray-process coverage pragma to the decorator/class declaration (e.g., add the
coverage pragma comment such as "# pragma: no cover (ray-process)" on the same
line or immediately above the @ray.remote line) so the Ray worker process code
is omitted from coverage reports.

Comment on lines +95 to +102
# NOTE: @ray.remote decorator removed to fix Ray bug #26283 with compiled graphs.
# Ray.remote() is now applied at runtime in RayWorkerBuilder (like vLLM does).
class DTensorPolicyWorkerV2(AbstractPolicyWorker, ColocatablePolicyInterface):
# Runtime environment to use when applying ray.remote() at runtime
_ray_remote_runtime_env = get_runtime_env_for_policy_worker(
"dtensor_policy_worker_v2"
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for references to _ray_remote_runtime_env and _default_options
rg -n -S '_ray_remote_runtime_env|_default_options' nemo_rl/ --type py

Repository: NVIDIA-NeMo/RL

Length of output: 469


🏁 Script executed:

#!/bin/bash
# Search for RayWorkerBuilder, wrapper classes, and ray.remote application logic
rg -n 'RayWorkerBuilder|NeMoRayWorkerWrapper|apply.*ray\.remote|ray\.remote\(' nemo_rl/ --type py

Repository: NVIDIA-NeMo/RL

Length of output: 3605


🏁 Script executed:

#!/bin/bash
# Find all policy worker files to see naming patterns
fd -e py 'policy.*worker' nemo_rl/models/policy/workers/ | head -20

Repository: NVIDIA-NeMo/RL

Length of output: 278


🏁 Script executed:

#!/bin/bash
# Examine the get_runtime_env_for_policy_worker function
rg -n 'def get_runtime_env_for_policy_worker' nemo_rl/ -A 10 --type py

Repository: NVIDIA-NeMo/RL

Length of output: 869


Use _default_options instead of _ray_remote_runtime_env for naming consistency.

RayWorkerBuilder (nemo_rl/distributed/worker_groups.py:249) explicitly looks for _default_options using getattr(worker_class, "_default_options", {}). DTensorPolicyWorker and MegatronPolicyWorker both use this convention, but DTensorPolicyWorkerV2 uses _ray_remote_runtime_env instead. This naming mismatch means the runtime environment configuration will be silently dropped when the builder applies ray.remote() at runtime.

Change line 99 to:

_default_options = get_runtime_env_for_policy_worker(
    "dtensor_policy_worker_v2"
)
🤖 Prompt for AI Agents
In nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py around lines 95 to
102, the class attribute is named _ray_remote_runtime_env but RayWorkerBuilder
expects _default_options, so change the attribute name to _default_options and
assign it the result of
get_runtime_env_for_policy_worker("dtensor_policy_worker_v2"); ensure no other
references to the old name remain and update any imports or usages if necessary
so the runtime env is picked up when Ray.remote() is applied at runtime.

Comment thread nemo_rl/models/policy/workers/megatron_policy_worker.py Outdated
Comment thread ray.sub
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 7a7178a (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@katec846 katec846 requested a review from guyueh1 December 12, 2025 20:14
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
@github-actions
Copy link
Copy Markdown

ℹ️ File Consistency Check

Check based on commit: 47d5615 (PR #1612 from ray_overhead_exp)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
@katec846 katec846 changed the title feat: Support Ray Compiled Graph and NCCL-optimized data transfer for SFT feat: Support Ray Compiled Graph for SFT Dec 18, 2025
@guyueh1 guyueh1 linked an issue Jan 6, 2026 that may be closed by this pull request
@guyueh1 guyueh1 added the Performance Related to improving performance label Mar 5, 2026
@anwithk anwithk added this to the v0.6 Release milestone Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Performance Related to improving performance

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reduce Ray overhead in SFT with Ray Compiled Graph

3 participants