diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 241ec261c66..70cbe085b07 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -301,7 +301,7 @@ bool isLocalSizeOne(IterDomain* id) { } // namespace -std::optional getCommunicationInfo(Expr* expr) { +CommunicationInfo getCommunicationInfo(Expr* expr) { NVF_ERROR( isResharding(expr), "getCommunicationInfo should only be called when `expr` is known to be a " @@ -319,9 +319,11 @@ std::optional getCommunicationInfo(Expr* expr) { auto* consumer = expr->outputs().at(0)->as(); std::optional communication_info = std::nullopt; - auto create_communication_info = [&](CommunicationType type, - IterDomain* p_sharded_id, - IterDomain* c_sharded_id) { + // Fill `communication_info` instead of returning the result, so we can catch + // errors when more than one DIDs have sharding changes. + auto fill_communication_info = [&](CommunicationType type, + IterDomain* p_sharded_id, + IterDomain* c_sharded_id) { NVF_ERROR( !communication_info.has_value(), "Expected at most one sharding change"); @@ -358,18 +360,18 @@ std::optional getCommunicationInfo(Expr* expr) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; - create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); + fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); } if (!p_sharded && c_sharded) { IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); - create_communication_info( + fill_communication_info( CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id); } if (p_sharded && c_sharded) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); // TODO(#4604): This is problematic for 2D sharding. - create_communication_info( + fill_communication_info( CommunicationType::SendRecv, p_logical_id, c_logical_id); } } else { @@ -383,7 +385,7 @@ std::optional getCommunicationInfo(Expr* expr) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allreduce : CommunicationType::Reduce; - create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); + fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); continue; } @@ -400,7 +402,7 @@ std::optional getCommunicationInfo(Expr* expr) { if (!c_it->second->isReduction()) { continue; } - create_communication_info( + fill_communication_info( CommunicationType::ReduceScatter, c2p_map.at(c_logical_id), c_logical_id); @@ -408,9 +410,9 @@ std::optional getCommunicationInfo(Expr* expr) { } if (!communication_info.has_value()) { - create_communication_info(CommunicationType::Broadcast, nullptr, nullptr); + fill_communication_info(CommunicationType::Broadcast, nullptr, nullptr); } - return communication_info; + return *communication_info; } namespace { @@ -472,22 +474,16 @@ bool isCommunicationLayoutCompliant(Expr* expr) { auto* producer = expr->inputs().at(0)->as(); auto* consumer = expr->outputs().at(0)->as(); - // TODO(#4552): the caller should make sure Expr is a communication so - // getCommunicationInfo always returns a valid CommunicationInfo. Retry after - // #4552 is merged. - auto communication_info = getCommunicationInfo(expr); - if (!communication_info.has_value()) { - return true; - } + CommunicationInfo communication_info = getCommunicationInfo(expr); Layout p_layout = getCommunicationLayout( - producer, communication_info->type, communication_info->p_sharded_id); + producer, communication_info.type, communication_info.p_sharded_id); if (!isCompliantWith(*canonicalizeLayout(producer), p_layout)) { return false; } Layout c_layout = getCommunicationLayout( - consumer, communication_info->type, communication_info->c_sharded_id); + consumer, communication_info.type, communication_info.c_sharded_id); if (!isCompliantWith(*canonicalizeLayout(consumer), c_layout)) { return false; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index ab75b9dd726..774e7211a11 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -37,7 +37,7 @@ bool isCommunicationLayoutCompliant(Expr* expr); // info: type and sharded IDs. We assume that the expr has been decomposed and // represented a single communication. If multiple communications are present or // 2D sharding, this function will raise an error. -std::optional getCommunicationInfo(Expr* expr); +CommunicationInfo getCommunicationInfo(Expr* expr); // Given the input/output TensorView of a communication, returns its layout // required by the communication backend (e.g. NCCL or UCC). `sharded_id` is the diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index 0dbfe15f88d..4d0103f2428 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -398,11 +398,7 @@ void InsertReshardingsPass::runPass(Fusion* fusion) { // Validate for (Expr* e : fusion->exprs()) { if (isResharding(e)) { - NVF_ERROR( - getCommunicationInfo(e).has_value(), - "After decomposition, any resharding expression is expected to be a " - "lowerable communication: ", - e); + getCommunicationInfo(e); } } } diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index e174b501e67..b96a8668cab 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -22,12 +22,11 @@ namespace nvfuser::preseg_passes { namespace { -void makeCommunicationLayoutCompliant( - Expr* expr, - CommunicationInfo communication_info) { +void makeCommunicationLayoutCompliant(Expr* expr) { auto* input = expr->inputs().at(0)->as(); auto* output = expr->outputs().at(0)->as(); + CommunicationInfo communication_info = getCommunicationInfo(expr); IterDomain* p_sharded_id = communication_info.p_sharded_id; IterDomain* c_sharded_id = communication_info.c_sharded_id; @@ -81,17 +80,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { continue; } - auto communication_info = getCommunicationInfo(expr); - // Should really be simply NVF_ERROR(communication_info.has_value()); - // - // I'll try to do that after #4552 is merged. Some of the `mesh.size() > 1` - // check in getCommunicationInfo and convertSingleOpToCommuniation will also - // need to go away for this simplification. - if (!communication_info.has_value()) { - continue; - } - - makeCommunicationLayoutCompliant(expr, *communication_info); + makeCommunicationLayoutCompliant(expr); } if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) {