Skip to content

fix: Reinitialize model parallel after import#1317

Merged
terrykong merged 1 commit intomainfrom
ybgao/oct8-fix-parallelism
Oct 9, 2025
Merged

fix: Reinitialize model parallel after import#1317
terrykong merged 1 commit intomainfrom
ybgao/oct8-fix-parallelism

Conversation

@ybgao-nvidia
Copy link
Copy Markdown
Contributor

@ybgao-nvidia ybgao-nvidia commented Oct 8, 2025

What does this PR do ?

This PR re-initializes the model parallel state after parallel model loading. This resolves the issue where the parallelism used during model loading is different from that used in training. This materializes in the error in #1300.

In particular, the error occurred because the distributed model import process (import_model_from_hf_name()) initializes the model parallel state with default values (CP=1), leading to incorrect data parallel size calculations. When training begins, the system finds the parallel state already initialized with the wrong configuration and skips reinitialization, resulting in a training job launched with an incorrect value for DP.

The reinitialization results in a one-time performance cost when initially loading the model. Subsequent training runs that load from checkpoints don't encounter this and hence don't incur this cost.

Issues

This PR resolves #1300.

Usage

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Summary by CodeRabbit

  • Bug Fixes
    • Prevents reinitialization conflicts in model-parallel setups when loading models from Hugging Face by safely resetting any existing parallel state. This reduces crashes, hangs, and inconsistent behavior in distributed runs and provides clearer log messaging during initialization.

Signed-off-by: Yubo Gao <yubog@nvidia.com>
@ybgao-nvidia ybgao-nvidia self-assigned this Oct 8, 2025
@ybgao-nvidia ybgao-nvidia requested a review from a team as a code owner October 8, 2025 20:55
@ybgao-nvidia ybgao-nvidia changed the title reinitialize model parallel after import Reinitialize model parallel after import Oct 8, 2025
@ybgao-nvidia ybgao-nvidia changed the title Reinitialize model parallel after import fix: Reinitialize model parallel after import Oct 8, 2025
@ybgao-nvidia ybgao-nvidia requested review from terrykong and yfw October 8, 2025 20:56
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Oct 8, 2025

📝 Walkthrough

Walkthrough

Adds a safeguard in MegatronPolicyWorker initialization: after loading a HuggingFace model, if Megatron model-parallel is already initialized, it logs a reinitialization message and destroys the existing parallel state before continuing initialization.

Changes

Cohort / File(s) Summary
Megatron policy worker initialization
nemo_rl/models/policy/megatron_policy_worker.py
Inserted a post-load check in init to detect pre-initialized model-parallel state and destroy it to reset parallel_state before proceeding. No public API changes.

Sequence Diagram(s)

sequenceDiagram
    participant W as MegatronPolicyWorker
    participant HF as HF Model Loader
    participant PS as Megatron ParallelState

    W->>HF: load_model_from_hf(...)
    HF-->>W: model
    W->>PS: is_model_parallel_initialized?
    alt Already initialized
        W->>PS: destroy_model_parallel()
        Note right of PS: Reset state to avoid mismatched DP/CP sizes
    else Not initialized
        Note over W,PS: Proceed without reset
    end
    W->>W: continue initialization (set up MP/DP/CP)
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested labels

CI:L1, external

Suggested reviewers

  • terrykong
  • yaoyu-33

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title succinctly and accurately describes the primary change of reinitializing the model parallel state after import and aligns with the code modifications without extraneous details.
Linked Issues Check ✅ Passed The change adds a safeguard that destroys any preexisting parallel state after loading a model from Hugging Face, which directly addresses the reported index error by ensuring CP and DP sizes are correctly initialized per issue #1300.
Out of Scope Changes Check ✅ Passed All modifications are contained to the model import initialization in MegatronPolicyWorker and relate directly to reinitializing the parallel state, with no unrelated or extraneous code changes introduced.
Test Results For Major Changes ✅ Passed The PR introduces a targeted bug fix that resets the model-parallel state after importing a model, representing a relatively small control-flow adjustment rather than a major feature or refactor. Since the change is minor, the custom check passes without requiring test evidence in the PR description.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ybgao/oct8-fix-parallelism

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
nemo_rl/models/policy/megatron_policy_worker.py (1)

510-512: Consider using the existing destroy_parallel_state() helper for more complete cleanup.

The file already defines a comprehensive destroy_parallel_state() function (lines 328-426) that handles:

  • Destroying both torch distributed and model parallel state
  • Resetting async call queues in multiple modules (NeMo, Megatron, base strategy)
  • Error handling with try-except blocks

The current implementation only destroys model parallel state without resetting async call queues. According to the existing helper's docstring, resetting async call tracking ensures "all ranks start with consistent call_idx values for async checkpointing" after the distributed context is recreated.

Apply this diff to use the more comprehensive helper:

-            if parallel_state.model_parallel_is_initialized():
-                print("Reinitializing model parallel after loading model state.")
-                parallel_state.destroy_model_parallel()
+            if parallel_state.model_parallel_is_initialized():
+                print("Reinitializing model parallel after loading model state.")
+                destroy_parallel_state()
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 57046a4 and 0b41352.

📒 Files selected for processing (1)
  • nemo_rl/models/policy/megatron_policy_worker.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts

Files:

  • nemo_rl/models/policy/megatron_policy_worker.py
nemo_rl/**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)

Files:

  • nemo_rl/models/policy/megatron_policy_worker.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Lint check
  • GitHub Check: Post submodule check comment / Comment on PR
  • GitHub Check: Post automodel integration comment / Comment on PR

Copy link
Copy Markdown
Contributor

@yfw yfw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@ybgao-nvidia ybgao-nvidia added the CI:L1 Run doctests, unit tests, and functional tests label Oct 8, 2025
@terrykong terrykong enabled auto-merge (squash) October 8, 2025 22:04
@terrykong terrykong merged commit d726c38 into main Oct 9, 2025
94 of 103 checks passed
@terrykong terrykong deleted the ybgao/oct8-fix-parallelism branch October 9, 2025 01:27
@zpqiu zpqiu added the r0.4.0 label Oct 27, 2025
zpqiu pushed a commit that referenced this pull request Oct 27, 2025
Signed-off-by: Yubo Gao <yubog@nvidia.com>
(cherry picked from commit d726c38)
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Yubo Gao <yubog@nvidia.com>
yuanhangsu1986 pushed a commit to yuanhangsu1986/RL-Nemontron-Edge-Omni that referenced this pull request Feb 21, 2026
Signed-off-by: Yubo Gao <yubog@nvidia.com>
Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests r0.4.0

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Error when using CP with SFT in Megatron backend

4 participants