Skip to content

Fix getCommunicationInfo for multi-dimensional meshes#5969

Merged
wujingyue merged 3 commits intomainfrom
wjy/comm
Feb 21, 2026
Merged

Fix getCommunicationInfo for multi-dimensional meshes#5969
wujingyue merged 3 commits intomainfrom
wjy/comm

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Feb 18, 2026

Fixes #4604

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 18, 2026

Review updated until commit 807a8e6

Description

  • Fix getCommunicationInfo to properly handle multi-dimensional device meshes

  • Add early return check for haveDifferentShardings to avoid redundant processing

  • Improve error handling and messaging in communication lowering functions

  • Add test coverage for 2D mesh all-gather operations

Changes walkthrough

Relevant files
Bug fix
lower_to_communication.cpp
Fix communication info handling for multi-dimensional meshes

csrc/host_ir/lower_to_communication.cpp

  • Added include for ir/interface_nodes.h
  • Improved error handling in lowerToGather with NVF_ERROR_EQ
  • Added early return check for haveDifferentShardings in
    getCommunicationInfoForParallelType
  • Enhanced handling of different meshes in communication operations
  • Simplified error handling for reduction operations
  • Removed problematic TODO comment about 2D sharding
  • +26/-34 
    Documentation
    resharding.cpp
    Document multi-dimensional sharding limitations                   

    csrc/multidevice/resharding.cpp

  • Added documentation comment explaining problematic multi-dimensional
    sharding case
  • Provided example showing when current code fails for multi-dimensional
    meshes
  • +7/-0     
    Tests
    test_communication.py
    Add test coverage for 2D mesh all-gather operations           

    tests/python/multidevice/test_communication.py

  • Added test_allgather_2d function to test all-gather on 2D device
    meshes
  • Tests multi-dimensional mesh configuration with outer_split and
    parallelization
  • Validates proper sharding and communication for 2D mesh operations
  • +31/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic Flow Changes

    The early return check if (!haveDifferentShardings(producer, consumer, {pt})) at line 348-350 is a key change that affects the control flow. This should be carefully validated to ensure it doesn't break existing functionality for 1D meshes while properly handling multi-dimensional cases.

    if (!haveDifferentShardings(producer, consumer, {pt})) {
      return std::nullopt;
    }
    Known Limitation

    The comment at lines 129-135 acknowledges that the current code is still problematic for multi-dimensional sharding cases. While this PR fixes the immediate issue, the broader limitation should be tracked and addressed in future work.

    //
    // This code is problematic for multi-dimensional sharding.
    // ```
    // x: [iDIDy{2}, iDIDx{2}] on mesh [[0, 1], [2, 3]]
    // y = set(x): [iDIDy{2}, i{2}] on mesh [[0], [2]]
    // ```
    // should be treated as non-resharding on DIDy.

    @wujingyue wujingyue requested a review from Priya2698 February 18, 2026 07:08
    @wujingyue wujingyue marked this pull request as ready for review February 18, 2026 07:09
    @wujingyue
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @wujingyue
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Greptile Summary

    This PR fixes getCommunicationInfo for multi-dimensional device meshes by adding a per-parallel-type haveDifferentShardings guard at the top of getCommunicationInfoForParallelType. Previously, the function would misclassify communication types when a 2D mesh had sharding changes on only one dimension (e.g., allgather on DIDx while DIDy stays the same), because it conflated "both loop IDs present" with "resharding needed." The fix ensures that parallel types with identical shardings are skipped early, and the broadcast fallback for different-mesh/unsharded cases is moved from the caller into the per-PT function for better locality.

    • Adds haveDifferentShardings(producer, consumer, {pt}) check at the start of getCommunicationInfoForParallelType to skip parallel types with no sharding change
    • Moves the broadcast-for-different-meshes logic from getCommunicationInfo into the per-PT handler
    • Replaces soft return std::nullopt with NVF_ERROR assertions where the haveDifferentShardings guard guarantees the condition should never occur
    • Removes the now-redundant "check if p_logical_id is reduced in the output" block for ReduceScatter
    • Documents a known remaining limitation in haveDifferentShardings for multi-dimensional meshes with different mesh shapes
    • Adds a test_allgather_2d regression test exercising the 2D mesh allgather scenario from issue Fix convertSingleOpToCommunication for 2D sharding #4604

    Confidence Score: 4/5

    • This PR is safe to merge — the logic changes are well-reasoned and a regression test covers the fixed scenario.
    • The core fix is sound: the early haveDifferentShardings per-PT guard correctly filters out parallel types with no sharding change, preventing misclassification for multi-dimensional meshes. The assertion tightening follows logically from the new guard. One point docked because the documented limitation in haveDifferentShardings for different-mesh multi-dimensional sharding remains unfixed (though it's pre-existing and out of scope for this PR).
    • csrc/host_ir/lower_to_communication.cpp contains the core logic changes — the new assertions replacing soft returns should be verified against edge cases in multi-device CI.

    Important Files Changed

    Filename Overview
    csrc/host_ir/lower_to_communication.cpp Core fix: adds early haveDifferentShardings check per-parallel-type in getCommunicationInfoForParallelType, moves broadcast fallback into the per-PT function, removes now-redundant checks, and tightens assertions. Logic is sound and well-reasoned.
    csrc/multidevice/resharding.cpp Documents a known limitation in the different-mesh early-return path in haveDifferentShardings for multi-dimensional sharding. Comment-only change — no logic altered.
    tests/python/multidevice/test_communication.py Adds test_allgather_2d regression test for 2D mesh allgather scenario that exercises the fix. Test follows existing patterns, properly skips when device count is not divisible by tp_size.

    Flowchart

    %%{init: {'theme': 'neutral'}}%%
    flowchart TD
        A["getCommunicationInfo(Expr* e)"] --> B["Loop over kParallelTypeDIDs\n(DIDx, DIDy, DIDz)"]
        B --> C["getCommunicationInfoForParallelType\n(producer, consumer, pt)"]
        C --> D{"haveDifferentShardings\n(producer, consumer, {pt})"}
        D -- "No sharding change" --> E["return nullopt\n(skip this PT)"]
        D -- "Has sharding change" --> F{"consumer def type?"}
        F -- "LoadStoreOp" --> G{"p_loop_id / c_loop_id?"}
        G -- "both null\n(different meshes)" --> H["Broadcast"]
        G -- "p only" --> I{"same mesh?"}
        I -- "yes" --> J["Allgather"]
        I -- "no" --> K["Gather"]
        G -- "c only" --> L["Scatter"]
        G -- "both present" --> M{"same logical ID?"}
        M -- "yes" --> N["SendRecv"]
        M -- "no" --> O["AllToAll"]
        F -- "ReductionOp/SqueezeOp" --> P{"c_loop_id?"}
        P -- "null" --> Q{"same mesh?"}
        Q -- "yes" --> R["Allreduce"]
        Q -- "no" --> S["Reduce"]
        P -- "present" --> T["ReduceScatter"]
    
    Loading

    Last reviewed commit: 807a8e6

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 18, 2026

    Additional Comments (1)

    csrc/multidevice/resharding.cpp
    Known limitation affects callers of this function

    The new code in getCommunicationInfo relies on calling haveDifferentShardings(producer, consumer, {pt}) per parallel type to skip unchanged dimensions. However, this early return at line 140 means that when meshes differ (e.g., [[0,1],[2,3]] vs [[0],[2]]), calling haveDifferentShardings(..., {DIDy}) will return true even though DIDy sharding is unchanged — as the comment correctly notes.

    This means getCommunicationInfo could produce incorrect results for multi-dimensional resharding across different meshes (e.g., Gather on one DID while the other DID remains unchanged but meshes differ). The same-mesh allgather case in the new test works because the meshes are identical. Is there a plan to address this for the cross-mesh case, or is that intentionally deferred? Is the cross-mesh multi-dimensional sharding case (e.g., Gather where meshes differ but one DID is unchanged) intentionally deferred as a future fix, or should it be addressed as part of this PR?

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 20, 2026

    Additional Comments (1)

    csrc/host_ir/lower_to_communication.cpp
    The removed code checked if p2c.at(p_logical_id)->isReduction() before returning ReduceScatter. Verify this assumption always holds when both p_loop_id and c_loop_id exist in a ReductionOp/SqueezeOp.

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    // x: [iDIDy{2}, iDIDx{2}] on mesh [[0, 1], [2, 3]]
    // y = set(x): [iDIDy{2}, i{2}] on mesh [[0], [2]]
    // ```
    // should be treated as non-resharding on DIDy.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Instead of device mesh, we should be comparing the slice corresponding to the parallel types here.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Does it work?

    x: [iDIDy{2}, iDIDx{2}] on mesh [[0, 1], [2, 3]]
    y = set(x): [iDIDy{2}, i{2}] on mesh [[0], [3]]
    

    is also non-resharding.

    @wujingyue wujingyue merged commit b5207d5 into main Feb 21, 2026
    56 checks passed
    @wujingyue wujingyue deleted the wjy/comm branch February 21, 2026 04:37
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    Fix convertSingleOpToCommunication for 2D sharding

    2 participants