From e0abd84bb5265f2c3ac2054ba908895a8412516b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 1 Apr 2025 22:27:47 -0700 Subject: [PATCH 01/21] allgather loop split, contig + noncontig --- csrc/scheduler/pointwise.cpp | 2 +- tests/cpp/test_multidevice_communications.cpp | 91 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index a9d72e29540..5dc8a98803e 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -347,7 +347,7 @@ std::unique_ptr getPointwiseHeuristics( auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids; auto& broadcast_byte_multiples = broadcast_info.get().broadcast_multiples; - NVF_ERROR(broadcast_byte_multiples.size() == ref_loop.size()); + NVF_ERROR(broadcast_byte_multiples.size() == largest_out->getLogicalDomain().size()); int64_t dtype_sum = 0; for (auto inp : ir_utils::filterByType(fusion->inputs())) { diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index af0c0719aa7..0bcdae8d012 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -17,6 +17,8 @@ #include #include #include +#include + #include @@ -409,6 +411,95 @@ TEST_P(CommunicationTest, ReduceScatter) { } } +TEST_P(CommunicationTest, AllgatherLoopSplit_NonContiguous) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto d = communicator_->size(); + + TensorView* tv0 = makeConcreteTensor({5, d*3}); + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::DIDx); + reorderDIDToFront(tv0); + + TensorView* tv1 = permute(tv0, {{1, 0}}); + tv1->outer_split(0, d); + tv1->axis(0)->parallelize(ParallelType::DIDx); + + TensorView* tv2 = set(tv1); + tv2->outer_split(0, d); + tv2->axis(0)->parallelize(ParallelType::Serial); + + TensorView* tv3 = permute(tv2, {{0, 1}}); + tv3->outer_split(1, d); + tv3->axis(1)->parallelize(ParallelType::Serial); + tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); + + for (auto tv : {tv0, tv1, tv2, tv3}) { + tv->setDeviceMesh(full_mesh_); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + fusion->addInput(tv0); + fusion->addOutput(tv3); + + at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); + at::Tensor in_tensor = shardTensor(unsharded_in_tensor, tv0); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({in_tensor})[0].as(); + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {unsharded_in_tensor}, + __LINE__, + __FILE__); +} + +TEST_P(CommunicationTest, AllgatherLoopSplit_Contiguous) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto d = communicator_->size(); + + TensorView* tv0 = makeConcreteTensor({d*3, 5}); + TensorView* tv1 = set(tv0); + TensorView* tv2 = permute(tv1, {{1, 0}}); + fusion->addInput(tv0); + fusion->addOutput(tv2); + + tv0->outer_split(0, d); + TransformPropagator propagator(tv0); + SetSelector selector({tv1, tv2}); + MaxLogicalDomainInfoSpanningTree(tv0, &selector).traverse(&propagator); + + tv0->setDeviceMesh(full_mesh_); + shardAllLike(tv0, {tv1, tv2}); + + tv0->axis(0)->parallelize(ParallelType::DIDx); + + for (auto tv : fusion->allTvs()) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); + at::Tensor in_tensor = shardTensor(unsharded_in_tensor, tv0); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({in_tensor})[0].as(); + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {unsharded_in_tensor.transpose(0, 1)}, + __LINE__, + __FILE__); +} + INSTANTIATE_TEST_SUITE_P( , CommunicationTest, From 91757dd90151d0aeab607fd3a74fa84f74f31de5 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 1 Apr 2025 23:43:04 -0700 Subject: [PATCH 02/21] no devices logical domain --- csrc/scheduler/pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 5dc8a98803e..df29061e432 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -347,7 +347,7 @@ std::unique_ptr getPointwiseHeuristics( auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids; auto& broadcast_byte_multiples = broadcast_info.get().broadcast_multiples; - NVF_ERROR(broadcast_byte_multiples.size() == largest_out->getLogicalDomain().size()); + NVF_ERROR(broadcast_byte_multiples.size() == TensorDomain::noDevices(largest_out->getLogicalDomain()).size(), "Broadcast byte multiples size mismatch: ", broadcast_byte_multiples.size(), " != ", largest_out->getLogicalDomain(), "Loop domain:", ref_loop); int64_t dtype_sum = 0; for (auto inp : ir_utils::filterByType(fusion->inputs())) { From 2c29ba43707ae035b3438651420e8a7b116e01ce Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 1 Apr 2025 23:48:53 -0700 Subject: [PATCH 03/21] check non-device, non-reduction logical shape --- csrc/scheduler/pointwise.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index df29061e432..2cecfa5af49 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -347,7 +347,7 @@ std::unique_ptr getPointwiseHeuristics( auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids; auto& broadcast_byte_multiples = broadcast_info.get().broadcast_multiples; - NVF_ERROR(broadcast_byte_multiples.size() == TensorDomain::noDevices(largest_out->getLogicalDomain()).size(), "Broadcast byte multiples size mismatch: ", broadcast_byte_multiples.size(), " != ", largest_out->getLogicalDomain(), "Loop domain:", ref_loop); + NVF_ERROR(broadcast_byte_multiples.size() == TensorDomain::noDevices(TensorDomain::noReductions(largest_out->getLogicalDomain())).size(), "Broadcast byte multiples size mismatch: ", broadcast_byte_multiples.size(), " != ", largest_out->getLogicalDomain(), "Loop domain:", ref_loop); int64_t dtype_sum = 0; for (auto inp : ir_utils::filterByType(fusion->inputs())) { From c57a731248befa971f04c277b7671109e2a31b00 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 2 Apr 2025 20:22:08 -0700 Subject: [PATCH 04/21] fix scatter for loop split --- csrc/multidevice/communication.cpp | 23 +++++--- tests/cpp/test_multidevice_communications.cpp | 56 +++++++++++++++++++ 2 files changed, 70 insertions(+), 9 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 041e13eb80c..aaeca7c9970 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -378,24 +378,27 @@ c10::intrusive_ptr postScatter( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - if (my_device_index == communication->root() && - !communication->out()->getDeviceMesh().has(communication->root())) { - output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); - } - std::vector output_tensors({output_tensor}); - + + auto output_device_mesh = communication->out()->getDeviceMesh(); + bool output_has_root = output_device_mesh.has(communication->root()); auto root_relative_index = communication->getRootRelativeIndex(); + std::vector> input_tensors; + if (my_device_index == communication->root()) { + auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/0); + if (!output_has_root) { + output_tensor = at::empty_like(splits.at(0)); + } input_tensors.resize(1); int64_t j = 0; for (auto i : arange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !communication->out()->getDeviceMesh().has(communication->root())) { + !output_has_root) { input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); + input_tensors.front().push_back(splits.at(j)); j++; } @@ -403,8 +406,10 @@ c10::intrusive_ptr postScatter( assertBuffersHaveSameSize(input_tensors[0], output_tensors); } + std::vector output_tensors({output_tensor}); + return backend->scatter( - output_tensors, input_tensors, {.rootRank = root_relative_index}); + {output_tensor}, input_tensors, {.rootRank = root_relative_index}); } c10::intrusive_ptr postReduce( diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 0bcdae8d012..7b4c65e57b9 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -500,6 +500,62 @@ TEST_P(CommunicationTest, AllgatherLoopSplit_Contiguous) { __FILE__); } +TEST_P(CommunicationTest, Scatter_NonContiguous) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto d = communicator_->size(); + + DeviceMesh mesh_zero({0}); + + TensorView* tv0 = makeConcreteTensor({5, d*3}); + TensorView* tv1 = permute(tv0, {{1, 0}}); + TensorView* tv2 = set(tv1); + TensorView* tv3 = permute(tv2, {{1, 0}}); + + tv0->setDeviceMesh(mesh_zero); + tv1->setDeviceMesh(mesh_zero); + tv2->setDeviceMesh(full_mesh_); + tv3->setDeviceMesh(full_mesh_); + + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::Serial); + + tv1->outer_split(0, d); + tv1->axis(0)->parallelize(ParallelType::Serial); + + tv2->outer_split(0, d); + tv2->axis(0)->parallelize(ParallelType::DIDx); + + tv3->outer_split(1, d); + tv3->axis(1)->parallelize(ParallelType::DIDx); + tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); + + fusion->addInput(tv0); + fusion->addOutput(tv3); + + for (auto tv : {tv0, tv1, tv2, tv3}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + debug() << "tv: " << tv->toString() << std::endl; + debug() << "Logical domain: " << tv->getLogicalDomain() << std::endl; + debug() << "Allocation domain: " << tv->getAllocationDomain() << std::endl; + } + + at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({unsharded_in_tensor})[0].as(); + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {unsharded_in_tensor}, + {expected_output}, + __LINE__, + __FILE__); +} + INSTANTIATE_TEST_SUITE_P( , CommunicationTest, From 096ac03853e8af7d30ed810704f4571878eb0f64 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Wed, 2 Apr 2025 22:40:31 -0700 Subject: [PATCH 05/21] update postAllScatter, add tests for ReduceScatter --- csrc/multidevice/communication.cpp | 6 +- tests/cpp/test_multidevice_communications.cpp | 116 +++++++++--------- 2 files changed, 63 insertions(+), 59 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index aaeca7c9970..7b63adf8aaf 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -384,11 +384,12 @@ c10::intrusive_ptr postScatter( auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; + std::vector output_tensors({output_tensor}); if (my_device_index == communication->root()) { auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/0); if (!output_has_root) { - output_tensor = at::empty_like(splits.at(0)); + output_tensors[0] = at::empty_like(splits.at(0)); } input_tensors.resize(1); int64_t j = 0; @@ -406,10 +407,9 @@ c10::intrusive_ptr postScatter( assertBuffersHaveSameSize(input_tensors[0], output_tensors); } - std::vector output_tensors({output_tensor}); return backend->scatter( - {output_tensor}, input_tensors, {.rootRank = root_relative_index}); + output_tensors, input_tensors, {.rootRank = root_relative_index}); } c10::intrusive_ptr postReduce( diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 7b4c65e57b9..c938cc66298 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -412,6 +412,10 @@ TEST_P(CommunicationTest, ReduceScatter) { } TEST_P(CommunicationTest, AllgatherLoopSplit_NonContiguous) { + // NCCL and UCC do not support non-contiguous tensors. + // Therefore, we need to add permute operations to make the tensor contiguous. + // Note, that, modifying the allocation domain such that the gather axis is outermost + // is not sufficient, requiring logical shape change. auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -458,99 +462,99 @@ TEST_P(CommunicationTest, AllgatherLoopSplit_NonContiguous) { __FILE__); } -TEST_P(CommunicationTest, AllgatherLoopSplit_Contiguous) { +TEST_P(CommunicationTest, ScatterLoopSplit_NonContiguous) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const auto d = communicator_->size(); + + DeviceMesh mesh_zero({0}); - TensorView* tv0 = makeConcreteTensor({d*3, 5}); - TensorView* tv1 = set(tv0); - TensorView* tv2 = permute(tv1, {{1, 0}}); - fusion->addInput(tv0); - fusion->addOutput(tv2); + TensorView* tv0 = makeConcreteTensor({5, d*3}); + // TensorView* tv1 = permute(tv0, {{1, 0}}); + TensorView* tv2 = set(tv0); + // TensorView* tv3 = permute(tv2, {{1, 0}}); - tv0->outer_split(0, d); - TransformPropagator propagator(tv0); - SetSelector selector({tv1, tv2}); - MaxLogicalDomainInfoSpanningTree(tv0, &selector).traverse(&propagator); - - tv0->setDeviceMesh(full_mesh_); - shardAllLike(tv0, {tv1, tv2}); + tv0->setDeviceMesh(mesh_zero); + // tv1->setDeviceMesh(mesh_zero); + tv2->setDeviceMesh(full_mesh_); + // tv3->setDeviceMesh(full_mesh_); - tv0->axis(0)->parallelize(ParallelType::DIDx); + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::Serial); + + // tv1->outer_split(0, d); + // tv1->axis(0)->parallelize(ParallelType::Serial); + + // tv2->outer_split(0, d); + // tv2->axis(0)->parallelize(ParallelType::DIDx); + + tv2->outer_split(1, d); + tv2->axis(1)->parallelize(ParallelType::DIDx); + // tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); + + fusion->addInput(tv0); + fusion->addOutput(tv2); - for (auto tv : fusion->allTvs()) { + for (auto tv : {tv0, tv2}) { tv->setAllocationDomain(tv->getLoopDomain(), true); + debug() << "tv: " << tv->toString() << std::endl; + debug() << "Logical domain: " << tv->getLogicalDomain() << std::endl; + debug() << "Allocation domain: " << tv->getAllocationDomain() << std::endl; } - at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); - at::Tensor in_tensor = shardTensor(unsharded_in_tensor, tv0); - + at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor out_tensor = - executor_cache.runFusionWithInputs({in_tensor})[0].as(); + executor_cache.runFusionWithInputs({unsharded_in_tensor})[0].as(); testValidate( executor_cache.fusion(), {out_tensor}, - {in_tensor}, - {unsharded_in_tensor.transpose(0, 1)}, + {unsharded_in_tensor}, + {expected_output}, __LINE__, - __FILE__); + __FILE__); } -TEST_P(CommunicationTest, Scatter_NonContiguous) { +TEST_P(CommunicationTest, ReduceScatterLoopSplit_NonContiguous) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const auto d = communicator_->size(); - - DeviceMesh mesh_zero({0}); - TensorView* tv0 = makeConcreteTensor({5, d*3}); - TensorView* tv1 = permute(tv0, {{1, 0}}); - TensorView* tv2 = set(tv1); - TensorView* tv3 = permute(tv2, {{1, 0}}); - - tv0->setDeviceMesh(mesh_zero); - tv1->setDeviceMesh(mesh_zero); - tv2->setDeviceMesh(full_mesh_); - tv3->setDeviceMesh(full_mesh_); + TensorView* tv0 = makeConcreteTensor({5, d*3, d*7}); + TensorView* tv1 = sum(tv0, {1}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv0->outer_split(1, d); - tv0->axis(1)->parallelize(ParallelType::Serial); - - tv1->outer_split(0, d); - tv1->axis(0)->parallelize(ParallelType::Serial); - - tv2->outer_split(0, d); - tv2->axis(0)->parallelize(ParallelType::DIDx); + tv0->axis(1)->parallelize(ParallelType::DIDx); - tv3->outer_split(1, d); - tv3->axis(1)->parallelize(ParallelType::DIDx); - tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); + tv1->outer_split(1, d); + TensorView* tv2 = tv1->rFactor({2}); + tv2->axis(1)->parallelize(ParallelType::DIDx); - fusion->addInput(tv0); - fusion->addOutput(tv3); + tv1->outer_split(2, d); + tv1->axis(2)->parallelize(ParallelType::DIDx); - for (auto tv : {tv0, tv1, tv2, tv3}) { + for (auto tv : {tv0, tv1, tv2}) { + tv->setDeviceMesh(full_mesh_); tv->setAllocationDomain(tv->getLoopDomain(), true); - debug() << "tv: " << tv->toString() << std::endl; - debug() << "Logical domain: " << tv->getLogicalDomain() << std::endl; - debug() << "Allocation domain: " << tv->getAllocationDomain() << std::endl; } - - at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); - at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); + + at::Tensor unsharded_in_tensor = at::randn({5, d*3, d*7}, tensor_options); + at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 1, full_mesh_); + at::Tensor expected_output = shardTensor(unsharded_in_tensor.sum(1), -1, full_mesh_); FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor out_tensor = - executor_cache.runFusionWithInputs({unsharded_in_tensor})[0].as(); - + executor_cache.runFusionWithInputs({in_tensor})[0].as(); testValidate( executor_cache.fusion(), {out_tensor}, - {unsharded_in_tensor}, + {in_tensor}, {expected_output}, __LINE__, __FILE__); From ea2320708883eebe2d01c6f11a027298771ef770 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 16:55:40 -0700 Subject: [PATCH 06/21] another approach for noncontig tensors --- csrc/multidevice/communication.cpp | 4 +- csrc/multidevice/utils.cpp | 13 +- tests/cpp/test_multidevice_communications.cpp | 143 +++--------------- 3 files changed, 32 insertions(+), 128 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 7b63adf8aaf..6e9d84baafb 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) #include #endif @@ -385,9 +386,10 @@ c10::intrusive_ptr postScatter( std::vector> input_tensors; std::vector output_tensors({output_tensor}); + int64_t scattered_axis = getShardedLogicalAxis(communication->out(), ParallelType::DIDx); if (my_device_index == communication->root()) { - auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/0); + auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/scattered_axis); if (!output_has_root) { output_tensors[0] = at::empty_like(splits.at(0)); } diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 5df7acc7b05..56b28f8e567 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -170,21 +170,22 @@ int64_t getShardedLogicalAxis( const TensorView* tv, const ParallelType parallel_type) { std::unordered_map parallel_type_to_id = - mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); - IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); - if (alloc_id == nullptr) { + mapDeviceParallelTypeToId(tv->getLoopDomain()); + + IterDomain* loop_id = getOrDefault(parallel_type_to_id, parallel_type); + if (loop_id == nullptr) { return -1; } std::unordered_map logical_id_to_axis = mapIterDomainToTensorAxis(tv->getLogicalDomain()); - IterDomain* id = alloc_id; + IterDomain* id = loop_id; while (logical_id_to_axis.count(id) == 0) { Expr* def = id->definition(); NVF_ERROR( def != nullptr, "Failed to find a non-reduction logical IterDomain that produces ", - alloc_id); + loop_id); if (auto* split = dynamic_cast(def)) { // Returning just which tensor axis is sharded isn't sufficient to let // shardTensor, a user of this function, know how to shard the tensor. @@ -266,7 +267,7 @@ at::Tensor shardTensor( auto extent = tensor.size(axis); auto nslices = mesh.size(); NVF_CHECK( - extent % nslices == 0, "Sharded axis must be evenly divisble by mesh"); + extent % nslices == 0, "Sharded axis must be evenly divisble by mesh: ", extent, " % ", nslices); auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index c938cc66298..a00a3eb93b5 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -109,7 +109,10 @@ TEST_P(CommunicationTest, Allgather) { FusionGuard fg(&container); auto* in = makeContigTensor(2); in->setDeviceMesh(full_mesh_); + in->axis(0)->parallelize(ParallelType::DIDx); auto* out = ops::newValLike(in, in->dtype())->as(); + out->axis(0)->parallelize(ParallelType::Serial); + auto communication = IrBuilder::create( CommunicationType::Allgather, out, in, all_ranks_); @@ -411,153 +414,51 @@ TEST_P(CommunicationTest, ReduceScatter) { } } -TEST_P(CommunicationTest, AllgatherLoopSplit_NonContiguous) { - // NCCL and UCC do not support non-contiguous tensors. - // Therefore, we need to add permute operations to make the tensor contiguous. - // Note, that, modifying the allocation domain such that the gather axis is outermost - // is not sufficient, requiring logical shape change. +TEST_P(CommunicationTest, AllgatherLoopSplit) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); + // ProcessGroupNCCL requires the gathered axis to be outermost. + // We change the allocation of tensorviews to reflect this. + // We do not modify the logical shape of the tensorview. + // When posting communication, we permute the tensor to match the ProcessGroupNCCL contiguity requirements. + // This would still require one copy on each device if the input tensor is in a different layout. const auto d = communicator_->size(); TensorView* tv0 = makeConcreteTensor({5, d*3}); tv0->outer_split(1, d); tv0->axis(1)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv0); - - TensorView* tv1 = permute(tv0, {{1, 0}}); - tv1->outer_split(0, d); - tv1->axis(0)->parallelize(ParallelType::DIDx); + tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); + // tv0: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5] - TensorView* tv2 = set(tv1); - tv2->outer_split(0, d); - tv2->axis(0)->parallelize(ParallelType::Serial); - - TensorView* tv3 = permute(tv2, {{0, 1}}); - tv3->outer_split(1, d); - tv3->axis(1)->parallelize(ParallelType::Serial); - tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); + TensorView* tv1 = set(tv0); + tv1->outer_split(1, d); + tv1->axis(1)->parallelize(ParallelType::Serial); + tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); + // tv1: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5] - for (auto tv : {tv0, tv1, tv2, tv3}) { + for (auto tv : {tv0, tv1}) { tv->setDeviceMesh(full_mesh_); tv->setAllocationDomain(tv->getLoopDomain(), true); } - fusion->addInput(tv0); - fusion->addOutput(tv3); - - at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); - at::Tensor in_tensor = shardTensor(unsharded_in_tensor, tv0); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor out_tensor = - executor_cache.runFusionWithInputs({in_tensor})[0].as(); - testValidate( - executor_cache.fusion(), - {out_tensor}, - {in_tensor}, - {unsharded_in_tensor}, - __LINE__, - __FILE__); -} - -TEST_P(CommunicationTest, ScatterLoopSplit_NonContiguous) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const auto d = communicator_->size(); - - DeviceMesh mesh_zero({0}); - - TensorView* tv0 = makeConcreteTensor({5, d*3}); - // TensorView* tv1 = permute(tv0, {{1, 0}}); - TensorView* tv2 = set(tv0); - // TensorView* tv3 = permute(tv2, {{1, 0}}); - - tv0->setDeviceMesh(mesh_zero); - // tv1->setDeviceMesh(mesh_zero); - tv2->setDeviceMesh(full_mesh_); - // tv3->setDeviceMesh(full_mesh_); - - tv0->outer_split(1, d); - tv0->axis(1)->parallelize(ParallelType::Serial); - - // tv1->outer_split(0, d); - // tv1->axis(0)->parallelize(ParallelType::Serial); - - // tv2->outer_split(0, d); - // tv2->axis(0)->parallelize(ParallelType::DIDx); - - tv2->outer_split(1, d); - tv2->axis(1)->parallelize(ParallelType::DIDx); - // tv3->reorder({{1, 0}, {2, 1}, {0, 2}}); - - fusion->addInput(tv0); - fusion->addOutput(tv2); - - for (auto tv : {tv0, tv2}) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - debug() << "tv: " << tv->toString() << std::endl; - debug() << "Logical domain: " << tv->getLogicalDomain() << std::endl; - debug() << "Allocation domain: " << tv->getAllocationDomain() << std::endl; - } - - at::Tensor unsharded_in_tensor = at::randn({5, d*3}, tensor_options); - at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor out_tensor = - executor_cache.runFusionWithInputs({unsharded_in_tensor})[0].as(); - - testValidate( - executor_cache.fusion(), - {out_tensor}, - {unsharded_in_tensor}, - {expected_output}, - __LINE__, - __FILE__); -} - -TEST_P(CommunicationTest, ReduceScatterLoopSplit_NonContiguous) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const auto d = communicator_->size(); - - TensorView* tv0 = makeConcreteTensor({5, d*3, d*7}); - TensorView* tv1 = sum(tv0, {1}); - fusion->addInput(tv0); fusion->addOutput(tv1); - - tv0->outer_split(1, d); - tv0->axis(1)->parallelize(ParallelType::DIDx); - tv1->outer_split(1, d); - TensorView* tv2 = tv1->rFactor({2}); - tv2->axis(1)->parallelize(ParallelType::DIDx); - - tv1->outer_split(2, d); - tv1->axis(2)->parallelize(ParallelType::DIDx); - - for (auto tv : {tv0, tv1, tv2}) { - tv->setDeviceMesh(full_mesh_); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } + at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); + at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); - at::Tensor unsharded_in_tensor = at::randn({5, d*3, d*7}, tensor_options); - at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 1, full_mesh_); - at::Tensor expected_output = shardTensor(unsharded_in_tensor.sum(1), -1, full_mesh_); FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0].as(); + testValidate( executor_cache.fusion(), {out_tensor}, {in_tensor}, - {expected_output}, + {unsharded_in_tensor.transpose(0, 1)}, __LINE__, - __FILE__); + __FILE__); } INSTANTIATE_TEST_SUITE_P( From de893376d50be9a36bfd8ad9e01091001edebd7a Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 17:00:00 -0700 Subject: [PATCH 07/21] move scatter, pointwise changes to other PR --- csrc/multidevice/communication.cpp | 21 ++++++++------------- csrc/scheduler/pointwise.cpp | 2 +- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 6e9d84baafb..9faf1abdcc3 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -380,28 +380,24 @@ c10::intrusive_ptr postScatter( at::Tensor input_tensor, at::Tensor output_tensor) { - auto output_device_mesh = communication->out()->getDeviceMesh(); - bool output_has_root = output_device_mesh.has(communication->root()); - auto root_relative_index = communication->getRootRelativeIndex(); + if (my_device_index == communication->root() && + !communication->out()->getDeviceMesh().has(communication->root())) { + output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); + } + std::vector output_tensors({output_tensor}); + auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; - std::vector output_tensors({output_tensor}); - int64_t scattered_axis = getShardedLogicalAxis(communication->out(), ParallelType::DIDx); - if (my_device_index == communication->root()) { - auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/scattered_axis); - if (!output_has_root) { - output_tensors[0] = at::empty_like(splits.at(0)); - } input_tensors.resize(1); int64_t j = 0; for (auto i : arange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !output_has_root) { + !communication->out()->getDeviceMesh().has(communication->root())) { input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(splits.at(j)); + input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); j++; } @@ -409,7 +405,6 @@ c10::intrusive_ptr postScatter( assertBuffersHaveSameSize(input_tensors[0], output_tensors); } - return backend->scatter( output_tensors, input_tensors, {.rootRank = root_relative_index}); } diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index 2cecfa5af49..a9d72e29540 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -347,7 +347,7 @@ std::unique_ptr getPointwiseHeuristics( auto& view_disjoint_sets = broadcast_info.get().view_disjoint_set_ids; auto& broadcast_byte_multiples = broadcast_info.get().broadcast_multiples; - NVF_ERROR(broadcast_byte_multiples.size() == TensorDomain::noDevices(TensorDomain::noReductions(largest_out->getLogicalDomain())).size(), "Broadcast byte multiples size mismatch: ", broadcast_byte_multiples.size(), " != ", largest_out->getLogicalDomain(), "Loop domain:", ref_loop); + NVF_ERROR(broadcast_byte_multiples.size() == ref_loop.size()); int64_t dtype_sum = 0; for (auto inp : ir_utils::filterByType(fusion->inputs())) { From 2008e60a15a6a5ebc81d85c59943cf4aea5df1a8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 17:03:03 -0700 Subject: [PATCH 08/21] undo extraneous change --- csrc/multidevice/communication.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 9faf1abdcc3..134640339ef 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -379,7 +379,6 @@ c10::intrusive_ptr postScatter( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - if (my_device_index == communication->root() && !communication->out()->getDeviceMesh().has(communication->root())) { output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); From c817d9c9ac7d5e592f6fa513eb9bbedc09f678c8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 18:42:19 -0700 Subject: [PATCH 09/21] avoid using getShardedLogicalAxis --- csrc/multidevice/communication.cpp | 1 - csrc/multidevice/utils.cpp | 13 ++++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 134640339ef..041e13eb80c 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL) #include #endif diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 56b28f8e567..5df7acc7b05 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -170,22 +170,21 @@ int64_t getShardedLogicalAxis( const TensorView* tv, const ParallelType parallel_type) { std::unordered_map parallel_type_to_id = - mapDeviceParallelTypeToId(tv->getLoopDomain()); - - IterDomain* loop_id = getOrDefault(parallel_type_to_id, parallel_type); - if (loop_id == nullptr) { + mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); + IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); + if (alloc_id == nullptr) { return -1; } std::unordered_map logical_id_to_axis = mapIterDomainToTensorAxis(tv->getLogicalDomain()); - IterDomain* id = loop_id; + IterDomain* id = alloc_id; while (logical_id_to_axis.count(id) == 0) { Expr* def = id->definition(); NVF_ERROR( def != nullptr, "Failed to find a non-reduction logical IterDomain that produces ", - loop_id); + alloc_id); if (auto* split = dynamic_cast(def)) { // Returning just which tensor axis is sharded isn't sufficient to let // shardTensor, a user of this function, know how to shard the tensor. @@ -267,7 +266,7 @@ at::Tensor shardTensor( auto extent = tensor.size(axis); auto nslices = mesh.size(); NVF_CHECK( - extent % nslices == 0, "Sharded axis must be evenly divisble by mesh: ", extent, " % ", nslices); + extent % nslices == 0, "Sharded axis must be evenly divisble by mesh"); auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; From 63be7ba85f7d935ef951222d558b593ef9e36dcf Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 18:44:47 -0700 Subject: [PATCH 10/21] undo adding sharding to communication test --- tests/cpp/test_multidevice_communications.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index a00a3eb93b5..d32de56e716 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -109,10 +109,7 @@ TEST_P(CommunicationTest, Allgather) { FusionGuard fg(&container); auto* in = makeContigTensor(2); in->setDeviceMesh(full_mesh_); - in->axis(0)->parallelize(ParallelType::DIDx); auto* out = ops::newValLike(in, in->dtype())->as(); - out->axis(0)->parallelize(ParallelType::Serial); - auto communication = IrBuilder::create( CommunicationType::Allgather, out, in, all_ranks_); From 0702cd7a3e4bd4a6e11209dc26e9c2e36031794b Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 8 Apr 2025 13:19:59 -0700 Subject: [PATCH 11/21] pm/reorder --- csrc/host_ir/executor.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index a4cdb3d2717..b7774769e47 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -166,7 +166,8 @@ KernelArgumentHolder HostIrExecutor::run( communicator_->deviceId(), backend, in_tensor, - out_tensor); + out_tensor, + expr_eval); if (work != nullptr) { work->wait(); } @@ -503,7 +504,8 @@ void HostIrEvaluator::handle(Communication* communication) { communicator_->deviceId(), backend, input_tensor, - output_tensor); + output_tensor, + expr_evaluator_); } void HostIrEvaluator::handle(P2PCommunication* communication) { From 02135c687c750c2696db5af992c764012616667f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Tue, 8 Apr 2025 13:49:06 -0700 Subject: [PATCH 12/21] lintrunner --- tests/cpp/test_multidevice_communications.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index d32de56e716..e68ca75ea74 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -17,8 +17,6 @@ #include #include #include -#include - #include @@ -418,11 +416,11 @@ TEST_P(CommunicationTest, AllgatherLoopSplit) { // ProcessGroupNCCL requires the gathered axis to be outermost. // We change the allocation of tensorviews to reflect this. // We do not modify the logical shape of the tensorview. - // When posting communication, we permute the tensor to match the ProcessGroupNCCL contiguity requirements. - // This would still require one copy on each device if the input tensor is in a different layout. + // This would still require one copy on each device if the input tensor is in + // a different layout. const auto d = communicator_->size(); - TensorView* tv0 = makeConcreteTensor({5, d*3}); + TensorView* tv0 = makeConcreteTensor({5, d * 3}); tv0->outer_split(1, d); tv0->axis(1)->parallelize(ParallelType::DIDx); tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); @@ -442,9 +440,10 @@ TEST_P(CommunicationTest, AllgatherLoopSplit) { fusion->addInput(tv0); fusion->addOutput(tv1); - at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); - at::Tensor in_tensor = shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); - + at::Tensor unsharded_in_tensor = at::randn({d * 3, 5}, tensor_options); + at::Tensor in_tensor = + shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); + FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0].as(); @@ -455,7 +454,7 @@ TEST_P(CommunicationTest, AllgatherLoopSplit) { {in_tensor}, {unsharded_in_tensor.transpose(0, 1)}, __LINE__, - __FILE__); + __FILE__); } INSTANTIATE_TEST_SUITE_P( From 025c414a53543f21eba3a962f9a3973b90949a33 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Thu, 3 Apr 2025 18:55:53 -0700 Subject: [PATCH 13/21] avoid using getShardedLogicalAxis --- csrc/multidevice/communication.cpp | 27 ++++++++---- tests/cpp/test_multidevice_communications.cpp | 43 +++++++++++++++++++ 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 041e13eb80c..d02d3c91a72 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -378,24 +378,35 @@ c10::intrusive_ptr postScatter( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - if (my_device_index == communication->root() && - !communication->out()->getDeviceMesh().has(communication->root())) { - output_tensor = at::empty_like(input_tensor.slice(0, 0, 1)); - } - std::vector output_tensors({output_tensor}); - + + auto output_device_mesh = communication->out()->getDeviceMesh(); + bool output_has_root = output_device_mesh.has(communication->root()); auto root_relative_index = communication->getRootRelativeIndex(); + std::vector> input_tensors; + std::vector output_tensors({output_tensor}); + + // Presegmentation should ensure outermost allocation of scattered axis required for correct results. + // Scatter does not require the input_tensor.is_contiguous() to be true so we do not permute the input tensor. + + // Get contiguity permutation to find the scattered axis. + auto dims = getContiguityPermutation(input_tensor); + auto scattered_axis = dims.at(0); + if (my_device_index == communication->root()) { + auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/scattered_axis); + if (!output_has_root) { + output_tensors[0] = at::empty_like(splits.at(0)); + } input_tensors.resize(1); int64_t j = 0; for (auto i : arange(communication->team().size())) { if (root_relative_index == static_cast(i) && - !communication->out()->getDeviceMesh().has(communication->root())) { + !output_has_root) { input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); + input_tensors.front().push_back(splits.at(j)); j++; } diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index e68ca75ea74..93f372c735d 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -457,6 +457,49 @@ TEST_P(CommunicationTest, AllgatherLoopSplit) { __FILE__); } +TEST_P(CommunicationTest, ScatterLoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto d = communicator_->size(); + + DeviceMesh mesh_zero({0}); + TensorView* tv0 = makeConcreteTensor({5, d*3}); + TensorView* tv1 = set(tv0); + + tv0->setDeviceMesh(mesh_zero); + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::Serial); + tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); + + tv1->setDeviceMesh(full_mesh_); + tv1->outer_split(1, d); + tv1->axis(1)->parallelize(ParallelType::DIDx); + tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + + for (auto tv : {tv0, tv1}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); + + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({unsharded_in_tensor.transpose(0, 1)})[0].as(); + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {unsharded_in_tensor.transpose(0, 1)}, + {expected_output}, + __LINE__, + __FILE__); +} + INSTANTIATE_TEST_SUITE_P( , CommunicationTest, From 0426039cf825b79d28c6823546a55f09d37fb74f Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 4 Apr 2025 12:13:27 -0700 Subject: [PATCH 14/21] only get the scattered axis on root --- csrc/multidevice/communication.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index d02d3c91a72..30d7e802586 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -385,15 +385,15 @@ c10::intrusive_ptr postScatter( std::vector> input_tensors; std::vector output_tensors({output_tensor}); - + + if (my_device_index == communication->root()) { // Presegmentation should ensure outermost allocation of scattered axis required for correct results. // Scatter does not require the input_tensor.is_contiguous() to be true so we do not permute the input tensor. // Get contiguity permutation to find the scattered axis. auto dims = getContiguityPermutation(input_tensor); - auto scattered_axis = dims.at(0); - - if (my_device_index == communication->root()) { + int64_t scattered_axis = dims.at(0); + auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/scattered_axis); if (!output_has_root) { output_tensors[0] = at::empty_like(splits.at(0)); From dc37a7db6b614e1a3ddc0505220d03c8b3392bb8 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 7 Apr 2025 16:17:30 -0700 Subject: [PATCH 15/21] flatten inputs outputs in scatter --- csrc/multidevice/communication.cpp | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 30d7e802586..a1dc55be5a5 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -384,17 +384,10 @@ c10::intrusive_ptr postScatter( auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; - std::vector output_tensors({output_tensor}); + std::vector output_tensors({output_tensor.as_strided({output_tensor.numel()}, {1})}); if (my_device_index == communication->root()) { - // Presegmentation should ensure outermost allocation of scattered axis required for correct results. - // Scatter does not require the input_tensor.is_contiguous() to be true so we do not permute the input tensor. - - // Get contiguity permutation to find the scattered axis. - auto dims = getContiguityPermutation(input_tensor); - int64_t scattered_axis = dims.at(0); - - auto splits = at::tensor_split(input_tensor, output_device_mesh.size(), /*dim=*/scattered_axis); + auto splits = at::tensor_split(input_tensor.as_strided({input_tensor.numel()}, {1}), output_device_mesh.size(), /*dim=*/0); if (!output_has_root) { output_tensors[0] = at::empty_like(splits.at(0)); } @@ -408,10 +401,10 @@ c10::intrusive_ptr postScatter( } input_tensors.front().push_back(splits.at(j)); j++; - } + } - assertBufferCount(input_tensors[0], communication->team().size()); - assertBuffersHaveSameSize(input_tensors[0], output_tensors); + assertBufferCount(input_tensors[0], communication->team().size()); + assertBuffersHaveSameSize(input_tensors[0], output_tensors); } return backend->scatter( From 2c3bc1e3bbe047f741a8d85d05daeaf72ca5083c Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 16 May 2025 15:36:34 -0700 Subject: [PATCH 16/21] undo change --- csrc/host_ir/executor.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index b7774769e47..a4cdb3d2717 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -166,8 +166,7 @@ KernelArgumentHolder HostIrExecutor::run( communicator_->deviceId(), backend, in_tensor, - out_tensor, - expr_eval); + out_tensor); if (work != nullptr) { work->wait(); } @@ -504,8 +503,7 @@ void HostIrEvaluator::handle(Communication* communication) { communicator_->deviceId(), backend, input_tensor, - output_tensor, - expr_evaluator_); + output_tensor); } void HostIrEvaluator::handle(P2PCommunication* communication) { From 0e64863c3cfec91660c6a3c2344aaf1b2ac22b95 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 16 May 2025 16:54:54 -0700 Subject: [PATCH 17/21] scatter --- csrc/multidevice/communication.cpp | 28 ++++++----- tests/cpp/test_multidevice_communications.cpp | 34 ++++++++------ .../test_multidevice_lower_communication.cpp | 46 +++++++++++++++++++ 3 files changed, 84 insertions(+), 24 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index a1dc55be5a5..3a534bd5ad8 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -378,33 +378,39 @@ c10::intrusive_ptr postScatter( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - + NVF_ERROR( + isTvContiguous(communication->in()), "Input tensor is not contiguous"); + NVF_ERROR( + isTvContiguous(communication->out()), "Output tensor is not contiguous"); + auto output_device_mesh = communication->out()->getDeviceMesh(); bool output_has_root = output_device_mesh.has(communication->root()); auto root_relative_index = communication->getRootRelativeIndex(); std::vector> input_tensors; - std::vector output_tensors({output_tensor.as_strided({output_tensor.numel()}, {1})}); - + + std::vector output_tensors( + {output_tensor.as_strided({output_tensor.numel()}, {1})}); + if (my_device_index == communication->root()) { - auto splits = at::tensor_split(input_tensor.as_strided({input_tensor.numel()}, {1}), output_device_mesh.size(), /*dim=*/0); - if (!output_has_root) { - output_tensors[0] = at::empty_like(splits.at(0)); - } + auto splits = at::tensor_split( + input_tensor.as_strided({input_tensor.numel()}, {1}), + output_device_mesh.size(), + /*dim=*/0); input_tensors.resize(1); int64_t j = 0; for (auto i : arange(communication->team().size())) { if (root_relative_index == static_cast(i) && !output_has_root) { - input_tensors.front().push_back(output_tensor); + input_tensors.front().push_back(at::empty_like(splits.at(0))); continue; } input_tensors.front().push_back(splits.at(j)); j++; - } + } - assertBufferCount(input_tensors[0], communication->team().size()); - assertBuffersHaveSameSize(input_tensors[0], output_tensors); + assertBufferCount(input_tensors[0], communication->team().size()); + assertBuffersHaveSameSize(input_tensors[0], output_tensors); } return backend->scatter( diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 93f372c735d..18f8439372f 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -461,40 +461,48 @@ TEST_P(CommunicationTest, ScatterLoopSplit) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const auto d = communicator_->size(); - + DeviceMesh mesh_zero({0}); - TensorView* tv0 = makeConcreteTensor({5, d*3}); + TensorView* tv0 = makeConcreteTensor({5, d * 3}); TensorView* tv1 = set(tv0); - + tv0->setDeviceMesh(mesh_zero); tv0->outer_split(1, d); tv0->axis(1)->parallelize(ParallelType::Serial); - tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); + tv0->reorder( + {{1, 0}, + {2, 1}, + {0, 2}}); // tv0: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5] tv1->setDeviceMesh(full_mesh_); tv1->outer_split(1, d); tv1->axis(1)->parallelize(ParallelType::DIDx); - tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); - + tv1->reorder( + {{1, 0}, + {2, 1}, + {0, 2}}); // tv1: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5] + fusion->addInput(tv0); fusion->addOutput(tv1); - + for (auto tv : {tv0, tv1}) { tv->setAllocationDomain(tv->getLoopDomain(), true); } - at::Tensor unsharded_in_tensor = at::randn({d*3, 5}, tensor_options); + at::Tensor unsharded_in_tensor = + at::randn({d * 3, 5}, tensor_options).transpose(0, 1); + + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); - at::Tensor expected_output = shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); - FusionExecutorCache executor_cache(std::move(fusion)); at::Tensor out_tensor = - executor_cache.runFusionWithInputs({unsharded_in_tensor.transpose(0, 1)})[0].as(); - + executor_cache.runFusionWithInputs({unsharded_in_tensor})[0] + .as(); + testValidate( executor_cache.fusion(), {out_tensor}, - {unsharded_in_tensor.transpose(0, 1)}, + {unsharded_in_tensor}, {expected_output}, __LINE__, __FILE__); diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index a9725615380..27de92cc958 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -585,6 +585,52 @@ TEST_P(LowerCollectiveTest, AllgatherLoopSplit_Noncontig) { __FILE__); } +TEST_P(LowerCollectiveTest, ScatterLoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto d = communicator_->size(); + auto full_mesh = DeviceMesh::createForNumDevices(d); + + DeviceMesh mesh_zero({0}); + TensorView* tv0 = makeConcreteTensor({5, d * 3}); + TensorView* tv1 = set(tv0); + + tv0->setDeviceMesh(mesh_zero); + tv0->outer_split(1, d); + tv0->axis(1)->parallelize(ParallelType::Serial); + tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); + + tv1->setDeviceMesh(full_mesh); + tv1->outer_split(1, d); + tv1->axis(1)->parallelize(ParallelType::DIDx); + tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + + for (auto tv : {tv0, tv1}) { + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor unsharded_in_tensor = + at::randn({d * 3, 5}, tensor_options).transpose(0, 1); + + at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = + executor_cache.runFusionWithInputs({unsharded_in_tensor})[0] + .as(); + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {unsharded_in_tensor}, + {expected_output}, + __LINE__, + __FILE__); +} + INSTANTIATE_TEST_SUITE_P( HostIrLowering, LowerCollectiveTest, From b64449ad152babbe11a06a706ebc002e0f9ff765 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 16 May 2025 16:56:32 -0700 Subject: [PATCH 18/21] remove tests from old file --- tests/cpp/test_multidevice_communications.cpp | 99 ------------------- 1 file changed, 99 deletions(-) diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 18f8439372f..af0c0719aa7 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -409,105 +409,6 @@ TEST_P(CommunicationTest, ReduceScatter) { } } -TEST_P(CommunicationTest, AllgatherLoopSplit) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - // ProcessGroupNCCL requires the gathered axis to be outermost. - // We change the allocation of tensorviews to reflect this. - // We do not modify the logical shape of the tensorview. - // This would still require one copy on each device if the input tensor is in - // a different layout. - const auto d = communicator_->size(); - - TensorView* tv0 = makeConcreteTensor({5, d * 3}); - tv0->outer_split(1, d); - tv0->axis(1)->parallelize(ParallelType::DIDx); - tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); - // tv0: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5] - - TensorView* tv1 = set(tv0); - tv1->outer_split(1, d); - tv1->axis(1)->parallelize(ParallelType::Serial); - tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); - // tv1: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5] - - for (auto tv : {tv0, tv1}) { - tv->setDeviceMesh(full_mesh_); - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - fusion->addInput(tv0); - fusion->addOutput(tv1); - - at::Tensor unsharded_in_tensor = at::randn({d * 3, 5}, tensor_options); - at::Tensor in_tensor = - shardTensor(unsharded_in_tensor, 0, full_mesh_).transpose(0, 1); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor out_tensor = - executor_cache.runFusionWithInputs({in_tensor})[0].as(); - - testValidate( - executor_cache.fusion(), - {out_tensor}, - {in_tensor}, - {unsharded_in_tensor.transpose(0, 1)}, - __LINE__, - __FILE__); -} - -TEST_P(CommunicationTest, ScatterLoopSplit) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - const auto d = communicator_->size(); - - DeviceMesh mesh_zero({0}); - TensorView* tv0 = makeConcreteTensor({5, d * 3}); - TensorView* tv1 = set(tv0); - - tv0->setDeviceMesh(mesh_zero); - tv0->outer_split(1, d); - tv0->axis(1)->parallelize(ParallelType::Serial); - tv0->reorder( - {{1, 0}, - {2, 1}, - {0, 2}}); // tv0: Logical = [5, d*3], Loop/Allocation = [Serial(d), 3, 5] - - tv1->setDeviceMesh(full_mesh_); - tv1->outer_split(1, d); - tv1->axis(1)->parallelize(ParallelType::DIDx); - tv1->reorder( - {{1, 0}, - {2, 1}, - {0, 2}}); // tv1: Logical = [5, d*3], Loop/Allocation = [DIDx(d), 3, 5] - - fusion->addInput(tv0); - fusion->addOutput(tv1); - - for (auto tv : {tv0, tv1}) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - at::Tensor unsharded_in_tensor = - at::randn({d * 3, 5}, tensor_options).transpose(0, 1); - - at::Tensor expected_output = shardTensor(unsharded_in_tensor, 1, full_mesh_); - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor out_tensor = - executor_cache.runFusionWithInputs({unsharded_in_tensor})[0] - .as(); - - testValidate( - executor_cache.fusion(), - {out_tensor}, - {unsharded_in_tensor}, - {expected_output}, - __LINE__, - __FILE__); -} - INSTANTIATE_TEST_SUITE_P( , CommunicationTest, From 79b2a3dbdec473927ccb833b737cb081fe637105 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Fri, 16 May 2025 19:13:33 -0700 Subject: [PATCH 19/21] stride only when defined --- csrc/multidevice/communication.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 3a534bd5ad8..db440bb1c14 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -389,14 +389,18 @@ c10::intrusive_ptr postScatter( std::vector> input_tensors; - std::vector output_tensors( - {output_tensor.as_strided({output_tensor.numel()}, {1})}); + if (output_tensor.defined()) { + output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1}); + } if (my_device_index == communication->root()) { auto splits = at::tensor_split( input_tensor.as_strided({input_tensor.numel()}, {1}), output_device_mesh.size(), /*dim=*/0); + if (!output_has_root) { + output_tensor = at::empty_like(splits.at(0)); + } input_tensors.resize(1); int64_t j = 0; for (auto i : arange(communication->team().size())) { @@ -410,9 +414,11 @@ c10::intrusive_ptr postScatter( } assertBufferCount(input_tensors[0], communication->team().size()); - assertBuffersHaveSameSize(input_tensors[0], output_tensors); + assertBuffersHaveSameSize(input_tensors[0], {output_tensor}); } + std::vector output_tensors({output_tensor}); + return backend->scatter( output_tensors, input_tensors, {.rootRank = root_relative_index}); } From 8ba56a5ea45c066ddfa0083c11d9738cfb5244cc Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 19 May 2025 15:44:51 -0700 Subject: [PATCH 20/21] remove non-root scatter --- csrc/host_ir/lower_to_communication.cpp | 17 ++++++++---- csrc/multidevice/communication.cpp | 37 +++++++++++-------------- tests/cpp/test_multidevice_pipeline.cpp | 4 +-- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 49e796a679e..8416ef6cdd1 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -54,17 +54,24 @@ void lowerToScatter( TensorView* output_tv, const HostIrLowerParams& params, std::vector& comms) { - // we arbitrarily choose the first device of the sender mesh to be the root const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); NVF_ERROR( receiver_mesh.rank() == 1, "Gather only supported on a 1D mesh. Given ", receiver_mesh); - auto root = input_tv->getDeviceMesh().at(0); + + // Find a common device between input and receiver meshes to be the root + auto it = std::ranges::find_if( + input_tv->getDeviceMesh().vector(), + [&receiver_mesh](DeviceIdxType device) { + return receiver_mesh.has(device); + }); + NVF_ERROR( + it != input_tv->getDeviceMesh().vector().end(), + "No common device found between input and receiver meshes"); + DeviceIdxType root = *it; + Team team = receiver_mesh.vector(); - if (!receiver_mesh.has(root)) { - team.push_back(root); - } comms.push_back(IrBuilder::create( CommunicationType::Scatter, output_tv, diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index db440bb1c14..1bee622d4b4 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -384,43 +384,38 @@ c10::intrusive_ptr postScatter( isTvContiguous(communication->out()), "Output tensor is not contiguous"); auto output_device_mesh = communication->out()->getDeviceMesh(); - bool output_has_root = output_device_mesh.has(communication->root()); - auto root_relative_index = communication->getRootRelativeIndex(); + NVF_ERROR( + output_device_mesh.has(communication->root()), + "communication->root() ", + communication->root(), + " is not in the output device mesh ", + output_device_mesh, + "."); std::vector> input_tensors; - if (output_tensor.defined()) { - output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1}); - } + output_tensor = output_tensor.as_strided({output_tensor.numel()}, {1}); + std::vector output_tensors({output_tensor}); if (my_device_index == communication->root()) { auto splits = at::tensor_split( input_tensor.as_strided({input_tensor.numel()}, {1}), output_device_mesh.size(), /*dim=*/0); - if (!output_has_root) { - output_tensor = at::empty_like(splits.at(0)); - } + input_tensors.resize(1); - int64_t j = 0; - for (auto i : arange(communication->team().size())) { - if (root_relative_index == static_cast(i) && - !output_has_root) { - input_tensors.front().push_back(at::empty_like(splits.at(0))); - continue; - } - input_tensors.front().push_back(splits.at(j)); - j++; + for (const auto& split : splits) { + input_tensors.front().push_back(split); } - assertBufferCount(input_tensors[0], communication->team().size()); + assertBufferCount(input_tensors[0], output_device_mesh.size()); assertBuffersHaveSameSize(input_tensors[0], {output_tensor}); } - std::vector output_tensors({output_tensor}); - return backend->scatter( - output_tensors, input_tensors, {.rootRank = root_relative_index}); + output_tensors, + input_tensors, + {.rootRank = communication->getRootRelativeIndex()}); } c10::intrusive_ptr postReduce( diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 12dfed5dd43..05f1620f4c7 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -355,8 +355,8 @@ INSTANTIATE_TEST_SUITE_P( PipelineTestTwoStages, testing::Combine( testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc), - all_meshes, - all_meshes, + testing::Values(mesh0, mesh1), + testing::Values(mesh2, mesh4, mesh5), testing::Values(false), testing::Values(true), testing::Values(false), From 846fbee5571f9d46199174796ed1951bbafb67a3 Mon Sep 17 00:00:00 2001 From: root <26priya11@gmail.com> Date: Mon, 19 May 2025 17:26:09 -0700 Subject: [PATCH 21/21] minor changes --- csrc/multidevice/communication.cpp | 2 +- tests/cpp/test_multidevice_lower_communication.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 1bee622d4b4..3a2abebb264 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -409,7 +409,7 @@ c10::intrusive_ptr postScatter( } assertBufferCount(input_tensors[0], output_device_mesh.size()); - assertBuffersHaveSameSize(input_tensors[0], {output_tensor}); + assertBuffersHaveSameSize(input_tensors[0], output_tensors); } return backend->scatter( diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 27de92cc958..622d8d2450a 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -598,12 +598,12 @@ TEST_P(LowerCollectiveTest, ScatterLoopSplit) { tv0->setDeviceMesh(mesh_zero); tv0->outer_split(1, d); tv0->axis(1)->parallelize(ParallelType::Serial); - tv0->reorder({{1, 0}, {2, 1}, {0, 2}}); + tv0->reorder({2, 0, 1}); tv1->setDeviceMesh(full_mesh); tv1->outer_split(1, d); tv1->axis(1)->parallelize(ParallelType::DIDx); - tv1->reorder({{1, 0}, {2, 1}, {0, 2}}); + tv1->reorder({2, 0, 1}); fusion->addInput(tv0); fusion->addOutput(tv1);