Skip to content

DID loop split for scatter#4191

Merged
Priya2698 merged 21 commits intomainfrom
pm/scatter
May 20, 2025
Merged

DID loop split for scatter#4191
Priya2698 merged 21 commits intomainfrom
pm/scatter

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Apr 4, 2025

Adds support for scatter when using loop split for sharding.
Prepares for #3900

Similar to allgather, for correct results, scatter requires outermost allocation of the scattered axis.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 4, 2025

Review updated until commit 846fbee

Description

  • Added support for scatter in loop split for sharding.

  • Ensured correct root selection for scatter operations.

  • Added tests for ScatterLoopSplit.

  • Improved error handling for non-contiguous tensors.


Changes walkthrough 📝

Relevant files
Enhancement
lower_to_communication.cpp
Improved root selection for scatter                                           

csrc/host_ir/lower_to_communication.cpp

  • Selected a common device between input and receiver meshes as the
    root.
  • Removed arbitrary root selection logic.
  • +12/-5   
    communication.cpp
    Enhanced scatter communication logic                                         

    csrc/multidevice/communication.cpp

  • Added error checks for contiguous tensors.
  • Ensured root is in the output device mesh.
  • Used as_strided for output tensor.
  • Split input tensor correctly for scatter.
  • +28/-17 
    Tests
    test_multidevice_lower_communication.cpp
    Added scatter loop split test                                                       

    tests/cpp/test_multidevice_lower_communication.cpp

    • Added a test case for ScatterLoopSplit.
    +46/-0   
    test_multidevice_pipeline.cpp
    Updated scatter test meshes                                                           

    tests/cpp/test_multidevice_pipeline.cpp

    • Updated mesh values for scatter tests.
    +2/-2     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Root Device Validation

    Ensure that the logic for finding a common device between input and receiver meshes is robust and handles edge cases, such as when no common device exists.

    auto it = std::ranges::find_if(
        input_tv->getDeviceMesh().vector(),
        [&receiver_mesh](DeviceIdxType device) {
          return receiver_mesh.has(device);
        });
    NVF_ERROR(
        it != input_tv->getDeviceMesh().vector().end(),
        "No common device found between input and receiver meshes");
    DeviceIdxType root = *it;
    
    Tensor Contiguity Check

    Verify that the checks for tensor contiguity are necessary and do not introduce unnecessary overhead or restrict valid use cases.

    NVF_ERROR(
        isTvContiguous(communication->in()), "Input tensor is not contiguous");
    NVF_ERROR(
        isTvContiguous(communication->out()), "Output tensor is not contiguous");
    Test Coverage

    Ensure that the new test case covers all relevant scenarios and edge cases for the scatter operation with loop split.

    TEST_P(LowerCollectiveTest, ScatterLoopSplit) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
      const auto d = communicator_->size();
      auto full_mesh = DeviceMesh::createForNumDevices(d);
    
      DeviceMesh mesh_zero({0});
      TensorView* tv0 = makeConcreteTensor({5, d * 3});
      TensorView* tv1 = set(tv0);
    
      tv0->setDeviceMesh(mesh_zero);
      tv0->outer_split(1, d);
      tv0->axis(1)->parallelize(ParallelType::Serial);
      tv0->reorder({2, 0, 1});
    
      tv1->setDeviceMesh(full_mesh);
      tv1->outer_split(1, d);
      tv1->axis(1)->parallelize(ParallelType::DIDx);
      tv1->reorder({2, 0, 1});
    
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      for (auto tv : {tv0, tv1}) {
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    
      at::Tensor unsharded_in_tensor =
          at::randn({d * 3, 5}, tensor_options).transpose(0, 1);
    
      at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor out_tensor =
          executor_cache.runFusionWithInputs({unsharded_in_tensor})[0]
              .as<at::Tensor>();
    
      testValidate(
          executor_cache.fusion(),
          {out_tensor},
          {unsharded_in_tensor},
          {expected_output},
          __LINE__,
          __FILE__);
    }
    
    INSTANTIATE_TEST_SUITE_P(

    @Priya2698
    Copy link
    Collaborator Author

    !test

    1 similar comment
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested review from cowanmeg and wujingyue April 4, 2025 19:15
    @Priya2698 Priya2698 marked this pull request as ready for review April 4, 2025 19:15
    @Priya2698
    Copy link
    Collaborator Author

    The Scatter/PipelineTestTwoStages.Communication/36 is failing on the CI.
    I am debugging this.

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'll review this after #4170

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !test

    !communication->out()->getDeviceMesh().has(communication->root())) {
    input_tensors.front().push_back(output_tensor);
    !output_has_root) {
    input_tensors.front().push_back(at::empty_like(splits.at(0)));
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    do we necessarily need to allocate a new buffer for inputs here?

    @Priya2698
    Copy link
    Collaborator Author

    !test

    testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
    all_meshes,
    all_meshes,
    testing::Values(mesh0, mesh1),
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Changing this to avoid cases where root is not in output device mesh.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 merged commit ee0d3f7 into main May 20, 2025
    53 checks passed
    @Priya2698 Priya2698 deleted the pm/scatter branch May 20, 2025 16:26
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants