From 5083750fb7ea7a8a9bf55349dc6c36ee3a95d115 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Jun 2025 14:37:22 -0700 Subject: [PATCH 1/9] Remove an unnecessary guard NVFuserTest has it already --- tests/cpp/test_multidevice_lower_communication.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 9e6d94f31d4..cd0b44eaeef 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -65,7 +65,6 @@ class LowerGatherTest public testing::WithParamInterface> {}; TEST_P(LowerGatherTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -138,7 +137,6 @@ class LowerScatterTest public testing::WithParamInterface> {}; TEST_P(LowerScatterTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -189,7 +187,6 @@ class LowerSendRecvTest public testing::WithParamInterface> {}; TEST_P(LowerSendRecvTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -255,7 +252,6 @@ void LowerCollectiveTest::SetUp() { // available. Therefore, we call it after the isBackendAvailable check. communicator_->setDefaultBackend(backend_type); - EnableOptionsGuard enable_options_guard; if (enable_host_ir_lowering) { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); } From 1f7ef88fc6c483a781a39db6971b2e5a0bb80b51 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Jun 2025 23:02:07 -0700 Subject: [PATCH 2/9] Fix the test --- csrc/host_ir/lower_to_communication.cpp | 26 +++++++++---------- csrc/runtime/allocations.cpp | 2 +- csrc/runtime/fusion_kernel_runtime.cpp | 20 +++++++------- .../test_multidevice_lower_communication.cpp | 3 ++- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 70cbe085b07..a6931010625 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -301,22 +301,22 @@ bool isLocalSizeOne(IterDomain* id) { } // namespace -CommunicationInfo getCommunicationInfo(Expr* expr) { +CommunicationInfo getCommunicationInfo(Expr* e) { NVF_ERROR( - isResharding(expr), - "getCommunicationInfo should only be called when `expr` is known to be a " - "communication. So `expr` should be resharding. Given: ", - expr); + isResharding(e), + "getCommunicationInfo should only be called when `e` is known to be a " + "communication. So `e` should be resharding. Given: ", + e); NVF_ERROR( - expr->isA() || expr->isA(), - "getCommunicationInfo should only be called when `expr` is known to be a " - "communication. So `expr` should be either a LoadStoreOp or a " + e->isA() || e->isA(), + "getCommunicationInfo should only be called when `e` is known to be a " + "communication. So `e` should be either a LoadStoreOp or a " "ReductionOp. Given: ", - expr); + e); - auto* producer = expr->inputs().at(0)->as(); - auto* consumer = expr->outputs().at(0)->as(); + auto* producer = e->inputs().at(0)->as(); + auto* consumer = e->outputs().at(0)->as(); std::optional communication_info = std::nullopt; // Fill `communication_info` instead of returning the result, so we can catch @@ -355,7 +355,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { const bool c_sharded = c_loop_did != nullptr && consumer_mesh.size() > 1; const bool same_mesh = producer_mesh == consumer_mesh; - if (expr->isA()) { + if (e->isA()) { if (p_sharded && !c_sharded) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allgather @@ -375,7 +375,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { CommunicationType::SendRecv, p_logical_id, c_logical_id); } } else { - NVF_ERROR(expr->isA()); + NVF_ERROR(e->isA()); if (!p_sharded) { // Not a reduction based communication. continue; diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index f947d7847f5..2bc6474e9c6 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -167,7 +167,7 @@ std::pair, std::vector> inferShape( inferred_val.hasValue(), "Could not launch kernel as program could not infer ", symbolic_size->toInlineString(), - "(", + " (", symbolic_size->toString(), ") for the buffer ", tv->toString()); diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index aa9b98b276f..dc01e96dc9b 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -483,24 +483,26 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { switch (group_to_run->schedulerType()) { case SchedulerType::Communication: { auto deviceid = Communicator::getInstance().deviceId(); - NVF_ERROR( - group_to_run->exprs().size() == 1, - "Communication segments must contain only one Expr"); - for (auto* expr : convertSingleOpToCommunication( + NVF_ERROR_EQ( + group_to_run->exprs().size(), + 1, + "Communication segments must contain only one Expr."); + for (auto* e : convertSingleOpToCommunication( ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) { NVF_ERROR( - expr->isA(), - "Exprs in a Communication group should be Communication"); + e->isA(), + "Exprs in a Communication group should be Communication: ", + e); // Allocate the recv buffers of communications - auto* communication = expr->as(); + auto* communication = e->as(); TensorView* tv = communication->out(); if (tv->getDeviceMesh().has(deviceid)) { auto* allocate = IrBuilder::create(tv, MemoryType::Global); hic->pushBackTopLevelExprs(allocate); } - hic->pushBackTopLevelExprs(expr); - auto wait = IrBuilder::create(expr->as()); + hic->pushBackTopLevelExprs(communication); + auto wait = IrBuilder::create(communication); hic->pushBackTopLevelExprs(wait); } } break; diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index cd0b44eaeef..3d32334edb3 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -774,7 +774,8 @@ INSTANTIATE_TEST_SUITE_P( LowerCollectiveTest, ::testing::Combine( testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc), - testing::Bool()), + // Can't do testing::Bool() yet due to #4230 + testing::Values(false)), ([](const testing::TestParamInfo>& info) -> std::string { const auto& [backend_type, enable_host_ir_lowering] = info.param; From 91daa0f2cacc97b1aa1742985c29e16874e05605 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 20 Nov 2024 17:34:49 -0800 Subject: [PATCH 3/9] Make ReshardingTest.Insert a MultiDeviceTest. This makes the test more realistic and gives better coverage. --- tests/cpp/multidevice.h | 18 +++ .../test_multidevice_lower_communication.cpp | 27 +--- tests/cpp/test_multidevice_sharding.cpp | 59 +++++++++ tests/cpp/test_resharding.cpp | 120 +----------------- 4 files changed, 87 insertions(+), 137 deletions(-) diff --git a/tests/cpp/multidevice.h b/tests/cpp/multidevice.h index 207c206fe6c..046ca21fd8a 100644 --- a/tests/cpp/multidevice.h +++ b/tests/cpp/multidevice.h @@ -51,4 +51,22 @@ class MultiDeviceTest : public NVFuserTest { bool disable_skip; }; +// This macro is supposed to be used in a test case of a MultiDeviceTest or its +// `SetUp` method, which have access to GTEST_SKIP and communicator_. It's not +// made a function because that function wouldn't be able to skip the test by +// calling GTEST_SKIP. +#define SKIP_IF_NOT_ENOUGH_DEVICES(fusion) \ + do { \ + const auto num_devices = communicator_->size(); \ + for (auto* tv : fusion->allTvs()) { \ + const DeviceMesh& mesh = tv->getDeviceMesh(); \ + for (const auto device_id : mesh.vector()) { \ + if (device_id >= num_devices) { \ + GTEST_SKIP() << tv->toString() << ") requires more than " \ + << num_devices << " devices."; \ + } \ + } \ + } \ + } while (0) + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 3d32334edb3..ef4d1c88c1b 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -41,21 +41,6 @@ void assertIsCompiledToHostIrContainer( } } // namespace -// This is made a macro instead of a function, because GTEST_SKIP can only be -// used in individual test cases or `SetUp` methods. -#define SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh) \ - do { \ - const auto num_devices = communicator_->size(); \ - for (const auto& mesh : {in_mesh, out_mesh}) { \ - for (const auto device_id : mesh.vector()) { \ - if (device_id >= num_devices) { \ - GTEST_SKIP() << "Mesh (" << mesh << ") requires more than " \ - << num_devices << " devices."; \ - } \ - } \ - } \ - } while (0) - using InOutMesh = std::pair; static constexpr int kTensorSize = 4; @@ -72,8 +57,6 @@ TEST_P(LowerGatherTest, ) { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); } - SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh); - auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -86,6 +69,8 @@ TEST_P(LowerGatherTest, ) { out->setDeviceMesh(out_mesh); in->axis(0)->parallelize(ParallelType::DIDx); + SKIP_IF_NOT_ENOUGH_DEVICES(fusion); + const auto device_id = communicator_->deviceId(); at::Tensor unsharded_tensor = at::randn({in_mesh.size(), kTensorSize}, tensor_options); @@ -144,8 +129,6 @@ TEST_P(LowerScatterTest, ) { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); } - SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh); - auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -158,6 +141,8 @@ TEST_P(LowerScatterTest, ) { out->setDeviceMesh(out_mesh); out->axis(0)->parallelize(ParallelType::DIDx); + SKIP_IF_NOT_ENOUGH_DEVICES(fusion); + const auto device_id = communicator_->deviceId(); at::Tensor unsharded_tensor = at::randn({out_mesh.size(), kTensorSize}, tensor_options); @@ -194,8 +179,6 @@ TEST_P(LowerSendRecvTest, ) { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); } - SKIP_IF_NOT_ENOUGH_DEVICES(in_mesh, out_mesh); - auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -210,6 +193,8 @@ TEST_P(LowerSendRecvTest, ) { in->axis(0)->parallelize(ParallelType::DIDx); out->axis(0)->parallelize(ParallelType::DIDx); + SKIP_IF_NOT_ENOUGH_DEVICES(fusion); + const auto device_id = communicator_->deviceId(); at::Tensor unsharded_tensor = at::randn({in_mesh.size(), kTensorSize}, tensor_options); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index f0a74cfcf00..854a2866c53 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -735,6 +735,65 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } +using InsertReshardingTestParams = std::tuple; + +class InsertReshardingTest + : public MultiDeviceTest, + public testing::WithParamInterface {}; + +TEST_P(InsertReshardingTest, Execute) { + auto [mesh, is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_sharded] = + GetParam(); + constexpr int64_t kShardedAxis = 1; + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = mul(tv0, tv0); + TensorView* tv2 = add(tv0, tv1); + TensorView* tv3 = sum(tv2, {kShardedAxis}); + TensorView* tv4 = broadcast(tv3, {false, true, false}); + TensorView* tv5 = mul(tv2, tv4); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + fusion->addOutput(tv5); + + for (auto* tv : {tv0, tv1, tv2, tv3, tv4, tv5}) { + tv->setDeviceMesh(mesh); + } + + if (is_tv0_tv5_sharded) { + tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + // tv3->axis(kShardedAxis) is a reduction, so don't shard it. + tv5->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + } + if (is_tv1_tv4_sharded) { + tv1->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + tv4->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + } + if (is_tv2_sharded) { + tv2->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + } + + SKIP_IF_NOT_ENOUGH_DEVICES(fusion); + + FusionExecutorCache executor_cache(std::move(fusion)); + executor_cache.runFusionWithInputs({at::randn( + {2, is_tv0_tv5_sharded ? 1 : mesh.size(), 5}, tensor_options)}); +} + +INSTANTIATE_TEST_SUITE_P( + , + InsertReshardingTest, + ::testing::Combine( + ::testing::ValuesIn( + std::vector({{0}, {0, 1}, {0, 1, 2, 3}})), + ::testing::Bool(), + ::testing::Bool(), + ::testing::Bool())); + TEST_F(MultiDeviceTest, TransformPropagatorSplitReshape) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index 84a2494126d..0bb8ee98cbe 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -5,6 +5,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include +#include + #include #include @@ -22,9 +25,6 @@ #include #include -#include -#include - namespace nvfuser { using testing::Each; @@ -32,50 +32,7 @@ using testing::IsEmpty; using testing::IsFalse; using testing::ResultOf; -using ReshardingTestParams = - std::tuple; - -class ReshardingTest : public NVFuserFixtureParamTest { - protected: - void SetUp() override { - fusion_ = std::make_unique(); - fg_ = std::make_unique(fusion_.get()); - } - void validate() { - // TODO(wujingyue): after preseg passes are integrated to - // FusionExecutorCache, simplify validation by using - // FusionExecutorCache::getMostRecentKernelRuntime()->fusionSegments()->groups(). - for (auto expr : fusion_->exprs()) { - EXPECT_TRUE(HostIrLower::canLower(expr)) << "on expr: " << expr; - } - - SegmentCandidateFinderOptions options{ - .run_translate_welford = false, - .run_combine_reductions = false, - .run_herrmann_merge = true, - .run_final_merge = true, - .custom_should_merge_groups = &HostIrLower::shouldMergeSegmentedGroups}; - - auto segmented_fusion = SegmentCandidateFinder::segment( - std::move(fusion_), KernelArgumentHolder(), options, true); - - for (SegmentedGroup* group : segmented_fusion->groups()) { - // TODO: use EXPECT_THAT. - EXPECT_TRUE( - std::none_of( - group->exprs().begin(), - group->exprs().end(), - [](auto expr) { return isResharding(expr); }) || - (group->exprs().size() == 1 && isResharding(group->exprs().at(0)))); - } - // checks that the segments are disjoints and that the graph of segment is - // acyclic - segmented_fusion->validateDisjoint(); - } - - std::unique_ptr fusion_; - std::unique_ptr fg_; -}; +using ReshardingTest = NVFuserTest; TEST_F(ReshardingTest, SplitingView) { const int b = 2, s = 3, h = 96, e = 128; @@ -608,75 +565,6 @@ TEST_F(ReshardingTest, InsertShardedAxisReordering) { } } -TEST_P(ReshardingTest, Insert) { - if (!distributedEnabled()) { // Test only works with distributed - GTEST_SKIP() << "Requires distributed API"; - } - auto - [mesh0, - mesh1, - mesh2, - is_tv0_tv5_sharded, - is_tv1_tv4_sharded, - is_tv2_sharded] = GetParam(); - constexpr int64_t kShardedAxis = 1; - - TensorView* tv0 = makeContigTensor(3); - TensorView* tv1 = binaryOp(BinaryOpType::Mul, tv0, tv0); - TensorView* tv2 = binaryOp(BinaryOpType::Add, tv0, tv1); - TensorView* tv3 = sum(tv2, {kShardedAxis}); - TensorView* tv4 = broadcast(tv3, {false, true, false}); - TensorView* tv5 = binaryOp(BinaryOpType::Mul, tv2, tv4); - - tv0->setDeviceMesh(mesh0); - tv1->setDeviceMesh(mesh1); - tv2->setDeviceMesh(mesh2); - tv3->setDeviceMesh(mesh0); - tv4->setDeviceMesh(mesh1); - tv5->setDeviceMesh(mesh0); - fusion_->addInput(tv0); - fusion_->addOutput(tv1); - fusion_->addOutput(tv5); - - if (is_tv0_tv5_sharded) { - tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - // tv3->axis(kShardedAxis) is a reduction, so don't shard it. - tv5->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - } - if (is_tv1_tv4_sharded) { - tv1->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - tv4->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - } - if (is_tv2_sharded) { - tv2->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - } - - preseg_passes::OptimizationPass< - preseg_passes::InsertReshardingsPass>::runPass(fusion_.get()); - preseg_passes::OptimizationPass< - preseg_passes::ReorderShardedAxisPass>::runPass(fusion_.get()); - validate(); -} - -namespace { - -DeviceMesh Mesh0({0}); -DeviceMesh Mesh1({1, 2}); -DeviceMesh Mesh2({0, 1, 2, 3}); - -} // namespace - -INSTANTIATE_TEST_SUITE_P( - , - ReshardingTest, - testing::Combine( - testing::Values(Mesh0, Mesh2), - testing::Values(Mesh1, Mesh2), - testing::Values(Mesh2), - testing::Bool(), - testing::Bool(), - testing::Bool())); - using ReshardingSelectOpTest = NVFuserTest; TEST_F(ReshardingSelectOpTest, NonResharding) { From 14ac2185091cd225356a7b5681385a146a1b2fe8 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Jun 2025 14:31:38 -0700 Subject: [PATCH 4/9] Unnecessary dependency --- tests/cpp/test_resharding.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/test_resharding.cpp b/tests/cpp/test_resharding.cpp index 0bb8ee98cbe..05ef7ef3f6a 100644 --- a/tests/cpp/test_resharding.cpp +++ b/tests/cpp/test_resharding.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include From a3b6344309828570c41f7e26f2f713f5bc642275 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Jun 2025 20:53:25 -0700 Subject: [PATCH 5/9] Test passing --- tests/cpp/test_multidevice_sharding.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 854a2866c53..9341083f477 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -735,15 +735,14 @@ TEST_F(MultiDeviceTest, ReorderDIDToFront) { __FILE__); } -using InsertReshardingTestParams = std::tuple; +using InsertReshardingTestParams = std::tuple; class InsertReshardingTest : public MultiDeviceTest, public testing::WithParamInterface {}; TEST_P(InsertReshardingTest, Execute) { - auto [mesh, is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_sharded] = - GetParam(); + auto [is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_sharded] = GetParam(); constexpr int64_t kShardedAxis = 1; auto fusion = std::make_unique(); @@ -760,13 +759,14 @@ TEST_P(InsertReshardingTest, Execute) { fusion->addOutput(tv1); fusion->addOutput(tv5); + auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); for (auto* tv : {tv0, tv1, tv2, tv3, tv4, tv5}) { tv->setDeviceMesh(mesh); } if (is_tv0_tv5_sharded) { tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - // tv3->axis(kShardedAxis) is a reduction, so don't shard it. + tv3->axis(kShardedAxis)->parallelize(ParallelType::DIDx); tv5->axis(kShardedAxis)->parallelize(ParallelType::DIDx); } if (is_tv1_tv4_sharded) { @@ -788,8 +788,6 @@ INSTANTIATE_TEST_SUITE_P( , InsertReshardingTest, ::testing::Combine( - ::testing::ValuesIn( - std::vector({{0}, {0, 1}, {0, 1, 2, 3}})), ::testing::Bool(), ::testing::Bool(), ::testing::Bool())); From 5672034656d8cdb83b93fbb15edad334b4e31e72 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Jun 2025 21:17:37 -0700 Subject: [PATCH 6/9] Shard tv3 the same way as tv2 --- tests/cpp/test_multidevice_sharding.cpp | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 9341083f477..a409b91d0fc 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -743,7 +743,6 @@ class InsertReshardingTest TEST_P(InsertReshardingTest, Execute) { auto [is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_sharded] = GetParam(); - constexpr int64_t kShardedAxis = 1; auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -751,7 +750,7 @@ TEST_P(InsertReshardingTest, Execute) { TensorView* tv0 = makeContigTensor(3); TensorView* tv1 = mul(tv0, tv0); TensorView* tv2 = add(tv0, tv1); - TensorView* tv3 = sum(tv2, {kShardedAxis}); + TensorView* tv3 = sum(tv2, {1}); TensorView* tv4 = broadcast(tv3, {false, true, false}); TensorView* tv5 = mul(tv2, tv4); @@ -765,20 +764,18 @@ TEST_P(InsertReshardingTest, Execute) { } if (is_tv0_tv5_sharded) { - tv0->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - tv3->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - tv5->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + tv0->axis(1)->parallelize(ParallelType::DIDx); + tv5->axis(1)->parallelize(ParallelType::DIDx); } if (is_tv1_tv4_sharded) { - tv1->axis(kShardedAxis)->parallelize(ParallelType::DIDx); - tv4->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + tv1->axis(1)->parallelize(ParallelType::DIDx); + tv4->axis(1)->parallelize(ParallelType::DIDx); } - if (is_tv2_sharded) { - tv2->axis(kShardedAxis)->parallelize(ParallelType::DIDx); + if (is_tv2_tv3_sharded) { + tv2->axis(1)->parallelize(ParallelType::DIDx); + tv3->axis(1)->parallelize(ParallelType::DIDx); } - SKIP_IF_NOT_ENOUGH_DEVICES(fusion); - FusionExecutorCache executor_cache(std::move(fusion)); executor_cache.runFusionWithInputs({at::randn( {2, is_tv0_tv5_sharded ? 1 : mesh.size(), 5}, tensor_options)}); From 57f1ba4e168eb451b4d783924b8973f84a4f20ef Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Jun 2025 23:16:18 -0700 Subject: [PATCH 7/9] Verify output --- tests/cpp/test_multidevice_sharding.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index a409b91d0fc..ced39bd2ef6 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -742,7 +742,8 @@ class InsertReshardingTest public testing::WithParamInterface {}; TEST_P(InsertReshardingTest, Execute) { - auto [is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_sharded] = GetParam(); + auto [is_tv0_tv5_sharded, is_tv1_tv4_sharded, is_tv2_tv3_sharded] = + GetParam(); auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -755,7 +756,6 @@ TEST_P(InsertReshardingTest, Execute) { TensorView* tv5 = mul(tv2, tv4); fusion->addInput(tv0); - fusion->addOutput(tv1); fusion->addOutput(tv5); auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); @@ -776,9 +776,19 @@ TEST_P(InsertReshardingTest, Execute) { tv3->axis(1)->parallelize(ParallelType::DIDx); } + at::Tensor t0 = at::randint(3, {2, mesh.size(), 5}, tensor_options); + at::Tensor t1 = t0 * t0; + at::Tensor t2 = t0 + t1; + at::Tensor t5 = t2 * t2.sum({1}, /*keepdim=*/true); + FusionExecutorCache executor_cache(std::move(fusion)); - executor_cache.runFusionWithInputs({at::randn( - {2, is_tv0_tv5_sharded ? 1 : mesh.size(), 5}, tensor_options)}); + if (is_tv0_tv5_sharded) { + t0 = shardTensor(t0, 1, mesh); + t5 = shardTensor(t5, 1, mesh); + } + + auto outs = executor_cache.runFusionWithInputs({t0}); + testValidate(executor_cache.fusion(), outs, {t0}, {t5}, __LINE__, __FILE__); } INSTANTIATE_TEST_SUITE_P( From 615367271c5b2bbb6e3d1897d8d2299a7ca0cb5a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Jun 2025 23:21:37 -0700 Subject: [PATCH 8/9] Comment --- tests/cpp/test_multidevice_sharding.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index ced39bd2ef6..dbdf29707a9 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -756,6 +756,8 @@ TEST_P(InsertReshardingTest, Execute) { TensorView* tv5 = mul(tv2, tv4); fusion->addInput(tv0); + // Due to #4642, we can't add tv1 as output. + // fusion->addOutput(tv1); fusion->addOutput(tv5); auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); @@ -777,8 +779,7 @@ TEST_P(InsertReshardingTest, Execute) { } at::Tensor t0 = at::randint(3, {2, mesh.size(), 5}, tensor_options); - at::Tensor t1 = t0 * t0; - at::Tensor t2 = t0 + t1; + at::Tensor t2 = t0 + t0 * t0; at::Tensor t5 = t2 * t2.sum({1}, /*keepdim=*/true); FusionExecutorCache executor_cache(std::move(fusion)); From 913c09d1ac7ecff13359e28fdaf896083e546dd2 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sun, 15 Jun 2025 23:44:08 -0700 Subject: [PATCH 9/9] Remove HostIrLower::canLower as it's no longer used --- csrc/host_ir/lower.cpp | 38 -------------------------------------- csrc/host_ir/lower.h | 5 ----- 2 files changed, 43 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index b188a14b47c..0c9ee7b061a 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -27,44 +27,6 @@ namespace nvfuser { -bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { - if (!isResharding(expr)) { - return true; - } - if (!ir_utils::isTvOp(expr)) { - return false; - } - if (auto* reduction = dynamic_cast(expr)) { - if (!ignore_inner_resharding && !isCommunicationLayoutCompliant(expr)) { - return false; - } - auto in = reduction->in()->as(); - auto out = reduction->out()->as(); - // get the reduced axis - std::vector reduction_axis; - std::copy_if( - out->getLogicalDomain().begin(), - out->getLogicalDomain().end(), - std::back_inserter(reduction_axis), - [](IterDomain* id) { return id->isReduction(); }); - // check whether the reduction involves only one axis - if (reduction_axis.size() != 1) { - return false; - } - // We check whether the reduced axis is sharded on the input - const auto c2p_map = - PairwiseLogicalDomainMap(in, out).mapConsumerToProducer(); - auto c2p_map_it = c2p_map.find(reduction_axis.at(0)); - return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim(); - } else if (auto* ldst = dynamic_cast(expr)) { - if (!ignore_inner_resharding && !isCommunicationLayoutCompliant(expr)) { - return false; - } - return ldst->as()->opType() == LoadStoreOpType::Set; - } - return false; -} - bool HostIrLower::isLowerableAsStandaloneHostOp(Expr* expr) { if (expr->isOneOf< MatmulOp, diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 8df156d4512..8bc4f37eca2 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -24,11 +24,6 @@ class HostIrLower { explicit HostIrLower(const HostIrLowerParams& params = HostIrLowerParams()) : params_(params) {} - // The flag `ignore_inner_resharding` is useful because the preseg passes - // `InsertReshardingsPass` and `ReorderShardedAxisPass` want different - // behaviors - static bool canLower(Expr* expr, bool ignore_inner_resharding = false); - // Lower a sharded Expr into a series of Communication. std::vector lower(Expr* c, DeviceIdxType my_device_index);