feat: Support Ray Compiled Graph for SFT#1612
feat: Support Ray Compiled Graph for SFT#1612katec846 wants to merge 14 commits intoNVIDIA-NeMo:mainfrom
Conversation
a20a721 to
3d4324c
Compare
ℹ️ File Consistency CheckCheck based on commit: 3d4324c (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 6e7810a (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: d319f3f (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 6edb10b (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: f55280d (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 89d7d95 (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 018e95a (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 8c3204f (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: f28a4d8 (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
f28a4d8 to
d60c225
Compare
ℹ️ File Consistency CheckCheck based on commit: d60c225 (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
📝 WalkthroughWalkthroughThis PR implements Ray Compiled Graph (RCG) support for distributed training. It introduces a new compiled graph infrastructure module, wraps workers with a uniform method invocation interface, removes Changes
Sequence DiagramsequenceDiagram
participant SFT as SFT Trainer
participant RCGWrap as CompiledGraphWorkerGroup
participant CGExec as CompiledGraphExecutor
participant DAG as Ray Compiled DAG
participant Worker as Policy Worker
SFT->>RCGWrap: Initialize with config (RCG enabled)
RCGWrap->>RCGWrap: Wrap RayWorkerGroup
Note over SFT,DAG: Warmup Phase (total_steps == 0)
SFT->>RCGWrap: _warmup_compiled_graph(synthetic_data)
RCGWrap->>CGExec: Create executor for 'train' method
CGExec->>DAG: Build & compile DAG with workers
DAG->>Worker: Execute train_compiled(synthetic_batch)
Worker-->>DAG: Return training results
DAG-->>CGExec: Results via compiled kernel
CGExec-->>RCGWrap: Warmup complete, DAG ready
Note over SFT,DAG: Training Phase
SFT->>RCGWrap: run_all_workers_sharded_data(train, real_data)
RCGWrap->>CGExec: Use compiled executor (cached)
CGExec->>DAG: Execute train_compiled(real_batch)
DAG->>Worker: Run via compiled kernel
Worker-->>DAG: Training outputs
DAG-->>CGExec: Fused result
CGExec-->>RCGWrap: Return to caller
RCGWrap-->>SFT: Training step complete
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
510-523: Remove unused timing code (or plumb it into metrics/logging).
worker_start_timeis currently unused, so this is dead code and will fail lint (F841).@@ def train( @@ ) -> dict[str, Any]: """Train the policy on a batch of data with a given loss function.""" - import time - - worker_start_time = time.time() - if gbs is None: gbs = self.cfg["train_global_batch_size"]nemo_rl/distributed/worker_groups.py (1)
886-1018:run_all_workers_sharded_data()only shards the first axis inin_sharded_axesdue to earlybreak— limits to single-axis despite multi-axis wording in docstring.
The method accepts a list of axes but exits the loop after processing the first match (line ~956), silently ignoring subsequent axes. All current usage is single-axis only. Either assert single-axis at the start to prevent misconfiguration, or implement proper nested indexing if multi-axis is required.Example fix to enforce single-axis:
def run_all_workers_sharded_data( @@ if in_sharded_axes is None: in_sharded_axes = [] + if len(in_sharded_axes) > 1: + raise ValueError( + f"Only single-axis sharding is supported (got {in_sharded_axes})." + )
🧹 Nitpick comments (6)
ray.sub (1)
319-321: Export on head too (consistency) and quote the export.
You exportRAY_LOG_SYNC_FREQUENCYin worker containers, but the head command relies on outer interpolation; exporting in both places keeps behavior consistent and avoids surprises if the heredoc content changes.-export RAY_LOG_SYNC_FREQUENCY=$RAY_LOG_SYNC_FREQUENCY +export RAY_LOG_SYNC_FREQUENCY="${RAY_LOG_SYNC_FREQUENCY}"nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
422-429: Annotate_default_optionsasClassVar(and consider preventing mutation).
This matches Ruff RUF012 and avoids accidental per-class shared mutation surprises.-from typing import Any, Iterator, Optional, TypeVar, cast +from typing import Any, ClassVar, Iterator, Optional, TypeVar, cast @@ class MegatronPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface): # Default options to use when applying ray.remote() at runtime - _default_options = { + _default_options: ClassVar[dict[str, Any]] = { "runtime_env": get_runtime_env_for_policy_worker("megatron_policy_worker") }nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
134-141: Annotate_default_optionsasClassVar(and keep it immutable-ish).Ruff is right here: make the class attribute a
ClassVarto avoid accidental instance-level expectations and typing noise.@@ -from typing import Any, Generator, Iterable, Optional, Set, Union, cast +from typing import Any, ClassVar, Generator, Iterable, Optional, Set, Union, cast @@ class DTensorPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface): # Default options to use when applying ray.remote() at runtime - _default_options = { + _default_options: ClassVar[dict[str, Any]] = { "runtime_env": get_runtime_env_for_policy_worker("dtensor_policy_worker") }nemo_rl/models/policy/lm_policy.py (1)
853-859: Avoid fully silent__del__failures (optional).I get why you’re suppressing teardown errors during interpreter exit, but a totally silent
except Exception: passcan mask real cleanup regressions. Consider a minimal opt-in debug log (e.g., behind an env var) instead of unconditional silence.nemo_rl/distributed/worker_groups.py (1)
141-158: Ray class “unwrapping” uses private attributes — verify against Ray 2.49.2.
_ray_actor_class/__ray_metadata__feel fragile across Ray versions; if they’re wrong, wrapper will instantiate the wrong thing (or fail) in hard-to-debug ways.
- Please verify on Ray 2.49.2 that a remotely-decorated class actually exposes
_ray_actor_classand/or__ray_metadata__.modified_classin the forms expected here, and add a small unit test that passes a@ray.remoteclass intoNeMoRayWorkerWrapper(even if production no longer does).nemo_rl/distributed/ray_compiled_graph.py (1)
81-123: A few low-friction cleanups to align with repo style/lint.
Mostly Ruff-driven / maintainability.
- Prefer
logger.exception(...)insideexcept Exceptionblocks where you re-raise or want stack traces (TRY400).- Replace the Unicode
×in docstrings/comments withx(RUF002/RUF003) if Ruff is enforced.- Consider narrowing teardown
except Exceptionto Ray-specific teardown errors if feasible (BLE001); if not, add a short comment why broad is required here.- If you’re on Python 3.12+ and Ruff enforces it, add
strict=tozip(...)where applicable (B905).Also applies to: 227-237, 260-293, 368-402, 501-507, 622-635
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (12)
nemo_rl/algorithms/sft.py(2 hunks)nemo_rl/distributed/ray_compiled_graph.py(1 hunks)nemo_rl/distributed/worker_groups.py(8 hunks)nemo_rl/models/policy/__init__.py(2 hunks)nemo_rl/models/policy/lm_policy.py(5 hunks)nemo_rl/models/policy/workers/dtensor_policy_worker.py(2 hunks)nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py(1 hunks)nemo_rl/models/policy/workers/megatron_policy_worker.py(2 hunks)ray.sub(2 hunks)tests/unit/distributed/test_worker_groups.py(30 hunks)tests/unit/models/policy/test_dtensor_worker.py(1 hunks)tests/unit/models/policy/test_dtensor_worker_v2.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/__init__.pynemo_rl/algorithms/sft.pytests/unit/distributed/test_worker_groups.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/distributed/ray_compiled_graph.pynemo_rl/models/policy/workers/megatron_policy_worker.pynemo_rl/models/policy/lm_policy.pytests/unit/models/policy/test_dtensor_worker_v2.pynemo_rl/distributed/worker_groups.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/__init__.pynemo_rl/algorithms/sft.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/distributed/ray_compiled_graph.pynemo_rl/models/policy/workers/megatron_policy_worker.pynemo_rl/models/policy/lm_policy.pynemo_rl/distributed/worker_groups.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/__init__.pyray.subnemo_rl/algorithms/sft.pytests/unit/distributed/test_worker_groups.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/distributed/ray_compiled_graph.pynemo_rl/models/policy/workers/megatron_policy_worker.pynemo_rl/models/policy/lm_policy.pytests/unit/models/policy/test_dtensor_worker_v2.pynemo_rl/distributed/worker_groups.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/__init__.pynemo_rl/algorithms/sft.pytests/unit/distributed/test_worker_groups.pytests/unit/models/policy/test_dtensor_worker.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/distributed/ray_compiled_graph.pynemo_rl/models/policy/workers/megatron_policy_worker.pynemo_rl/models/policy/lm_policy.pytests/unit/models/policy/test_dtensor_worker_v2.pynemo_rl/distributed/worker_groups.py
🧠 Learnings (4)
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to nemo_rl/**/*.py : For any source file under nemo_rl/*.py that defines a class or function decorated with ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Applied to files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/distributed/ray_compiled_graph.pynemo_rl/models/policy/workers/megatron_policy_worker.pynemo_rl/distributed/worker_groups.py
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to **/*.py : When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Applied to files:
nemo_rl/models/policy/__init__.py
📚 Learning: 2025-09-19T02:44:38.451Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:73-84
Timestamp: 2025-09-19T02:44:38.451Z
Learning: The scheduler configuration format with a separate "milestones: [20]" entry (not wrapped under name/kwargs) is a valid and established pattern used across GRPO, DPO, and distillation configs in the NeMo RL codebase. This format specifies transition points between different schedulers (e.g., LinearLR for warmup steps, then ConstantLR).
Applied to files:
nemo_rl/models/policy/__init__.py
📚 Learning: 2025-11-06T22:30:22.860Z
Learnt from: ZhiyuLi-Nvidia
Repo: NVIDIA-NeMo/RL PR: 1477
File: nemo_rl/models/generation/vllm/vllm_backend.py:163-168
Timestamp: 2025-11-06T22:30:22.860Z
Learning: For Ray actor methods in the vLLM generation worker code (vllm_backend.py, vllm_worker.py, vllm_worker_async.py), error handling should use print/traceback + return False pattern rather than raising exceptions, following the Ray RPC practice where exceptions may not propagate well across process boundaries.
Applied to files:
tests/unit/distributed/test_worker_groups.pynemo_rl/distributed/worker_groups.py
🧬 Code graph analysis (9)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/models/policy/utils.py (1)
get_runtime_env_for_policy_worker(286-296)
nemo_rl/algorithms/sft.py (3)
nemo_rl/models/policy/interfaces.py (2)
PolicyInterface(50-157)train(100-118)nemo_rl/distributed/batched_data_dict.py (2)
BatchedDataDict(75-860)size(814-823)nemo_rl/distributed/ray_compiled_graph.py (1)
get_compiled_graph_config(418-442)
tests/unit/distributed/test_worker_groups.py (1)
nemo_rl/distributed/worker_groups.py (3)
execute_method(186-201)workers(699-700)run_all_workers_single_data(838-884)
tests/unit/models/policy/test_dtensor_worker.py (1)
nemo_rl/distributed/worker_groups.py (2)
workers(699-700)execute_method(186-201)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (4)
nemo_rl/models/policy/workers/base_policy_worker.py (1)
AbstractPolicyWorker(25-154)nemo_rl/models/policy/interfaces.py (1)
ColocatablePolicyInterface(160-193)nemo_rl/models/policy/utils.py (1)
get_runtime_env_for_policy_worker(286-296)nemo_rl/utils/timer.py (1)
time(110-123)
nemo_rl/distributed/ray_compiled_graph.py (2)
nemo_rl/distributed/named_sharding.py (4)
NamedSharding(19-222)get_axis_size(209-211)get_ranks(155-199)get_worker_coords(103-122)nemo_rl/distributed/worker_groups.py (1)
train_compiled(161-184)
nemo_rl/models/policy/workers/megatron_policy_worker.py (4)
nemo_rl/models/policy/workers/base_policy_worker.py (1)
AbstractPolicyWorker(25-154)nemo_rl/models/policy/interfaces.py (1)
ColocatablePolicyInterface(160-193)nemo_rl/models/policy/utils.py (1)
get_runtime_env_for_policy_worker(286-296)nemo_rl/utils/timer.py (1)
time(110-123)
tests/unit/models/policy/test_dtensor_worker_v2.py (1)
nemo_rl/distributed/worker_groups.py (1)
execute_method(186-201)
nemo_rl/distributed/worker_groups.py (2)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
train(490-866)nemo_rl/models/policy/interfaces.py (1)
train(100-118)
🪛 Ruff (0.14.8)
nemo_rl/algorithms/sft.py
476-476: Do not catch blind exception: Exception
(BLE001)
479-479: String contains ambiguous ℹ (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?
(RUF001)
nemo_rl/models/policy/workers/dtensor_policy_worker.py
138-140: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
522-522: Local variable worker_start_time is assigned to but never used
Remove assignment to unused variable worker_start_time
(F841)
nemo_rl/distributed/ray_compiled_graph.py
117-121: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
131-131: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
143-143: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
160-163: Prefer TypeError exception for invalid type
(TRY004)
160-163: Avoid specifying long messages outside the exception class
(TRY003)
200-200: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
203-203: Loop control variable idx not used within loop body
Rename unused idx to _idx
(B007)
284-284: Do not catch blind exception: Exception
(BLE001)
291-291: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
393-393: Do not catch blind exception: Exception
(BLE001)
400-400: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
501-501: Do not catch blind exception: Exception
(BLE001)
502-504: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
528-528: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
528-528: Docstring contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF002)
609-609: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
609-609: Comment contains ambiguous × (MULTIPLICATION SIGN). Did you mean x (LATIN SMALL LETTER X)?
(RUF003)
622-622: Do not catch blind exception: Exception
(BLE001)
623-623: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
708-710: Avoid specifying long messages outside the exception class
(TRY003)
720-720: Avoid specifying long messages outside the exception class
(TRY003)
747-747: Loop control variable dp_rank not used within loop body
(B007)
824-824: Do not catch blind exception: Exception
(BLE001)
832-832: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
nemo_rl/models/policy/workers/megatron_policy_worker.py
426-428: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
914-914: Local variable worker_start_time is assigned to but never used
Remove assignment to unused variable worker_start_time
(F841)
nemo_rl/models/policy/lm_policy.py
856-859: try-except-pass detected, consider logging the exception
(S110)
856-856: Do not catch blind exception: Exception
(BLE001)
🔇 Additional comments (9)
tests/unit/models/policy/test_dtensor_worker_v2.py (1)
205-207: LGTM: test now exercises the wrapper-basedexecute_methodpath.
This matches the new worker invocation model and should better reflect production calling patterns.Also applies to: 229-231
tests/unit/models/policy/test_dtensor_worker.py (1)
604-606: LGTM: state_dict retrieval now goes throughexecute_method, consistent with wrapper dispatch.tests/unit/distributed/test_worker_groups.py (1)
276-285: Execute-method migration looks correct; consider guarding against string-typo drift.The switch to
execute_method.remote("<method>")matches the new wrapper-based RPC and keeps tests aligned withRayWorkerGroupbehavior. Only concern is the stringly-typed method names—if you want extra safety, consider a tiny helper/constants for frequently used method names in this test module.Also applies to: 338-360, 387-406, 483-488, 514-567, 569-640, 968-1014, 1228-1252
nemo_rl/models/policy/__init__.py (1)
174-190: Config typing is good; please add YAML exemplar + document recommended defaults.You added
PolicyConfig.ray_compiled_graphandRayCompiledGraphConfig, but I don’t see corresponding exemplar YAML updates in this PR context—per guidelines, please reflect the new key + recommended defaults underexamples/configs/*.yaml. Based on learnings, when adding a new config key to aTypedDict, document it and reflect it in exemplar YAMLs.Also applies to: 192-224
nemo_rl/algorithms/sft.py (1)
542-560: Warmup trigger placement is fine; please verifyget_compiled_graph_configdefaults vs “YAML is source of truth”.The integration point (enabled +
total_steps == 0) is sensible. One thing to double-check: ifget_compiled_graph_config(...)injects defaults in code, make sure that doesn’t conflict with the repo guideline that YAML is the single source of truth for config defaults.nemo_rl/models/policy/lm_policy.py (1)
33-37: Worker-group wrapping is clean; please verify compiled-graph gating against Ray 2.49.2.The conditional wrap into
CompiledGraphWorkerGroupkeeps the rest ofPolicycode unchanged and is a good integration point. Please sanity-checkshould_use_compiled_graph/get_compiled_graph_configbehavior against the Ray 2.49.2 compiled graph API and your rollout/training call patterns.Also applies to: 186-219
nemo_rl/distributed/worker_groups.py (2)
284-290: Wrapper-basedexecute_method.remote(...)plumbing looks consistent.
The call-site conversions appear coherent and should keep worker invocation uniform for both standard Ray and RCG paths.Also applies to: 734-737, 819-829, 878-883, 999-1013
141-160: Hardentrain_compiled()input contract (clear error + less KeyError pain in compiled graphs).
Today a missing"data"/"loss_fn"will surface as a remote KeyError with little context.class NeMoRayWorkerWrapper: @@ def train_compiled( self, train_input: dict[str, Any], ) -> dict[str, Any]: @@ - result = self.worker.train( - data=train_input["data"], - loss_fn=train_input["loss_fn"], + try: + data = train_input["data"] + loss_fn = train_input["loss_fn"] + except KeyError as e: + raise KeyError( + "train_input must include keys: 'data' and 'loss_fn' " + f"(got keys: {list(train_input.keys())})" + ) from e + + result = self.worker.train( + data=data, + loss_fn=loss_fn, eval_mode=train_input.get("eval_mode", False), gbs=train_input.get("gbs"), mbs=train_input.get("mbs"), ) return resultAlso applies to: 161-185
⛔ Skipped due to learnings
Learnt from: CR Repo: NVIDIA-NeMo/RL PR: 0 File: CODING_GUIDELINES.md:0-0 Timestamp: 2025-11-24T17:24:41.976Z Learning: Applies to nemo_rl/**/*.py : For any source file under nemo_rl/*.py that defines a class or function decorated with ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processesLearnt from: ZhiyuLi-Nvidia Repo: NVIDIA-NeMo/RL PR: 1477 File: nemo_rl/models/generation/vllm/vllm_backend.py:163-168 Timestamp: 2025-11-06T22:30:22.860Z Learning: For Ray actor methods in the vLLM generation worker code (vllm_backend.py, vllm_worker.py, vllm_worker_async.py), error handling should use print/traceback + return False pattern rather than raising exceptions, following the Ray RPC practice where exceptions may not propagate well across process boundaries.nemo_rl/distributed/ray_compiled_graph.py (1)
104-112: Axis naming + sharding semantics need verification (possible mismatch withNamedSharding.names).
This code hard-codes"pipeline_parallel","tensor_parallel","context_parallel","data_parallel"and largely ignoresin_sharded_axes/replicate_on_axesduring compiled execution.
- Please confirm these axis names exactly match
NamedSharding.namesin this repo (otherwiseget_axis_size(...)/get_ranks(...)will throw).- Please confirm callers pass
kwargs["data"]in the DP-sharded dict/list format this expects; otherwise compiled vs non-compiled paths won’t be equivalent.Also applies to: 141-167, 545-562
| def _warmup_compiled_graph( | ||
| policy: PolicyInterface, | ||
| loss_fn, | ||
| tokenizer: PreTrainedTokenizerBase, | ||
| master_config: dict, | ||
| rcg_config: dict, | ||
| ) -> None: | ||
| """Warmup Ray Compiled Graph with maximum sequence length. | ||
|
|
||
| This creates fake data with max_total_sequence_length to ensure the | ||
| compiled graph is built with the worst-case input shape, avoiding | ||
| recompilation during actual training. | ||
|
|
||
| Args: | ||
| policy: The policy to warmup | ||
| loss_fn: Loss function to use | ||
| tokenizer: Tokenizer for creating fake tokens | ||
| master_config: Master configuration dict | ||
| rcg_config: Ray Compiled Graph configuration dict | ||
| """ | ||
| import torch | ||
|
|
||
| from nemo_rl.distributed.batched_data_dict import BatchedDataDict | ||
|
|
||
| # Get configuration | ||
| max_seq_len = master_config["policy"]["max_total_sequence_length"] | ||
| gbs = master_config["policy"]["train_global_batch_size"] | ||
| mbs = master_config["policy"]["train_micro_batch_size"] | ||
|
|
||
| # Get warmup config from rcg_config (use defaults if not specified) | ||
| warmup_seq_len = rcg_config["warmup_seq_len"] or max_seq_len | ||
| warmup_gbs = rcg_config["warmup_gbs"] or gbs | ||
|
|
||
| print(f" 🔧 Warmup config: SEQ_LEN={warmup_seq_len}, GBS={warmup_gbs}, MBS={mbs}") | ||
| print(f" 📦 Creating fake data with shape: ({warmup_gbs}, {warmup_seq_len})") | ||
|
|
||
| # Create fake data with max sequence length | ||
| # Use valid token IDs from the tokenizer's vocabulary | ||
| vocab_size = len(tokenizer) | ||
| fake_input_ids = torch.randint( | ||
| low=0, | ||
| high=min(vocab_size, 32000), # Use reasonable token range | ||
| size=(warmup_gbs, warmup_seq_len), | ||
| dtype=torch.long, | ||
| ) | ||
|
|
||
| # Create attention mask (all ones = no padding) | ||
| fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long) | ||
|
|
||
| # Create position IDs | ||
| fake_position_ids = ( | ||
| torch.arange(warmup_seq_len, dtype=torch.long) | ||
| .unsqueeze(0) | ||
| .expand(warmup_gbs, -1) | ||
| ) | ||
|
|
||
| # Create labels (same as input_ids for SFT) | ||
| fake_labels = fake_input_ids.clone() | ||
|
|
||
| # Create loss mask (all ones = compute loss on all tokens) | ||
| fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32) | ||
|
|
||
| # All sequences have the same length (max_seq_len) | ||
| fake_input_lengths = torch.full((warmup_gbs,), warmup_seq_len, dtype=torch.long) | ||
|
|
||
| # Create sample mask (all sequences are valid, no padding/dummy sequences) | ||
| fake_sample_mask = torch.ones((warmup_gbs,), dtype=torch.float32) | ||
|
|
||
| # Create token mask (all tokens contribute to loss, used for token-level loss) | ||
| fake_token_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32) | ||
|
|
||
| # Create microbatch indices and lengths for sequence packing | ||
| # For warmup, all sequences have same length, so we create simple placeholders | ||
| num_microbatches = warmup_gbs // mbs | ||
| fake_micro_batch_indices = [] | ||
| fake_micro_batch_lengths = [] | ||
|
|
||
| for mb_idx in range(num_microbatches): | ||
| start_idx = mb_idx * mbs | ||
| end_idx = start_idx + mbs | ||
| # Each microbatch contains mbs sequences, each of length warmup_seq_len | ||
| fake_micro_batch_indices.append(list(range(start_idx, end_idx))) | ||
| fake_micro_batch_lengths.append([warmup_seq_len] * mbs) | ||
|
|
||
| # Create BatchedDataDict with fake data | ||
| warmup_data = BatchedDataDict( | ||
| { | ||
| "input_ids": fake_input_ids, | ||
| "attention_mask": fake_attention_mask, | ||
| "position_ids": fake_position_ids, | ||
| "labels": fake_labels, | ||
| "loss_mask": fake_loss_mask, | ||
| "input_lengths": fake_input_lengths, | ||
| "sample_mask": fake_sample_mask, | ||
| "token_mask": fake_token_mask, | ||
| "micro_batch_indices": fake_micro_batch_indices, | ||
| "micro_batch_lengths": fake_micro_batch_lengths, | ||
| } | ||
| ) | ||
|
|
||
| # Store warmup data in Ray object store for efficient sharing across workers | ||
| print(" 🚀 Running warmup training step...") | ||
|
|
||
| # Run one training step to trigger graph compilation | ||
| # Use eval_mode=True to skip optimizer step (we don't care about gradients) | ||
| try: | ||
| start_time = torch.cuda.Event(enable_timing=True) | ||
| end_time = torch.cuda.Event(enable_timing=True) | ||
|
|
||
| start_time.record() | ||
| _ = policy.train( | ||
| data=warmup_data, | ||
| loss_fn=loss_fn, | ||
| eval_mode=True, # Skip optimizer step | ||
| gbs=warmup_gbs, | ||
| mbs=mbs, | ||
| ) | ||
| end_time.record() | ||
|
|
||
| # Wait for completion | ||
| torch.cuda.synchronize() | ||
| warmup_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds | ||
|
|
||
| print(f" ✅ Warmup step completed in {warmup_time:.2f}s") | ||
| print(" 💾 Compiled graph is now cached and ready for training") | ||
|
|
||
| except Exception as e: | ||
| print(f" ⚠️ Warmup failed: {e}") | ||
| print( | ||
| " ℹ️ Continuing with normal training (graph will compile on first real step)" | ||
| ) | ||
|
|
There was a problem hiding this comment.
Warmup should match real SFT inputs; avoid driver-side CUDA timing; validate batch divisibility.
Right now warmup builds extra keys (attention_mask, position_ids, labels, loss_mask, microbatch lists) that SFT training doesn’t pass, and times via torch.cuda.Event in the driver (which may not have CUDA). I’d recommend: (1) construct warmup data with the same keys as real SFT (input_ids, input_lengths, token_mask, sample_mask + multimodal if relevant), (2) use time.perf_counter() on the driver, and (3) assert warmup_gbs % mbs == 0 (and ideally divisibility vs DP if required by your sharding path).
@@
def _warmup_compiled_graph(
@@
) -> None:
@@
- import torch
+ import time
@@
- # Create fake data with max sequence length
- # Use valid token IDs from the tokenizer's vocabulary
+ # Create fake data matching real SFT train() inputs as closely as possible.
vocab_size = len(tokenizer)
fake_input_ids = torch.randint(
@@
- # Create attention mask (all ones = no padding)
- fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long)
-
- # Create position IDs
- fake_position_ids = (
- torch.arange(warmup_seq_len, dtype=torch.long)
- .unsqueeze(0)
- .expand(warmup_gbs, -1)
- )
-
- # Create labels (same as input_ids for SFT)
- fake_labels = fake_input_ids.clone()
-
- # Create loss mask (all ones = compute loss on all tokens)
- fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)
+ if warmup_gbs % mbs != 0:
+ raise ValueError(f"warmup_gbs ({warmup_gbs}) must be divisible by mbs ({mbs})")
@@
- # Create BatchedDataDict with fake data
warmup_data = BatchedDataDict(
{
"input_ids": fake_input_ids,
- "attention_mask": fake_attention_mask,
- "position_ids": fake_position_ids,
- "labels": fake_labels,
- "loss_mask": fake_loss_mask,
"input_lengths": fake_input_lengths,
"sample_mask": fake_sample_mask,
"token_mask": fake_token_mask,
- "micro_batch_indices": fake_micro_batch_indices,
- "micro_batch_lengths": fake_micro_batch_lengths,
}
)
@@
- try:
- start_time = torch.cuda.Event(enable_timing=True)
- end_time = torch.cuda.Event(enable_timing=True)
-
- start_time.record()
- _ = policy.train(
+ try:
+ t0 = time.perf_counter()
+ _ = policy.train(
data=warmup_data,
loss_fn=loss_fn,
eval_mode=True, # Skip optimizer step
gbs=warmup_gbs,
mbs=mbs,
)
- end_time.record()
-
- # Wait for completion
- torch.cuda.synchronize()
- warmup_time = start_time.elapsed_time(end_time) / 1000.0 # Convert to seconds
+ warmup_time = time.perf_counter() - t0Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.8)
476-476: Do not catch blind exception: Exception
(BLE001)
479-479: String contains ambiguous ℹ (INFORMATION SOURCE). Did you mean i (LATIN SMALL LETTER I)?
(RUF001)
There was a problem hiding this comment.
Modify the warmup_data would cause overhead in training.
Only modify CUDA timing
There was a problem hiding this comment.
@katec846, understood! I'll focus only on the CUDA timing issue. Here's the updated suggestion that replaces torch.cuda.Event with time.perf_counter() to avoid driver-side CUDA dependencies:
def _warmup_compiled_graph(
policy: PolicyInterface,
loss_fn,
tokenizer: PreTrainedTokenizerBase,
master_config: dict,
rcg_config: dict,
) -> None:
"""Warmup Ray Compiled Graph with maximum sequence length.
This creates fake data with max_total_sequence_length to ensure the
compiled graph is built with the worst-case input shape, avoiding
recompilation during actual training.
Args:
policy: The policy to warmup
loss_fn: Loss function to use
tokenizer: Tokenizer for creating fake tokens
master_config: Master configuration dict
rcg_config: Ray Compiled Graph configuration dict
"""
import time
import torch
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
# Get configuration
max_seq_len = master_config["policy"]["max_total_sequence_length"]
gbs = master_config["policy"]["train_global_batch_size"]
mbs = master_config["policy"]["train_micro_batch_size"]
# Get warmup config from rcg_config (use defaults if not specified)
warmup_seq_len = rcg_config["warmup_seq_len"] or max_seq_len
warmup_gbs = rcg_config["warmup_gbs"] or gbs
print(f" 🔧 Warmup config: SEQ_LEN={warmup_seq_len}, GBS={warmup_gbs}, MBS={mbs}")
print(f" 📦 Creating fake data with shape: ({warmup_gbs}, {warmup_seq_len})")
# Create fake data with max sequence length
# Use valid token IDs from the tokenizer's vocabulary
vocab_size = len(tokenizer)
fake_input_ids = torch.randint(
low=0,
high=min(vocab_size, 32000), # Use reasonable token range
size=(warmup_gbs, warmup_seq_len),
dtype=torch.long,
)
# Create attention mask (all ones = no padding)
fake_attention_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.long)
# Create position IDs
fake_position_ids = (
torch.arange(warmup_seq_len, dtype=torch.long)
.unsqueeze(0)
.expand(warmup_gbs, -1)
)
# Create labels (same as input_ids for SFT)
fake_labels = fake_input_ids.clone()
# Create loss mask (all ones = compute loss on all tokens)
fake_loss_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)
# All sequences have the same length (max_seq_len)
fake_input_lengths = torch.full((warmup_gbs,), warmup_seq_len, dtype=torch.long)
# Create sample mask (all sequences are valid, no padding/dummy sequences)
fake_sample_mask = torch.ones((warmup_gbs,), dtype=torch.float32)
# Create token mask (all tokens contribute to loss, used for token-level loss)
fake_token_mask = torch.ones((warmup_gbs, warmup_seq_len), dtype=torch.float32)
# Create microbatch indices and lengths for sequence packing
# For warmup, all sequences have same length, so we create simple placeholders
num_microbatches = warmup_gbs // mbs
fake_micro_batch_indices = []
fake_micro_batch_lengths = []
for mb_idx in range(num_microbatches):
start_idx = mb_idx * mbs
end_idx = start_idx + mbs
# Each microbatch contains mbs sequences, each of length warmup_seq_len
fake_micro_batch_indices.append(list(range(start_idx, end_idx)))
fake_micro_batch_lengths.append([warmup_seq_len] * mbs)
# Create BatchedDataDict with fake data
warmup_data = BatchedDataDict(
{
"input_ids": fake_input_ids,
"attention_mask": fake_attention_mask,
"position_ids": fake_position_ids,
"labels": fake_labels,
"loss_mask": fake_loss_mask,
"input_lengths": fake_input_lengths,
"sample_mask": fake_sample_mask,
"token_mask": fake_token_mask,
"micro_batch_indices": fake_micro_batch_indices,
"micro_batch_lengths": fake_micro_batch_lengths,
}
)
# Store warmup data in Ray object store for efficient sharing across workers
print(" 🚀 Running warmup training step...")
# Run one training step to trigger graph compilation
# Use eval_mode=True to skip optimizer step (we don't care about gradients)
try:
t0 = time.perf_counter()
_ = policy.train(
data=warmup_data,
loss_fn=loss_fn,
eval_mode=True, # Skip optimizer step
gbs=warmup_gbs,
mbs=mbs,
)
warmup_time = time.perf_counter() - t0
print(f" ✅ Warmup step completed in {warmup_time:.2f}s")
print(" 💾 Compiled graph is now cached and ready for training")
except Exception as e:
print(f" ⚠️ Warmup failed: {e}")
print(
" ℹ️ Continuing with normal training (graph will compile on first real step)"
)This change uses time.perf_counter() which works reliably on the driver without requiring CUDA availability.
| # Wrap in MultiWorkerFuture | ||
| from nemo_rl.distributed.worker_groups import MultiWorkerFuture | ||
|
|
||
| return MultiWorkerFuture( | ||
| futures=refs if isinstance(refs, list) else [refs], | ||
| return_from_workers=return_from_workers, | ||
| called_workers=list(range(len(refs))), # All workers were called | ||
| ) | ||
| else: |
There was a problem hiding this comment.
len(refs) can break when refs isn’t a list — normalize once.
Both DP=1 and DP>1 branches can construct called_workers using len(refs) even when refs may be a single ObjectRef.
- return MultiWorkerFuture(
- futures=refs if isinstance(refs, list) else [refs],
+ refs_list = refs if isinstance(refs, list) else [refs]
+ return MultiWorkerFuture(
+ futures=refs_list,
return_from_workers=return_from_workers,
- called_workers=list(range(len(refs))), # All workers were called
+ called_workers=list(range(len(refs_list))), # All workers were called
)
@@
- return MultiWorkerFuture(
- futures=refs if isinstance(refs, list) else [refs],
+ refs_list = refs if isinstance(refs, list) else [refs]
+ return MultiWorkerFuture(
+ futures=refs_list,
return_from_workers=return_from_workers,
- called_workers=list(range(len(refs))), # All workers were called
+ called_workers=list(range(len(refs_list))), # All workers were called
)Also applies to: 770-777
🤖 Prompt for AI Agents
nemo_rl/distributed/ray_compiled_graph.py around lines 693-701 (and similarly
770-777): refs may be a single ObjectRef or a list; normalize it once to avoid
calling len() on a non-list. Replace direct use of refs with a local refs_list =
refs if isinstance(refs, list) else [refs], use futures=refs_list,
called_workers=list(range(len(refs_list))) and return_from_workers unchanged;
apply the same normalization in the DP>1 branch around lines 770-777.
| @ray.remote | ||
| class IsolatedWorkerInitializer: |
There was a problem hiding this comment.
Add Ray-process coverage pragma for @ray.remote per repo guidelines.
This file defines a @ray.remote class under nemo_rl/**, so it should be excluded from coverage.
Based on coding guidelines, ...
class RayWorkerBuilder:
- @ray.remote
+ @ray.remote # pragma: no cover
class IsolatedWorkerInitializer:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @ray.remote | |
| class IsolatedWorkerInitializer: | |
| @ray.remote # pragma: no cover | |
| class IsolatedWorkerInitializer: |
🤖 Prompt for AI Agents
In nemo_rl/distributed/worker_groups.py around lines 209 to 210, the @ray.remote
class needs to be excluded from test coverage; append the repository's
Ray-process coverage pragma to the decorator/class declaration (e.g., add the
coverage pragma comment such as "# pragma: no cover (ray-process)" on the same
line or immediately above the @ray.remote line) so the Ray worker process code
is omitted from coverage reports.
| # NOTE: @ray.remote decorator removed to fix Ray bug #26283 with compiled graphs. | ||
| # Ray.remote() is now applied at runtime in RayWorkerBuilder (like vLLM does). | ||
| class DTensorPolicyWorkerV2(AbstractPolicyWorker, ColocatablePolicyInterface): | ||
| # Runtime environment to use when applying ray.remote() at runtime | ||
| _ray_remote_runtime_env = get_runtime_env_for_policy_worker( | ||
| "dtensor_policy_worker_v2" | ||
| ) | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for references to _ray_remote_runtime_env and _default_options
rg -n -S '_ray_remote_runtime_env|_default_options' nemo_rl/ --type pyRepository: NVIDIA-NeMo/RL
Length of output: 469
🏁 Script executed:
#!/bin/bash
# Search for RayWorkerBuilder, wrapper classes, and ray.remote application logic
rg -n 'RayWorkerBuilder|NeMoRayWorkerWrapper|apply.*ray\.remote|ray\.remote\(' nemo_rl/ --type pyRepository: NVIDIA-NeMo/RL
Length of output: 3605
🏁 Script executed:
#!/bin/bash
# Find all policy worker files to see naming patterns
fd -e py 'policy.*worker' nemo_rl/models/policy/workers/ | head -20Repository: NVIDIA-NeMo/RL
Length of output: 278
🏁 Script executed:
#!/bin/bash
# Examine the get_runtime_env_for_policy_worker function
rg -n 'def get_runtime_env_for_policy_worker' nemo_rl/ -A 10 --type pyRepository: NVIDIA-NeMo/RL
Length of output: 869
Use _default_options instead of _ray_remote_runtime_env for naming consistency.
RayWorkerBuilder (nemo_rl/distributed/worker_groups.py:249) explicitly looks for _default_options using getattr(worker_class, "_default_options", {}). DTensorPolicyWorker and MegatronPolicyWorker both use this convention, but DTensorPolicyWorkerV2 uses _ray_remote_runtime_env instead. This naming mismatch means the runtime environment configuration will be silently dropped when the builder applies ray.remote() at runtime.
Change line 99 to:
_default_options = get_runtime_env_for_policy_worker(
"dtensor_policy_worker_v2"
)🤖 Prompt for AI Agents
In nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py around lines 95 to
102, the class attribute is named _ray_remote_runtime_env but RayWorkerBuilder
expects _default_options, so change the attribute name to _default_options and
assign it the result of
get_runtime_env_for_policy_worker("dtensor_policy_worker_v2"); ensure no other
references to the old name remain and update any imports or usages if necessary
so the runtime env is picked up when Ray.remote() is applied at runtime.
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
d60c225 to
7a7178a
Compare
ℹ️ File Consistency CheckCheck based on commit: 7a7178a (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 47d5615 (PR #1612 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Kate Cheng <yunhsuanc@nvidia.com>
What does this PR do ?
Adds Ray Compiled Graph support for SFT
TODO:
Ray Compiled Graph (RCG) - Sub-millisecond task orchestration overhead (<50μs vs ~1ms for standard Ray calls)
Key Features
Ray Compiled Graph Support
Implements Ray's compiled DAG execution for distributed training with support for:
Enable with:
++policy.sequence_packing.enabled=True
Technical Changes
Architecture Changes
1.
NeMoRayWorkerWrapper- Ray Actor Wrapper PatternProblem: Ray's
@ray.remotedecorator wraps all methods inActorMethodobjects with generic signatures, preventing Ray Compiled Graph from inspecting actual method parameters (Ray issue #26283).Solution:
NeMoRayWorkerWrapperclass that receives@ray.remotedecoration instead of worker classesself.worker = worker_class(*args, **kwargs)train_compiled(train_input: dict)method with inspectable signature for Compiled Graphexecute_method(method_name, *args, **kwargs)for standard Ray remote callsChanges:
@ray.remotedecorator from all worker classes (MegatronPolicyWorker,DTensorPolicyWorker,DTensorPolicyWorkerV2)RayWorkerBuildernow applies@ray.remote(NeMoRayWorkerWrapper)at runtimeworker.execute_method.remote(method_name, *args, **kwargs)2. Ray Compiled Graph Infrastructure (
ray_compiled_graph.py)New Classes:
CompiledGraphExecutorMultiDPCompiledGraphExecutorCompiledGraphWorkerGroupRayWorkerGroupHelper Functions:
get_compiled_graph_config(): Extracts RCG configuration from policy configshould_use_compiled_graph(): Determines if RCG should be enabled3. Configuration Support
New
RayCompiledGraphConfiginPolicyConfig:ray_compiled_graph:
enabled: true # Enable/disable RCG
warmup_seq_len: 8192 # Optional: warmup sequence length
warmup_gbs: 128 # Optional: warmup global batch size
overlap_communication: false # Experimental: overlap GPU compute and comm### 4. Automatic Warmup
_warmup_compiled_graph()insft.py:warmup_seq_lenandwarmup_gbs5. Policy Integration
lm_policy.pychanges:RayWorkerGroupwithCompiledGraphWorkerGroupwhen RCG is enabledRayWorkerGroupAdditional Improvements
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
ray_compiled_graphconfiguration option to control optimization behavior.Improvements
✏️ Tip: You can customize this high-level summary in your review settings.