From 5083750fb7ea7a8a9bf55349dc6c36ee3a95d115 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Jun 2025 14:37:22 -0700 Subject: [PATCH 1/2] Remove an unnecessary guard NVFuserTest has it already --- tests/cpp/test_multidevice_lower_communication.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 9e6d94f31d4..cd0b44eaeef 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -65,7 +65,6 @@ class LowerGatherTest public testing::WithParamInterface> {}; TEST_P(LowerGatherTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -138,7 +137,6 @@ class LowerScatterTest public testing::WithParamInterface> {}; TEST_P(LowerScatterTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -189,7 +187,6 @@ class LowerSendRecvTest public testing::WithParamInterface> {}; TEST_P(LowerSendRecvTest, ) { - EnableOptionsGuard opt_guard; const auto& [meshes, enable_host_ir_lowering] = GetParam(); const auto& [in_mesh, out_mesh] = meshes; @@ -255,7 +252,6 @@ void LowerCollectiveTest::SetUp() { // available. Therefore, we call it after the isBackendAvailable check. communicator_->setDefaultBackend(backend_type); - EnableOptionsGuard enable_options_guard; if (enable_host_ir_lowering) { EnableOptionsGuard::getCurOptions().set(EnableOption::HostIrLowering); } From 1f7ef88fc6c483a781a39db6971b2e5a0bb80b51 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 13 Jun 2025 23:02:07 -0700 Subject: [PATCH 2/2] Fix the test --- csrc/host_ir/lower_to_communication.cpp | 26 +++++++++---------- csrc/runtime/allocations.cpp | 2 +- csrc/runtime/fusion_kernel_runtime.cpp | 20 +++++++------- .../test_multidevice_lower_communication.cpp | 3 ++- 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 70cbe085b07..a6931010625 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -301,22 +301,22 @@ bool isLocalSizeOne(IterDomain* id) { } // namespace -CommunicationInfo getCommunicationInfo(Expr* expr) { +CommunicationInfo getCommunicationInfo(Expr* e) { NVF_ERROR( - isResharding(expr), - "getCommunicationInfo should only be called when `expr` is known to be a " - "communication. So `expr` should be resharding. Given: ", - expr); + isResharding(e), + "getCommunicationInfo should only be called when `e` is known to be a " + "communication. So `e` should be resharding. Given: ", + e); NVF_ERROR( - expr->isA() || expr->isA(), - "getCommunicationInfo should only be called when `expr` is known to be a " - "communication. So `expr` should be either a LoadStoreOp or a " + e->isA() || e->isA(), + "getCommunicationInfo should only be called when `e` is known to be a " + "communication. So `e` should be either a LoadStoreOp or a " "ReductionOp. Given: ", - expr); + e); - auto* producer = expr->inputs().at(0)->as(); - auto* consumer = expr->outputs().at(0)->as(); + auto* producer = e->inputs().at(0)->as(); + auto* consumer = e->outputs().at(0)->as(); std::optional communication_info = std::nullopt; // Fill `communication_info` instead of returning the result, so we can catch @@ -355,7 +355,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { const bool c_sharded = c_loop_did != nullptr && consumer_mesh.size() > 1; const bool same_mesh = producer_mesh == consumer_mesh; - if (expr->isA()) { + if (e->isA()) { if (p_sharded && !c_sharded) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allgather @@ -375,7 +375,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { CommunicationType::SendRecv, p_logical_id, c_logical_id); } } else { - NVF_ERROR(expr->isA()); + NVF_ERROR(e->isA()); if (!p_sharded) { // Not a reduction based communication. continue; diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index f947d7847f5..2bc6474e9c6 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -167,7 +167,7 @@ std::pair, std::vector> inferShape( inferred_val.hasValue(), "Could not launch kernel as program could not infer ", symbolic_size->toInlineString(), - "(", + " (", symbolic_size->toString(), ") for the buffer ", tv->toString()); diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index aa9b98b276f..dc01e96dc9b 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -483,24 +483,26 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { switch (group_to_run->schedulerType()) { case SchedulerType::Communication: { auto deviceid = Communicator::getInstance().deviceId(); - NVF_ERROR( - group_to_run->exprs().size() == 1, - "Communication segments must contain only one Expr"); - for (auto* expr : convertSingleOpToCommunication( + NVF_ERROR_EQ( + group_to_run->exprs().size(), + 1, + "Communication segments must contain only one Expr."); + for (auto* e : convertSingleOpToCommunication( ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) { NVF_ERROR( - expr->isA(), - "Exprs in a Communication group should be Communication"); + e->isA(), + "Exprs in a Communication group should be Communication: ", + e); // Allocate the recv buffers of communications - auto* communication = expr->as(); + auto* communication = e->as(); TensorView* tv = communication->out(); if (tv->getDeviceMesh().has(deviceid)) { auto* allocate = IrBuilder::create(tv, MemoryType::Global); hic->pushBackTopLevelExprs(allocate); } - hic->pushBackTopLevelExprs(expr); - auto wait = IrBuilder::create(expr->as()); + hic->pushBackTopLevelExprs(communication); + auto wait = IrBuilder::create(communication); hic->pushBackTopLevelExprs(wait); } } break; diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index cd0b44eaeef..3d32334edb3 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -774,7 +774,8 @@ INSTANTIATE_TEST_SUITE_P( LowerCollectiveTest, ::testing::Combine( testing::Values(CommunicatorBackend::kNccl, CommunicatorBackend::kUcc), - testing::Bool()), + // Can't do testing::Bool() yet due to #4230 + testing::Values(false)), ([](const testing::TestParamInfo>& info) -> std::string { const auto& [backend_type, enable_host_ir_lowering] = info.param;