fix: Reinitialize model parallel after import#1317
Conversation
Signed-off-by: Yubo Gao <yubog@nvidia.com>
📝 WalkthroughWalkthroughAdds a safeguard in MegatronPolicyWorker initialization: after loading a HuggingFace model, if Megatron model-parallel is already initialized, it logs a reinitialization message and destroys the existing parallel state before continuing initialization. Changes
Sequence Diagram(s)sequenceDiagram
participant W as MegatronPolicyWorker
participant HF as HF Model Loader
participant PS as Megatron ParallelState
W->>HF: load_model_from_hf(...)
HF-->>W: model
W->>PS: is_model_parallel_initialized?
alt Already initialized
W->>PS: destroy_model_parallel()
Note right of PS: Reset state to avoid mismatched DP/CP sizes
else Not initialized
Note over W,PS: Proceed without reset
end
W->>W: continue initialization (set up MP/DP/CP)
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)
510-512: Consider using the existingdestroy_parallel_state()helper for more complete cleanup.The file already defines a comprehensive
destroy_parallel_state()function (lines 328-426) that handles:
- Destroying both torch distributed and model parallel state
- Resetting async call queues in multiple modules (NeMo, Megatron, base strategy)
- Error handling with try-except blocks
The current implementation only destroys model parallel state without resetting async call queues. According to the existing helper's docstring, resetting async call tracking ensures "all ranks start with consistent call_idx values for async checkpointing" after the distributed context is recreated.
Apply this diff to use the more comprehensive helper:
- if parallel_state.model_parallel_is_initialized(): - print("Reinitializing model parallel after loading model state.") - parallel_state.destroy_model_parallel() + if parallel_state.model_parallel_is_initialized(): + print("Reinitializing model parallel after loading model state.") + destroy_parallel_state()
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/models/policy/megatron_policy_worker.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/megatron_policy_worker.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
Signed-off-by: Yubo Gao <yubog@nvidia.com> (cherry picked from commit d726c38)
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: Yubo Gao <yubog@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
This PR re-initializes the model parallel state after parallel model loading. This resolves the issue where the parallelism used during model loading is different from that used in training. This materializes in the error in #1300.
In particular, the error occurred because the distributed model import process (
import_model_from_hf_name()) initializes the model parallel state with default values (CP=1), leading to incorrect data parallel size calculations. When training begins, the system finds the parallel state already initialized with the wrong configuration and skips reinitialization, resulting in a training job launched with an incorrect value for DP.The reinitialization results in a one-time performance cost when initially loading the model. Subsequent training runs that load from checkpoints don't encounter this and hence don't incur this cost.
Issues
This PR resolves #1300.
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit