Skip to content

Comments

Hybrid Context Parallel Feature#2282

Merged
parthmannan merged 98 commits intoNVIDIA:mainfrom
parthmannan:pmannan/hybrid_dp_cp_main
Jan 17, 2026
Merged

Hybrid Context Parallel Feature#2282
parthmannan merged 98 commits intoNVIDIA:mainfrom
parthmannan:pmannan/hybrid_dp_cp_main

Conversation

@parthmannan
Copy link
Contributor

@parthmannan parthmannan commented Nov 17, 2025

What does this PR do ?

Dev MR for details - #2054

Design document discussed in MCore sync meeting - https://docs.google.com/document/d/1MnIPQ_VbpDNp-adtvcEv-SYx6A8rtt3-fDdxbcdrmk0/edit?usp=sharing

The first issue this MR is trying to solve is the imbalance between DP ranks when using packed sequences (for example in SFT). While packing sequences can help reduce variability in total sequence length, it does not guarantee equal workload. Attention compute is quadratic to sequence length and a single long sequence of 1k has 2x more compute than a packed sequence made of 2x512 length. This problem gets much worse when we have very large sequences and/or a large variation between sequence lengths.
This MR schedules a variable number of microbatches per rank in DPxCP group to ensure balanced workload.

The second issue this MR is trying to solve is redundant CP communication. Our context parallel size is based on the full packed sequence length (usually the max seq length of all samples). For example, if a sequence of 1k requires CP2, we apply CP2 to a packed sequence of 2x512 as well. But in reality, we can easily partition the packed sequence of 2x512 into 2 GPUs by separating the 2 samples without any CP. This MR introduces dynamic context parallelism where each sample is individually scheduled with a dynamic CP group.

To achieve the above, we introduce a balanced scheduler and a dataloader wrapper.
The dataloader wrapper is responsible for collecting the metadata which informs the scheduler of the sequence length of each sample across the entire global batch. This dataloader breaks up the packed sequences into individual samples as they are individually scheduled. Once we have the metadata, we can perform the scheduling using the balanced scheduler which assigns samples to ranks (across DPxCP group) and a dynamic CP group size. To avoid any deadlocks, we divide the schedule into groups (this replaces the notion of microbatches). Within each group, each rank is part of a fixed CP group. However, each rank may run different number of samples in order for all ranks to have a balanced compute.

Screenshot 2025-10-08 at 3 21 39 PM

We have run performance and correctness evaluations using the feature. Using the SFT packed dataset with max seq len of 128k and testing with LLaMa3 8B dummy model, we see 3x performance improvement with this feature. While there is room for improving the baseline itself, the speedup should remain in the 2-3x range.

This is how 128k seq len with CP16 looks like (without this feature). The GPU is bound by CP communications.
Screenshot 2025-10-08 at 3 28 38 PM

This is how 128k seq len with CP16 looks like (with this feature). The GPU is bound by attention compute since all redundant comms have been removed.
Screenshot 2025-10-08 at 3 30 26 PM

Feature correctness (@xiaoyao0115)
hybrid_cp_loss_curve

This is the first milestone of this feature and there's many improvements that we want to make in the future releases.

  1. The feature does not support pipeline parallelism or FSDP yet. We hope to add PP support next.
  2. The feature is limited to creating dynamic groups of CP of power 2. We hope to add complete dynamic support using changes in TransformerEngine DPA.
  3. The feature does not support CUDA graphs.
  4. The feature works best with FlashAttention instead of cuDNN FusedAttention. This is because the changing lengths and CP size make cuDNN recompile the graph and all performance gains are lost. We'll advocate for dynamic support to cuDNN FusedAttention.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

parthmannan and others added 30 commits July 14, 2025 19:08
…ia.com:12051/ADLR/megatron-lm into pmannan/hetero_cp_test_sft
@parthmannan
Copy link
Contributor Author

/ok to test 1ebcb02

@parthmannan
Copy link
Contributor Author

Functional test suite has been run and passing on Gitlab in https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/pipelines/41910205

@parthmannan parthmannan added this pull request to the merge queue Jan 16, 2026
@Phlip79 Phlip79 removed this pull request from the merge queue due to a manual request Jan 16, 2026
@parthmannan
Copy link
Contributor Author

/ok to test 641b04d

@parthmannan parthmannan enabled auto-merge January 17, 2026 00:31
@parthmannan parthmannan added this pull request to the merge queue Jan 17, 2026
Merged via the queue into NVIDIA:main with commit 98d8c56 Jan 17, 2026
69 of 73 checks passed
@parthmannan parthmannan deleted the pmannan/hybrid_dp_cp_main branch January 17, 2026 02:06
chtruong814 pushed a commit to chtruong814/Megatron-LM that referenced this pull request Jan 17, 2026
This reverts commit 98d8c56.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>
@asolergi-nv asolergi-nv mentioned this pull request Feb 5, 2026
6 tasks
daiyaanarfeen pushed a commit to daiyaanarfeen/Megatron-LM that referenced this pull request Feb 23, 2026
Signed-off-by: tailaim <tailaim@nvidia.com>
Signed-off-by: Parth Mannan <pmannan@nvidia.com>
Co-authored-by: Mcore Bot <mcore-bot@nvidia.com>
Co-authored-by: tailaim <tailaim@nvidia.com>
Co-authored-by: kunlunl <kunlunl@nvidia.com>
Co-authored-by: Kunlun Li <94586211+kunlunl@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

complexity: high dev2main: mbridge dev to main: this PR is needed in main for mbridge Final Review Apply this label to indicate that your PR is ready for final review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.