Skip to content

DID loop split for allgather for non-outermost sharded axis.#4170

Merged
Priya2698 merged 19 commits intomainfrom
pm/reorder
Apr 11, 2025
Merged

DID loop split for allgather for non-outermost sharded axis.#4170
Priya2698 merged 19 commits intomainfrom
pm/reorder

Conversation

@Priya2698
Copy link
Collaborator

@Priya2698 Priya2698 commented Apr 2, 2025

Adds support for allgather if the sharded axis is not outermost.
ProcessGroupNCCL and UCC does require allocation of the sharded axis to be outermost. We do not change the logical shape, and instead permute the tensors to meet the requirements of NCCL and UCC within postAllgather.

This will be added within the reorderShardedAxis preseg pass to correctly set the loop and allocation domain for Allgather communication. Additionally, a set operation is needed to change the allocation of input if it does not have the sharded axis as the outermost allocated axis.

@Priya2698
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 2, 2025

Review updated until commit 1c8204c

Description

  • Added validation of tensor sizes and strides against tensorviews.

  • Ensured input and output tensors are contiguous for Allgather operations.

  • Updated tests to include noncontiguous tensors and multiple backends.


Changes walkthrough 📝

Relevant files
Enhancement
executor.cpp
Add tensor validation in executor                                               

csrc/host_ir/executor.cpp

  • Added validateTensors function to validate tensor sizes and strides.
  • Called validateTensors in HostIrExecutor::run and
    HostIrEvaluator::handle.
  • +27/-0   
    communication.cpp
    Ensure tensor contiguity in Allgather                                       

    csrc/multidevice/communication.cpp

  • Added isTvContiguous function to check tensorview contiguity.
  • Flattened input and output tensors in postAllgather to ensure
    contiguity.
  • +30/-4   
    Tests
    test_multidevice_communications.cpp
    Include additional headers in tests                                           

    tests/cpp/test_multidevice_communications.cpp

  • Included ops/all_ops.h and validator.h for additional operations and
    validation.
  • +2/-0     
    test_multidevice_host_ir.cpp
    Set tensor contiguity in tests                                                     

    tests/cpp/test_multidevice_host_ir.cpp

    • Set contiguity for communication input and output tensors in tests.
    +9/-0     
    test_multidevice_lower_communication.cpp
    Update and add tests for Allgather                                             

    tests/cpp/test_multidevice_lower_communication.cpp

  • Refactored LowerCollectiveTest to parameterize by backend and enable
    HostIrLowering.
  • Added AllgatherLoopSplit_Noncontig test for noncontiguous tensors.
  • +81/-67 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Performance Impact

    The addition of validateTensors calls in multiple places may introduce performance overhead. Ensure that this validation is necessary and does not degrade performance.

    namespace {
    // Validates the sizes and strides of the input and output tensors
    // against the tensorviews
    void validateTensors(
        const std::vector<at::Tensor>& tensors,
        const std::vector<TensorView*>& tvs,
        const ExpressionEvaluator& expr_eval) {
      NVF_ERROR(tensors.size() == tvs.size());
      for (const auto& [tensor, tv] : zip(tensors, tvs)) {
        if (tensor.defined()) {
          inferAndValidateAllocationSizesAndStrides(tensor, tv, expr_eval);
        }
      }
    }
    Contiguity Check

    The isTvContiguous function checks if all axes are contiguous, which might be too strict for some use cases. Consider if a more flexible contiguity check is needed.

    bool isTvContiguous(const TensorView* tv) {
      // Reduction and broadcast axis do not have a contiguity value.
      return std::all_of(
          tv->getContiguity().begin(),
          tv->getContiguity().end(),
          [](std::optional<bool> c) { return c.value_or(true); });
    }
    Test Coverage

    The new test AllgatherLoopSplit_Noncontig is a good addition, but ensure that it covers all edge cases and does not introduce false positives.

    TEST_P(LowerCollectiveTest, AllgatherLoopSplit_Noncontig) {
      auto fusion = std::make_unique<Fusion>();
      FusionGuard fg(fusion.get());
    
      // ProcessGroupNCCL requires the gathered axis to be outermost.
      // We change the allocation of tensorviews to reflect this.
      // We do not modify the logical shape of the tensorview.
      // This would still require one copy on each device if the input tensor is in
      // a different layout.
      const auto d = communicator_->size();
      auto mesh = DeviceMesh::createForNumDevices(d);
    
      TensorView* tv0 = makeConcreteTensor({5, d * 3});
      tv0->outer_split(1, d);
      tv0->axis(1)->parallelize(ParallelType::DIDx);
      tv0->reorder({{1, 0}, {2, 1}, {0, 2}});
      // tv0: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5]
    
      TensorView* tv1 = set(tv0);
      tv1->outer_split(1, d);
      tv1->axis(1)->parallelize(ParallelType::Serial);
      tv1->reorder({{1, 0}, {2, 1}, {0, 2}});
      // tv1: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5]
    
      for (auto tv : {tv0, tv1}) {
        tv->setDeviceMesh(mesh);
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      }
    
      fusion->addInput(tv0);
      fusion->addOutput(tv1);
    
      at::Tensor unsharded_in_tensor = at::randn({d * 3, 5}, tensor_options);
      at::Tensor in_tensor =
          shardTensor(unsharded_in_tensor, 0, mesh).transpose(0, 1);
    
      FusionExecutorCache executor_cache(std::move(fusion));
      at::Tensor out_tensor =
          executor_cache.runFusionWithInputs({in_tensor})[0].as<at::Tensor>();
    
      testValidate(
          executor_cache.fusion(),
          {out_tensor},
          {in_tensor},
          {unsharded_in_tensor.transpose(0, 1)},
          __LINE__,
          __FILE__);
    }

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 changed the title allgather loop split, contig + noncontig DID loop split for allgather for non-outermost sharded axis. Apr 4, 2025
    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 marked this pull request as ready for review April 4, 2025 19:09
    @Priya2698 Priya2698 requested review from cowanmeg and wujingyue April 4, 2025 19:09
    @wujingyue
    Copy link
    Collaborator

    LGTM otherwise. Thanks for the change!

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 requested a review from wujingyue April 8, 2025 02:04
    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM with comments

    // Presegmentation pass `makeReshardingContiguous` ensures that the tvs are contiguous
    // and HostIrExecutor validates the tensor against the tv allocation domain.

    auto flattened_output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1});
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Also check contiguity of communication->in() and out()?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It is already enforced by makeReshardingContiguous pass so I am not duplicating it here.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Nit: makeReshardingContiguous is a bit too far and many changes could happen in between. For example, makeReshardingContiguous runs before segmentation and postSingleCommunication is at runtime. makeReshardingContiguous works on fusion IR e.g. set and reduce, and postSingleCommunication works on host IR e.g. Communication.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I am running into some test failures with a contiguity check here for the manual tests in test_multidevice_host_ir.cpp. Since these tests do not set an allocation domain, we have the contiguity set to false. How cumbersome is it to require manual tests also to have the allocation domain set correctly?
    CC: @samnordmann

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    How cumbersome is it to require manual tests also to have the allocation domain set correctly?

    IIUC, it should be set correctly, so let's set it correctly. The change can't be too large because test_multidevice_host_ir.cpp is <500 lines and that file has <20 calls to make*Tensor, many of which are Contig already.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    hey Priya, I am not sure how cumbersome that is -- but if needed feel free to do it, and please let me know how it looks like.
    Let me know also if you need help

    @Priya2698
    Copy link
    Collaborator Author

    !test

    auto out_tensor = output_args[out_idx].as<at::Tensor>();

    c10::intrusive_ptr<c10d::Work> work = postSingleCommunication(
    c10::intrusive_ptr<c10d::Work> work = validateAndPostSingleCommunication(
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'll skip validation for HostIrExecutor. It's done elsewhere already.

    Input:

    inferAndValidateAllocationSizesAndStrides(input, tv, ee);

    Output:

    inferAndValidateAllocationSizesAndStrides(tensor, tv, expr_eval);

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Outputs are validated when they are not provided, i.e. output_args is empty. If not, then there is no validation. Inputs can be skipped, like you said.

    I am looking at the callgraph to see what is the case, where output_args come pre-allocated.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    what is the case, where output_args come pre-allocated

    I looked for this before. I found only unit tests that explicitly call KernelExecutor, not via FusionExecutorCache.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Right. FusionExecutorCache provides an empty output_args (

    ).
    I'll leave the validation in for just the output tensor and skip for inputs.

    Copy link
    Collaborator

    @cowanmeg cowanmeg left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM! Thanks!

    Comment on lines +100 to +102
    for (const auto i : c10::irange(tensors.size())) {
    const auto& tensor = tensors.at(i);
    const auto& tv = tvs.at(i);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    FWIW, you can zip instead:

    for (auto&& [id, new_id] : zip(self_logical, new_self_logical)) {
    .

    Welcome to C++20!

    }
    // getBackendForTeam throws an error if the requested backend type isn't
    // available. Therefore, we call it after the isBackendAvailable check.
    communicator_->setDefaultBackend(backend_type);
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I discovered that while this allows me to set the backend when using FusionExecutorCache, communication->backend() is still different from the backend parameter passed to postSingleCommunciation.
    getBackendForTeam when not provided a backend returns the default backend set for the communicator.
    For executions running through HostIrEvaluator, this approach will not work.

    We should change this to have a uniform value of backend between communicator and communication object for easier verification. I'll attempt this in a future PR.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698 Priya2698 merged commit 35f4aed into main Apr 11, 2025
    53 checks passed
    @Priya2698 Priya2698 deleted the pm/reorder branch April 11, 2025 17:32
    Priya2698 added a commit that referenced this pull request May 18, 2025
    …utor (#4470)
    
    Prep PR for Issue #3900.
    
    I am modifying the `reorderShardedAxisPass` to set allocation domain
    consistent with the memory layout requirements of ProcessGroup NCCL and
    UCC, without changing the logical shape (see PR #4170 for example).
    
    MultiDeviceExecutor does not respect allocation domain, hence, removing
    these tests. Issue #4453.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants