Skip to content

Remove tests with inner sharded dimensions when using MultiDeviceExecutor#4470

Merged
Priya2698 merged 3 commits intomainfrom
pm/remove_inner_sharding_tests
May 18, 2025
Merged

Remove tests with inner sharded dimensions when using MultiDeviceExecutor#4470
Priya2698 merged 3 commits intomainfrom
pm/remove_inner_sharding_tests

Conversation

@Priya2698
Copy link
Collaborator

Prep PR for Issue #3900.

I am modifying the reorderShardedAxisPass to set allocation domain consistent with the memory layout requirements of ProcessGroup NCCL and UCC, without changing the logical shape (see PR #4170 for example).

MultiDeviceExecutor does not respect allocation domain, hence, removing these tests. Issue #4453.

@github-actions
Copy link

github-actions bot commented May 16, 2025

Review updated until commit 3be85a4

Description

  • Removed tests with inner sharded dimensions for MultiDeviceExecutor

  • Added stride comparison in validation

  • Updated test instantiations to use only one device


Changes walkthrough 📝

Relevant files
Enhancement
test_multidevice_pipeline.cpp
Enhance tests and validation for MultiDeviceExecutor         

tests/cpp/test_multidevice_pipeline.cpp

  • Included additional headers for optimization passes
  • Added OptimizationPassGuard for ReorderShardedAxisPass
  • Modified constructor to initialize optimization_guard_
  • Added stride comparison in validate method
  • Updated test instantiations to use only one device
  • +15/-5   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Test Removal

    The PR removes tests with inner sharded dimensions when using MultiDeviceExecutor. It is important to ensure that the removed tests are no longer relevant or that equivalent tests are added to cover the same scenarios.

    DeviceMesh mesh1({1});
    DeviceMesh mesh2({0, 1, 2, 3});
    DeviceMesh mesh3({0, 2, 3});
    DeviceMesh mesh4({1, 0, 2});
    DeviceMesh mesh5({1, 0});
    auto all_meshes = testing::Values(mesh0, mesh1, mesh2, mesh3, mesh4, mesh5);
    auto all_nontrivial_meshes = testing::Values(mesh2, mesh3, mesh4, mesh5);
    
    } // namespace
    
    INSTANTIATE_TEST_SUITE_P(
        Gather,
        PipelineTestTwoStages,
        testing::Combine(
            testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
            all_meshes,
            all_meshes,
            testing::Values(true),
            testing::Values(false),
            testing::Values(false),
            testing::Values(0),
            testing::Bool()));
    
    INSTANTIATE_TEST_SUITE_P(
        Scatter,
        PipelineTestTwoStages,
        testing::Combine(
            testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
            all_meshes,
            all_meshes,
            testing::Values(false),
            testing::Values(true),
            testing::Values(false),
            testing::Values(0),
            testing::Bool()));
    
    INSTANTIATE_TEST_SUITE_P(
        Bcast,
        PipelineTestTwoStages,
        testing::Combine(
            testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
            all_meshes,
            all_meshes,
            testing::Values(false),
            testing::Values(false),
            testing::Values(false),
            testing::Values(0, 1),
            testing::Bool()));
    
    INSTANTIATE_TEST_SUITE_P(
        Bcast_sharded,
        PipelineTestTwoStages,
        testing::Combine(
            // TODO(#2794): add back CommunicatorBackend::kUcc
            testing::Values(CommunicatorBackend::kNccl),
            testing::Values(mesh3, mesh4),
            testing::Values(mesh3, mesh4),
            testing::Values(true),
            testing::Values(true),
            testing::Values(false),
            testing::Values(0, 1),
            testing::Bool()));
    
    INSTANTIATE_TEST_SUITE_P(
        Bcast_sharded_same_mesh,
        PipelineTestTwoStages,
        testing::Combine(
            testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc),
            testing::Values(mesh0, mesh1),
            testing::Values(mesh_null), // the same mesh is used for all tensors
            testing::Values(true),
            testing::Values(true),
            testing::Values(false),
            testing::Values(0, 1),
            testing::Bool()));
    
    INSTANTIATE_TEST_SUITE_P(
        Reduce,
        PipelineTestTwoStages,
        testing::Combine(
    Constructor Change

    The constructor for PipelineTest now initializes optimization_guard_ with false. This change should be validated to ensure it does not affect the behavior of the tests.

    PipelineTest::PipelineTest() : optimization_guard_(false) {
      fusion = std::make_unique<Fusion>();
    Stride Comparison

    A new stride comparison is added in the validate method. This should be verified to ensure it correctly checks the expected behavior and does not introduce any unintended side effects.

    EXPECT_EQ(ref_output.strides(), obtained_output.strides())
        << "Strides are not equal: Ref: " << ref_output.strides()
        << " Output: " << obtained_output.strides() << std::endl;
    

    shardTensor(ref_unsharded_outputs[i].as<at::Tensor>(), output_tv);
    auto obtained_output = outputs[i].as<at::Tensor>();

    EXPECT_EQ(ref_output.strides(), obtained_output.strides()) << "Strides are not equal: Ref: " << ref_output.strides() << " Output: " << obtained_output.strides() << std::endl;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    torch::allclose checks for sizes but not strides, so leaving this verification here for future.

    @Priya2698
    Copy link
    Collaborator Author

    !test

    @Priya2698
    Copy link
    Collaborator Author

    !build

    @Priya2698 Priya2698 merged commit d4968a5 into main May 18, 2025
    16 checks passed
    @Priya2698 Priya2698 deleted the pm/remove_inner_sharding_tests branch May 18, 2025 02:52
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants