From 9bfc51986a0d22783e50da3e59758c525607801b Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 30 May 2025 16:46:26 -0700 Subject: [PATCH 1/8] Remove dependencies on host_ir/lower from FusionExecutorCache --- csrc/host_ir/executor.cpp | 16 ++---- csrc/host_ir/lower_to_communication.cpp | 50 +++++++++---------- csrc/host_ir/lower_to_communication.h | 6 +-- .../pass/convert_op_to_communication.cpp | 2 +- csrc/multidevice/utils.cpp | 1 - csrc/preseg_passes/reorder_sharded_axis.cpp | 3 +- csrc/runtime/fusion_kernel_runtime.cpp | 5 +- csrc/scheduler/communication.cpp | 9 ---- csrc/scheduler/no_op.cpp | 1 - 9 files changed, 35 insertions(+), 58 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index bebc1326d41..01863910a50 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -17,7 +17,6 @@ #include #include #include -#include #include #include #include @@ -46,16 +45,9 @@ HostIrExecutor::HostIrExecutor( bool HostIrExecutor::supported(Fusion* fusion) { FUSER_PERF_SCOPE("HostIrExecutor::supported"); std::vector exprs = fusion->exprs(); - if (std::any_of(exprs.begin(), exprs.end(), [](Expr* e) { - return isResharding(e) && HostIrLower::canLower(e); - })) { + if (std::any_of(exprs.begin(), exprs.end(), isResharding)) { NVF_ERROR( - std::all_of( - exprs.begin(), - exprs.end(), - [](Expr* e) { - return isResharding(e) && HostIrLower::canLower(e); - }), + std::all_of(exprs.begin(), exprs.end(), isResharding), "Could not execute fusion as all expressions in a host IR container must be communication based at this point."); return true; } @@ -81,8 +73,8 @@ void HostIrExecutor::compile(Fusion* fusion) { std::vector exprs = fusion->exprs(); DeviceIdxType my_device_idx = communicator_ ? communicator_->deviceId() : 0; for (Expr* e : exprs) { - std::vector communications = convertSingleOpToCommunication( - cloner.clone(e), my_device_idx, HostIrLowerParams()); + std::vector communications = + convertSingleOpToCommunication(cloner.clone(e), my_device_idx); for (auto* communication : communications) { host_ir_container_->pushBackTopLevelExprs(communication); } diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 40e9162ec6f..781f73a3150 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -52,7 +52,7 @@ inline c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { void lowerToScatter( TensorView* input_tv, TensorView* output_tv, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); NVF_ERROR( @@ -79,7 +79,7 @@ void lowerToScatter( team, root, c10d::ReduceOp::RedOpType::UNUSED, - params.communicator_backend)); + backend)); } /* @@ -91,7 +91,7 @@ need multiple Gathers if the tensor is replicated in the receiver mesh. void lowerToGather( TensorView* input_tv, TensorView* output_tv, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms) { // we create as many 'Gathers' as there are devices in the receiver mesh const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); @@ -111,7 +111,7 @@ void lowerToGather( team, root, c10d::ReduceOp::RedOpType::UNUSED, - params.communicator_backend)); + backend)); } } @@ -119,7 +119,7 @@ void lowerToGather( void lowerToAllgather( TensorView* input_tv, TensorView* output_tv, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms, DeviceIdxType my_device_idx) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); @@ -131,7 +131,7 @@ void lowerToAllgather( team, /*root=*/-1, c10d::ReduceOp::RedOpType::UNUSED, - params.communicator_backend)); + backend)); } // Adds one or zero Broadcast communication to the vector 'comms' @@ -139,7 +139,7 @@ void lowerToBroadcast( TensorView* input_tv, TensorView* output_tv, DeviceIdxType root, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& mesh = output_tv->getDeviceMesh(); NVF_ERROR( @@ -155,7 +155,7 @@ void lowerToBroadcast( team, root, c10d::ReduceOp::RedOpType::UNUSED, - params.communicator_backend)); + backend)); } // Adds several Broadcast or SendRecv communications to the vector 'comms' @@ -165,7 +165,7 @@ void lowerToBroadcast( void lowerToBroadcastOrSendRecv( TensorView* input_tv, TensorView* output_tv, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); @@ -196,7 +196,7 @@ void lowerToBroadcastOrSendRecv( Team({sender, receiver}), /*root=*/sender, c10d::ReduceOp::RedOpType::UNUSED, - params.communicator_backend)); + backend)); } } else { // Either of the following two cases is happening. @@ -209,7 +209,7 @@ void lowerToBroadcastOrSendRecv( input_tv, output_tv, /*root=*/sender_mesh.at(0), - params, + backend, comms); } } @@ -218,7 +218,7 @@ void lowerToReduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms) { const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); @@ -244,7 +244,7 @@ void lowerToReduce( team, root, reduce_op_type, - params.communicator_backend)); + backend)); } } @@ -252,7 +252,7 @@ void lowerToAllreduce( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms, DeviceIdxType my_device_idx) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); @@ -264,14 +264,14 @@ void lowerToAllreduce( team, /*root=*/-1, getC10dReduceOpType(op_type), - params.communicator_backend)); + backend)); } void lowerToReduceScatter( TensorView* input_tv, TensorView* output_tv, BinaryOpType op_type, - const HostIrLowerParams& params, + const CommunicatorBackend backend, std::vector& comms, DeviceIdxType my_device_idx) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); @@ -284,7 +284,7 @@ void lowerToReduceScatter( /*team=*/team, /*root=*/-1, getC10dReduceOpType(op_type), - params.communicator_backend)); + backend)); } } // namespace @@ -292,7 +292,7 @@ void lowerToReduceScatter( std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, - const HostIrLowerParams& params) { + const CommunicatorBackend backend) { FusionGuard fg(c->fusion()); std::vector comms; @@ -339,26 +339,26 @@ std::vector convertSingleOpToCommunication( "ReduceScatter operation must have the same sender and receiver device mesh. " "Insert a Set operation before or after the reduction to reshard ot another device mesh"); lowerToReduceScatter( - input_tv, output_tv, op_type, params, comms, my_device_idx); + input_tv, output_tv, op_type, backend, comms, my_device_idx); } else { if (same_mesh) { lowerToAllreduce( - input_tv, output_tv, op_type, params, comms, my_device_idx); + input_tv, output_tv, op_type, backend, comms, my_device_idx); } else { - lowerToReduce(input_tv, output_tv, op_type, params, comms); + lowerToReduce(input_tv, output_tv, op_type, backend, comms); } } } else { if (!is_input_sharded && is_output_sharded) { - lowerToScatter(input_tv, output_tv, params, comms); + lowerToScatter(input_tv, output_tv, backend, comms); } else if (is_input_sharded && !is_output_sharded) { if (same_mesh) { - lowerToAllgather(input_tv, output_tv, params, comms, my_device_idx); + lowerToAllgather(input_tv, output_tv, backend, comms, my_device_idx); } else { - lowerToGather(input_tv, output_tv, params, comms); + lowerToGather(input_tv, output_tv, backend, comms); } } else { - lowerToBroadcastOrSendRecv(input_tv, output_tv, params, comms); + lowerToBroadcastOrSendRecv(input_tv, output_tv, backend, comms); } } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 1edef39b180..4ec448e671c 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -7,14 +7,14 @@ // clang-format on #pragma once -#include -#include +#include +#include namespace nvfuser { std::vector convertSingleOpToCommunication( Expr* c, DeviceIdxType my_device_idx, - const HostIrLowerParams& params); + const CommunicatorBackend backend = CommunicatorBackend::kNccl); } // namespace nvfuser diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp index 06822b8295f..a91b17b966a 100644 --- a/csrc/host_ir/pass/convert_op_to_communication.cpp +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -32,7 +32,7 @@ void ConvertOpToCommunication::passImplementation(Fusion* fusion) { return new_top_level_exprs.push_back(top_level_expr); } for (auto* expr : nvfuser::convertSingleOpToCommunication( - top_level_expr, my_device_index, params_)) { + top_level_expr, my_device_index, params_.communicator_backend)) { // Allocate the recv buffers of communications if (expr->isA()) { auto* communication = expr->as(); diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 9b42a636f77..9df89e973bd 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index 75d2b27ca72..c1b5b188fd4 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -25,7 +24,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { const std::vector& exprs = fusion->exprs(); for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) { Expr* expr = *it; - if (HostIrLower::canLower(expr)) { + if (!isResharding(expr)) { continue; } NVF_ERROR( diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 4ac71985453..5029c2ecae1 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include @@ -483,9 +482,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { group_to_run->exprs().size() == 1, "Communication segments must contain only one Expr"); for (auto* expr : convertSingleOpToCommunication( - ir_cloner.clone(group_to_run->exprs().at(0)), - deviceid, - HostIrLowerParams())) { + ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) { NVF_ERROR( expr->isA(), "Exprs in a Communication group should be Communication"); diff --git a/csrc/scheduler/communication.cpp b/csrc/scheduler/communication.cpp index 1703a729ac8..e60af1cad45 100644 --- a/csrc/scheduler/communication.cpp +++ b/csrc/scheduler/communication.cpp @@ -7,7 +7,6 @@ // clang-format on #include -#include #include #include #include @@ -38,14 +37,6 @@ bool CommunicationScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - if (!HostIrLower::canLower(e)) { - scheduler_debug_utils::canScheduleRejectReason( - schedulerType(), - "Failed to lower the expression to host IR: ", - e->toString()); - return false; - } - return true; } diff --git a/csrc/scheduler/no_op.cpp b/csrc/scheduler/no_op.cpp index 59233d901db..664abdfb758 100644 --- a/csrc/scheduler/no_op.cpp +++ b/csrc/scheduler/no_op.cpp @@ -6,7 +6,6 @@ */ // clang-format on -#include #include #include #include From 6c4fb733c36d8ccf3475bec72610195dcee92877 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 4 Jun 2025 21:43:04 -0700 Subject: [PATCH 2/8] Remove one more dependency --- csrc/host_ir/lower_to_communication.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 82cfcefac55..ac344c78319 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -499,11 +499,6 @@ std::vector convertSingleOpToCommunication( const bool is_output_sharded = isSharded(output_tv) && receiver_mesh.size() > 1; - NVF_ERROR( - HostIrLower::canLower(c), - "Lowering expression ", - c->toString(), - " to communication is not supported"); NVF_ERROR( isCommunicationLayoutCompliant(c), "Resharding on an inner axis is not lowerable ", From aef4377ffa1970623b630f3b794bf530e01eb8a4 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 4 Jun 2025 21:59:28 -0700 Subject: [PATCH 3/8] Fix lint --- csrc/host_ir/lower_to_communication.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index f2e53cf6f74..a663c982f7c 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -7,7 +7,12 @@ // clang-format on #pragma once +#include + #include +#include +#include +#include #include namespace nvfuser { From e075c22f05ad7f422cc2b9e7b4ea3e2ddff5a4bd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 6 Jun 2025 10:43:24 -0700 Subject: [PATCH 4/8] WIP --- csrc/host_ir/lower_to_communication.cpp | 27 +++++++++++++++++------ csrc/preseg_passes/insert_reshardings.cpp | 12 ++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index ac344c78319..5bca6544e19 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -176,7 +176,7 @@ void lowerToBroadcastOrSendRecv( receiver_mesh); if (isSharded(input_tv) && sender_mesh.size() > 1) { // if the inputs and ouputs are parallelized, - // we create as many Broadcast as that will be handled in parallel + // we create as many SendRecvs as that will be handled in parallel NVF_ERROR( sender_mesh.size() == receiver_mesh.size(), "the receiver and sender meshes have different sizes: ", @@ -352,6 +352,13 @@ bool isAllocationOrderCompliant(TensorView* tv, IterDomain* sharded_id) { } std::optional getCommunicationInfo(Expr* expr) { + NVF_ERROR( + expr->isA() || expr->isA(), + "getCommunicationInfo should only be called when `expr` is known to be a " + "communication. So `expr` should be either a LoadStoreOp or a " + "ReductionOp. Given: ", + expr->toString()); + auto* producer = expr->inputs().at(0)->as(); auto* consumer = expr->outputs().at(0)->as(); bool has_sharding_change = false; @@ -396,16 +403,20 @@ std::optional getCommunicationInfo(Expr* expr) { CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); - continue; } if (!p_sharded && c_sharded) { - // Scatter IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); create_communication_info( CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id); - continue; } - } else if (expr->isA()) { + if (p_sharded && c_sharded) { + IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); + IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); + create_communication_info( + CommunicationType::SendRecv, p_logical_id, c_logical_id); + } + } else { + NVF_ERROR(expr->isA()); if (!p_sharded) { // Not a reduction based communication. continue; @@ -436,10 +447,12 @@ std::optional getCommunicationInfo(Expr* expr) { CommunicationType::ReduceScatter, c2p_map.at(c_logical_id), c_logical_id); - } else { - NVF_THROW("Unsupported expression: ", expr->toString()); } } + + if (!communication_info.has_value()) { + create_communication_info(CommunicationType::Broadcast, nullptr, nullptr); + } return communication_info; } diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index aed98114922..8199f02fbb3 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -335,6 +336,17 @@ void InsertReshardingsPass::runPass(Fusion* fusion) { // insertReshardingSetsBefore is used. insertReshardingSetsAfter(fusion); insertReshardingSetsBefore(fusion); + + // Validate + for (Expr* e : fusion->exprs()) { + if (isResharding(e)) { + NVF_ERROR( + getCommunicationInfo(e).has_value(), + "After decomposition, any resharding expression is expected to be a " + "lowerable communication: ", + e); + } + } } } // namespace nvfuser::preseg_passes From 02bf666098b1c46eb97918985e8f940b52a781fd Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 7 Jun 2025 15:35:57 -0700 Subject: [PATCH 5/8] Fix callers --- csrc/host_ir/lower_to_communication.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 3590c51a109..aa85e912b19 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -421,6 +421,16 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id) { const Layout layout = canonicalizeLayout(tv)->contiguous(); + // Reduction axis in reduce/allreduce does not have to be outermost in + // allocation domain. Nonetheless, `tv` still needs to be contiguous and + // therefore .contiguous() at the beginning of this function. + if (type == CommunicationType::Reduce || + type == CommunicationType::Allreduce || + type == CommunicationType::Broadcast || + type == CommunicationType::SendRecv) { + return layout; + } + const int64_t sharded_id_pos = posInDomain(layout.allocation_domain(), sharded_id); NVF_ERROR( @@ -430,14 +440,6 @@ Layout getCommunicationLayout( ") not found in the allocation domain of the tensor view: ", tv); - // Reduction axis in reduce/allreduce does not have to be outermost in - // allocation domain. Nonetheless, `tv` still needs to be contiguous and - // therefore .contiguous() at the beginning of this function. - if (type == CommunicationType::Reduce || - type == CommunicationType::Allreduce) { - return layout; - } - if (isLocalSizeOne(sharded_id)) { // Parallelized dimension, broadcast, and reduction do not affect // allocation. From 705c1ff435bf55f3c01b4858f2aa360a11835611 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 7 Jun 2025 18:05:18 -0700 Subject: [PATCH 6/8] Comments --- csrc/host_ir/lower_to_communication.cpp | 33 ++++++++++++++----------- csrc/host_ir/lower_to_communication.h | 10 +++----- 2 files changed, 22 insertions(+), 21 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index aa85e912b19..ded53637045 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -421,9 +421,9 @@ Layout getCommunicationLayout( const CommunicationType type, IterDomain* sharded_id) { const Layout layout = canonicalizeLayout(tv)->contiguous(); - // Reduction axis in reduce/allreduce does not have to be outermost in - // allocation domain. Nonetheless, `tv` still needs to be contiguous and - // therefore .contiguous() at the beginning of this function. + // For the following communication types, the sharded_id does not have to be + // outermost in allocation domain. Nonetheless, `tv` still needs to be + // contiguous and therefore .contiguous() at the beginning of this function. if (type == CommunicationType::Reduce || type == CommunicationType::Allreduce || type == CommunicationType::Broadcast || @@ -489,19 +489,19 @@ bool isCommunicationLayoutCompliant(Expr* expr) { } std::vector convertSingleOpToCommunication( - Expr* c, + Expr* e, DeviceIdxType my_device_idx, const CommunicatorBackend backend) { - FusionGuard fg(c->fusion()); + FusionGuard fg(e->fusion()); std::vector comms; NVF_ERROR( - c->inputs().size() == 1 && c->input(0)->isA() && - c->outputs().size() == 1 && c->output(0)->isA(), + e->inputs().size() == 1 && e->input(0)->isA() && + e->outputs().size() == 1 && e->output(0)->isA(), "Input/Output must be single TensorView: ", - c); - auto* input_tv = c->input(0)->as(); - auto* output_tv = c->output(0)->as(); + e); + auto* input_tv = e->input(0)->as(); + auto* output_tv = e->output(0)->as(); input_tv->setMemoryType(MemoryType::Global); output_tv->setMemoryType(MemoryType::Global); @@ -516,13 +516,12 @@ std::vector convertSingleOpToCommunication( isSharded(output_tv) && receiver_mesh.size() > 1; NVF_ERROR( - isCommunicationLayoutCompliant(c), + isCommunicationLayoutCompliant(e), "Resharding on an inner axis is not lowerable ", - c->toString()); - bool is_reduction = c->isA(); + e->toString()); - if (is_reduction) { - BinaryOpType op_type = c->as()->getReductionOpType(); + if (auto* reduce = dynamic_cast(e)) { + BinaryOpType op_type = reduce->getReductionOpType(); NVF_ERROR( is_input_sharded || sender_mesh.size() == 1, "the comm input must be sharded in case of reduce.", @@ -545,6 +544,10 @@ std::vector convertSingleOpToCommunication( } } } else { + NVF_ERROR( + e->isA(), + "Expected a LoadStoreOp or a ReductionOp, but got: ", + e); if (!is_input_sharded && is_output_sharded) { lowerToScatter(input_tv, output_tv, backend, comms); } else if (is_input_sharded && !is_output_sharded) { diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index 9f6c4586c89..ab75b9dd726 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -33,12 +33,10 @@ struct CommunicationInfo { // Composite expressions that are communication + compute are not supported. bool isCommunicationLayoutCompliant(Expr* expr); -// Returns the communication info for the -// (All)Gather/Scatter/ReduceScatter/(All)Reduce communication that may require -// copying the input/output and reordering the allocation domain. -// We assume that the expr has been decomposed and represented a single -// communication. If multiple communications are present, this function will -// raise an error. +// Given an Expr that's known to be a communication, returns the communication +// info: type and sharded IDs. We assume that the expr has been decomposed and +// represented a single communication. If multiple communications are present or +// 2D sharding, this function will raise an error. std::optional getCommunicationInfo(Expr* expr); // Given the input/output TensorView of a communication, returns its layout From b50210f45d23ceffeabdb04f450bb92423d84e46 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 7 Jun 2025 18:26:05 -0700 Subject: [PATCH 7/8] Cleanups after #4552 --- csrc/host_ir/lower_to_communication.cpp | 16 +++++----------- csrc/host_ir/lower_to_communication.h | 2 +- csrc/preseg_passes/insert_reshardings.cpp | 6 +----- csrc/preseg_passes/reorder_sharded_axis.cpp | 17 +++-------------- 4 files changed, 10 insertions(+), 31 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index ded53637045..674e1aaa6f6 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -301,7 +301,7 @@ bool isLocalSizeOne(IterDomain* id) { } // namespace -std::optional getCommunicationInfo(Expr* expr) { +CommunicationInfo getCommunicationInfo(Expr* expr) { NVF_ERROR( expr->isA() || expr->isA(), "getCommunicationInfo should only be called when `expr` is known to be a " @@ -403,7 +403,7 @@ std::optional getCommunicationInfo(Expr* expr) { if (!communication_info.has_value()) { create_communication_info(CommunicationType::Broadcast, nullptr, nullptr); } - return communication_info; + return *communication_info; } namespace { @@ -465,22 +465,16 @@ bool isCommunicationLayoutCompliant(Expr* expr) { auto* producer = expr->inputs().at(0)->as(); auto* consumer = expr->outputs().at(0)->as(); - // TODO(#4552): the caller should make sure Expr is a communication so - // getCommunicationInfo always returns a valid CommunicationInfo. Retry after - // #4552 is merged. - auto communication_info = getCommunicationInfo(expr); - if (!communication_info.has_value()) { - return true; - } + CommunicationInfo communication_info = getCommunicationInfo(expr); Layout p_layout = getCommunicationLayout( - producer, communication_info->type, communication_info->p_sharded_id); + producer, communication_info.type, communication_info.p_sharded_id); if (!isCompliantWith(*canonicalizeLayout(producer), p_layout)) { return false; } Layout c_layout = getCommunicationLayout( - consumer, communication_info->type, communication_info->c_sharded_id); + consumer, communication_info.type, communication_info.c_sharded_id); if (!isCompliantWith(*canonicalizeLayout(consumer), c_layout)) { return false; } diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h index ab75b9dd726..774e7211a11 100644 --- a/csrc/host_ir/lower_to_communication.h +++ b/csrc/host_ir/lower_to_communication.h @@ -37,7 +37,7 @@ bool isCommunicationLayoutCompliant(Expr* expr); // info: type and sharded IDs. We assume that the expr has been decomposed and // represented a single communication. If multiple communications are present or // 2D sharding, this function will raise an error. -std::optional getCommunicationInfo(Expr* expr); +CommunicationInfo getCommunicationInfo(Expr* expr); // Given the input/output TensorView of a communication, returns its layout // required by the communication backend (e.g. NCCL or UCC). `sharded_id` is the diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index 8199f02fbb3..48726aceda7 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -340,11 +340,7 @@ void InsertReshardingsPass::runPass(Fusion* fusion) { // Validate for (Expr* e : fusion->exprs()) { if (isResharding(e)) { - NVF_ERROR( - getCommunicationInfo(e).has_value(), - "After decomposition, any resharding expression is expected to be a " - "lowerable communication: ", - e); + getCommunicationInfo(e); } } } diff --git a/csrc/preseg_passes/reorder_sharded_axis.cpp b/csrc/preseg_passes/reorder_sharded_axis.cpp index e174b501e67..b96a8668cab 100644 --- a/csrc/preseg_passes/reorder_sharded_axis.cpp +++ b/csrc/preseg_passes/reorder_sharded_axis.cpp @@ -22,12 +22,11 @@ namespace nvfuser::preseg_passes { namespace { -void makeCommunicationLayoutCompliant( - Expr* expr, - CommunicationInfo communication_info) { +void makeCommunicationLayoutCompliant(Expr* expr) { auto* input = expr->inputs().at(0)->as(); auto* output = expr->outputs().at(0)->as(); + CommunicationInfo communication_info = getCommunicationInfo(expr); IterDomain* p_sharded_id = communication_info.p_sharded_id; IterDomain* c_sharded_id = communication_info.c_sharded_id; @@ -81,17 +80,7 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) { continue; } - auto communication_info = getCommunicationInfo(expr); - // Should really be simply NVF_ERROR(communication_info.has_value()); - // - // I'll try to do that after #4552 is merged. Some of the `mesh.size() > 1` - // check in getCommunicationInfo and convertSingleOpToCommuniation will also - // need to go away for this simplification. - if (!communication_info.has_value()) { - continue; - } - - makeCommunicationLayoutCompliant(expr, *communication_info); + makeCommunicationLayoutCompliant(expr); } if (isDebugDumpEnabled(DebugDumpOption::PreSegmenterLogging)) { From 3f357427e5cc5962603ab3728a8d4991e687ef9f Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Sat, 7 Jun 2025 18:30:34 -0700 Subject: [PATCH 8/8] Comment --- csrc/host_ir/lower_to_communication.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp index 674e1aaa6f6..99ffc9a24f3 100644 --- a/csrc/host_ir/lower_to_communication.cpp +++ b/csrc/host_ir/lower_to_communication.cpp @@ -313,9 +313,11 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { auto* consumer = expr->outputs().at(0)->as(); std::optional communication_info = std::nullopt; - auto create_communication_info = [&](CommunicationType type, - IterDomain* p_sharded_id, - IterDomain* c_sharded_id) { + // Fill `communication_info` instead of returning the result, so we can catch + // errors when more than one DIDs have sharding changes. + auto fill_communication_info = [&](CommunicationType type, + IterDomain* p_sharded_id, + IterDomain* c_sharded_id) { NVF_ERROR( !communication_info.has_value(), "Expected at most one sharding change"); @@ -352,17 +354,17 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allgather : CommunicationType::Gather; - create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); + fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); } if (!p_sharded && c_sharded) { IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); - create_communication_info( + fill_communication_info( CommunicationType::Scatter, c2p_map.at(c_logical_id), c_logical_id); } if (p_sharded && c_sharded) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); IterDomain* c_logical_id = getLogicalFromLoopId(consumer, c_loop_did); - create_communication_info( + fill_communication_info( CommunicationType::SendRecv, p_logical_id, c_logical_id); } } else { @@ -376,7 +378,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { IterDomain* p_logical_id = getLogicalFromLoopId(producer, p_loop_did); CommunicationType type = same_mesh ? CommunicationType::Allreduce : CommunicationType::Reduce; - create_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); + fill_communication_info(type, p_logical_id, p2c_map.at(p_logical_id)); continue; } @@ -393,7 +395,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { if (!c_it->second->isReduction()) { continue; } - create_communication_info( + fill_communication_info( CommunicationType::ReduceScatter, c2p_map.at(c_logical_id), c_logical_id); @@ -401,7 +403,7 @@ CommunicationInfo getCommunicationInfo(Expr* expr) { } if (!communication_info.has_value()) { - create_communication_info(CommunicationType::Broadcast, nullptr, nullptr); + fill_communication_info(CommunicationType::Broadcast, nullptr, nullptr); } return *communication_info; }