diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ad3a7ee3f..ab8d4c427fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/host_ir/executor.cpp ${NVFUSER_SRCS_DIR}/host_ir/host_ir.cpp ${NVFUSER_SRCS_DIR}/host_ir/lower.cpp + ${NVFUSER_SRCS_DIR}/host_ir/lower_to_communication.cpp ${NVFUSER_SRCS_DIR}/id_model/circular_buffer_indexing.cpp ${NVFUSER_SRCS_DIR}/id_model/contiguity.cpp ${NVFUSER_SRCS_DIR}/id_model/id_model.cpp @@ -216,6 +217,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/host_ir/pass/convert_op_to_communication.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp ${NVFUSER_SRCS_DIR}/host_ir/pass/insert_deallocations.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 39a5691ca97..c73822320fb 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include #include #include #include @@ -73,9 +75,8 @@ void HostIrExecutor::compile(Fusion* fusion) { std::vector exprs = fusion->exprs(); DeviceIdxType my_device_idx = communicator_ ? communicator_->deviceId() : 0; for (Expr* e : exprs) { - HostIrLower lower; - std::vector communications = - lower.lower(cloner.clone(e), my_device_idx); + std::vector communications = convertSingleOpToCommunication( + cloner.clone(e), my_device_idx, HostIrLowerParams()); for (auto* communication : communications) { host_ir_container_->pushBackTopLevelExprs(communication); } diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index ca9bb80ae4e..e628cdb4f73 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -7,6 +7,8 @@ // clang-format on #include #include +#include +#include #include #include #include @@ -25,369 +27,6 @@ namespace nvfuser { -namespace { - -// TODO: handle `c10d::RedOpType::reduceOp::AVG` and -// `c10d::RedOpType::reduceOp::PREMUL_SUM` -inline c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { - switch (op) { - case BinaryOpType::Add: - return c10d::ReduceOp::RedOpType::SUM; - case BinaryOpType::Mul: - return c10d::ReduceOp::RedOpType::PRODUCT; - case BinaryOpType::Min: - return c10d::ReduceOp::RedOpType::MIN; - case BinaryOpType::Max: - return c10d::ReduceOp::RedOpType::MAX; - case BinaryOpType::BitwiseAnd: - return c10d::ReduceOp::RedOpType::BAND; - case BinaryOpType::BitwiseOr: - return c10d::ReduceOp::RedOpType::BOR; - case BinaryOpType::BitwiseXor: - return c10d::ReduceOp::RedOpType::BXOR; - default: - NVF_THROW("unsupported reduction operation"); - return c10d::ReduceOp::RedOpType::UNUSED; - } -} - -// Adds one or zero Scatter communication to the vector 'comms' -void lowerToScatter( - TensorView* input_tv, - TensorView* output_tv, - const HostIrLowerParams& params, - std::vector& comms) { - // we arbitrarily choose the first device of the sender mesh to be the root - const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Gather only supported on a 1D mesh. Given ", - receiver_mesh); - auto root = input_tv->getDeviceMesh().at(0); - Team team = receiver_mesh.vector(); - if (!receiver_mesh.has(root)) { - team.push_back(root); - } - comms.push_back(IrBuilder::create( - CommunicationType::Scatter, - output_tv, - input_tv, - team, - root, - c10d::ReduceOp::RedOpType::UNUSED, - /*scatter_axis=*/-1, - params.communicator_backend)); -} - -/* -Adds zero or multiple Gather communications to the vector 'comms' - -Note that since the root of a Gather collective is a destination, we possibly -need multiple Gathers if the tensor is replicated in the receiver mesh. -*/ -void lowerToGather( - TensorView* input_tv, - TensorView* output_tv, - const HostIrLowerParams& params, - std::vector& comms) { - // we create as many 'Gathers' as there are devices in the receiver mesh - const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Currently only lower Gather on a 1D mesh. Given ", - sender_mesh); - for (auto root : output_tv->getDeviceMesh().vector()) { - Team team = sender_mesh.vector(); - if (!sender_mesh.has(root)) { - team.push_back(root); - } - comms.push_back(IrBuilder::create( - CommunicationType::Gather, - output_tv, - input_tv, - team, - root, - c10d::ReduceOp::RedOpType::UNUSED, - /*scatter_axis=*/-1, - params.communicator_backend)); - } -} - -// Add one or zero Allgather communication to the vector 'comms' -void lowerToAllgather( - TensorView* input_tv, - TensorView* output_tv, - const HostIrLowerParams& params, - std::vector& comms, - DeviceIdxType my_device_idx) { - const DeviceMesh& mesh = input_tv->getDeviceMesh(); - Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); - comms.push_back(IrBuilder::create( - CommunicationType::Allgather, - output_tv, - input_tv, - team, - /*root=*/-1, - c10d::ReduceOp::RedOpType::UNUSED, - /*scatter_axis=*/-1, - params.communicator_backend)); -} - -// Adds one or zero Broadcast communication to the vector 'comms' -void lowerToBroadcast( - TensorView* input_tv, - TensorView* output_tv, - DeviceIdxType root, - const HostIrLowerParams& params, - std::vector& comms) { - const DeviceMesh& mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - mesh.rank() == 1, "Broadcast only supported a 1D mesh. Given ", mesh); - Team team = mesh.vector(); - if (!mesh.has(root)) { - team.push_back(root); - } - comms.push_back(IrBuilder::create( - CommunicationType::Broadcast, - output_tv, - input_tv, - team, - root, - c10d::ReduceOp::RedOpType::UNUSED, - /*scatter_axis=*/-1, - params.communicator_backend)); -} - -// Adds several Broadcast or SendRecv communications to the vector 'comms' -// For now, we assume that this function is called only if -// the input and output have the same sharding. Later we could support more -// general cases. -void lowerToBroadcastOrSendRecv( - TensorView* input_tv, - TensorView* output_tv, - const HostIrLowerParams& params, - std::vector& comms) { - const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Broadcast only supported a 1D mesh. Given ", - sender_mesh); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Broadcast only supported a 1D mesh. Given ", - receiver_mesh); - if (isSharded(input_tv) && sender_mesh.size() > 1) { - // if the inputs and ouputs are parallelized, - // we create as many Broadcast as that will be handled in parallel - NVF_ERROR( - sender_mesh.size() == receiver_mesh.size(), - "the receiver and sender meshes have different sizes: ", - sender_mesh.size(), - " vs ", - receiver_mesh.size()); - for (auto i : arange(sender_mesh.size())) { - const DeviceIdxType sender = sender_mesh.at(i); - const DeviceIdxType receiver = receiver_mesh.at(i); - comms.push_back(IrBuilder::create( - CommunicationType::SendRecv, - output_tv, - input_tv, - Team({sender, receiver}), - /*root=*/sender, - c10d::ReduceOp::RedOpType::UNUSED, - /*scatter_axis=*/-1, - params.communicator_backend)); - } - } else { - // Either of the following two cases is happening. - // 1. `sender_mesh` contains only one device. In this case, we broadcast - // from that device. - // 2. `sender_mesh` contains multiple devices but the input is not sharded. - // In this case, we arbitrarily choose the first device of the sender mesh - // to be the root. - lowerToBroadcast( - input_tv, - output_tv, - /*root=*/sender_mesh.at(0), - params, - comms); - } -} - -void lowerToReduce( - TensorView* input_tv, - TensorView* output_tv, - BinaryOpType op_type, - const HostIrLowerParams& params, - std::vector& comms) { - const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - NVF_ERROR( - sender_mesh.rank() == 1, - "Reduce only supported a 1D mesh. Given ", - sender_mesh); - NVF_ERROR( - receiver_mesh.rank() == 1, - "Reduce only supported a 1D mesh. Given ", - receiver_mesh); - const auto reduce_op_type = getC10dReduceOpType(op_type); - // we create as many Reduces as there are devices in the receiver mesh - for (auto root : receiver_mesh.vector()) { - Team team = sender_mesh.vector(); - if (!sender_mesh.has(root)) { - team.push_back(root); - } - comms.push_back(IrBuilder::create( - CommunicationType::Reduce, - output_tv, - input_tv, - team, - root, - reduce_op_type, - /*scatter_axis=*/-1, - params.communicator_backend)); - } -} - -void lowerToAllreduce( - TensorView* input_tv, - TensorView* output_tv, - BinaryOpType op_type, - const HostIrLowerParams& params, - std::vector& comms, - DeviceIdxType my_device_idx) { - const DeviceMesh& mesh = input_tv->getDeviceMesh(); - Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); - comms.push_back(IrBuilder::create( - CommunicationType::Allreduce, - output_tv, - input_tv, - team, - /*root=*/-1, - getC10dReduceOpType(op_type), - /*scatter_axis=*/-1, - params.communicator_backend)); -} - -void lowerToReduceScatter( - TensorView* input_tv, - TensorView* output_tv, - BinaryOpType op_type, - const HostIrLowerParams& params, - std::vector& comms, - DeviceIdxType my_device_idx) { - const DeviceMesh& mesh = input_tv->getDeviceMesh(); - Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); - auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); - // The output tensor is sharded on scattered_axis and needs to be mapped - // back onto the input. The input has an reduced axis, so the scattered axis - // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The - // scattered_axis is axis=0 on the output and maps to axis=1 on the input. - if (reduction_axis <= scattered_axis) { - scattered_axis++; - } - - comms.push_back(IrBuilder::create( - CommunicationType::ReduceScatter, - output_tv, - input_tv, - /*team=*/team, - /*root=*/-1, - getC10dReduceOpType(op_type), - scattered_axis, - params.communicator_backend)); -} - -} // namespace - -/* -TODO: -*) Propose several lowering paths for each given communication - and provide a logic to decide which path to take -*) Leverage replication in the source to create several communications handled - in parallel. The idea would be to evenly split the destinations accross the - sources -*) Leverage the topology to ensure that the senders and recerivers are close -*/ -std::vector HostIrLower::lower(Expr* c, DeviceIdxType my_device_idx) { - FusionGuard fg(c->fusion()); - - if (c->isOneOf()) { - return lowerToCollectiveBasedPipelinedGemmComm(c); - } - - std::vector comms; - NVF_ERROR( - c->inputs().size() == 1 && c->input(0)->isA() && - c->outputs().size() == 1 && c->output(0)->isA(), - "Input/Output must be single TensorView: ", - c); - auto* input_tv = c->input(0)->as(); - auto* output_tv = c->output(0)->as(); - - input_tv->setMemoryType(MemoryType::Global); - output_tv->setMemoryType(MemoryType::Global); - - const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); - const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); - const bool same_mesh = sender_mesh == receiver_mesh; - - // Stores whether the I/O has its first axis parallelized on DIDx - const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1; - const bool is_output_sharded = - isSharded(output_tv) && receiver_mesh.size() > 1; - - NVF_ERROR( - HostIrLower::canLower(c), - "Lowering expression ", - c->toString(), - " to communication is not supported"); - NVF_ERROR( - !isInnerResharding(c), - "Resharding on an inner axis is not lowerable ", - c->toString()); - bool is_reduction = c->isA(); - - if (is_reduction) { - BinaryOpType op_type = c->as()->getReductionOpType(); - NVF_ERROR( - is_input_sharded || sender_mesh.size() == 1, - "the comm input must be sharded in case of reduce.", - "Insert a `set` before the reduction to reshard") - if (is_output_sharded) { - NVF_ERROR( - same_mesh, - "ReduceScatter operation must have the same sender and receiver device mesh. " - "Insert a Set operation before or after the reduction to reshard ot another device mesh"); - lowerToReduceScatter( - input_tv, output_tv, op_type, params_, comms, my_device_idx); - } else { - if (same_mesh) { - lowerToAllreduce( - input_tv, output_tv, op_type, params_, comms, my_device_idx); - } else { - lowerToReduce(input_tv, output_tv, op_type, params_, comms); - } - } - } else { - if (!is_input_sharded && is_output_sharded) { - lowerToScatter(input_tv, output_tv, params_, comms); - } else if (is_input_sharded && !is_output_sharded) { - if (same_mesh) { - lowerToAllgather(input_tv, output_tv, params_, comms, my_device_idx); - } else { - lowerToGather(input_tv, output_tv, params_, comms); - } - } else { - lowerToBroadcastOrSendRecv(input_tv, output_tv, params_, comms); - } - } - - return comms; -} - bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { if (!isResharding(expr)) { return true; @@ -448,173 +87,6 @@ bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) { return false; } -std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( - Expr* expr) { - NVF_ERROR( - (expr->isOneOf()), - "Expect a MatmulOp or a LinearOp, but got", - expr); - TensorView* tva = nullptr; - TensorView* tvb = nullptr; - TensorView* tv_bias = nullptr; - TensorView* tv_out = nullptr; - if (auto* matmul = dynamic_cast(expr)) { - tva = matmul->inA(); - tvb = matmul->inB(); - tv_out = matmul->out(); - } else { - auto* linear = expr->as(); - tva = linear->inA()->as(); - tvb = linear->inB()->as(); - tv_bias = (linear->has_bias() ? linear->bias()->as() : nullptr); - tv_out = linear->out()->as(); - NVF_ERROR( - !(linear->has_bias() && isSharded(tv_bias)), - "The bias ", - tv_bias, - " is expected to not be sharded"); - } - - NVF_ERROR( - !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded"); - NVF_ERROR( - !isSharded(tv_out), - "The output ", - tv_out, - " is expected to not be sharded"); - NVF_ERROR( - tv_out->axis(0)->getParallelType() == ParallelType::Stream, - "The output ", - tv_out, - " is expected to be stream-parallelized on axis 0"); - const int64_t sharded_axis_index = - getShardedLogicalAxis(tva, ParallelType::DIDx); - IterDomain* stream_axis = tva->axis(0); - NVF_ERROR( - stream_axis->getParallelType() == ParallelType::Serial && - sharded_axis_index == 1, - "The operand A ", - tva, - " is expected to be sharded on the dimension 1"); - - auto hic = FusionGuard::getCurFusion()->as(); - - auto* get_current_stream = IrBuilder::create(); - hir::Stream* original_stream = get_current_stream->stream(); - - TensorView* tva_allgathered = - ops::newValLike(tva, tva->dtype())->as(); - tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial); - tva_allgathered->setMemoryType(MemoryType::Global); - auto* allocate_tva_allgathered = - IrBuilder::create(tva_allgathered, MemoryType::Global); - - tv_out->setMemoryType(MemoryType::Global); - auto* allocate_tv_out = - IrBuilder::create(tv_out, MemoryType::Global); - - auto* j = - IrBuilder::create(DataType::Index); // running index of the for-loop - auto* start = hic->zeroVal(); - auto* stop = stream_axis->extent(); - auto* step = hic->oneVal(); - auto* for_loop_initial_sync = IrBuilder::create( - stream_axis, - /*index=*/j, - start, - stop, - step, - /*vectorize=*/false, - /*vectorize_shift=*/nullptr, - /*unroll_required=*/false, - CircularBufferLoopStage::NotApplicable, - /*circular_buffer_loop_stage_depth=*/0); - - auto* number_of_streams = - IrBuilder::create("numberOfStreams", DataType::Int); - auto* stream_index = mod(j, number_of_streams); - auto* stream = IrBuilder::create(stream_index); - auto* set_stream = IrBuilder::create(stream); - auto* initial_sync_stream = - IrBuilder::create(original_stream); - - // the initial sync of the streams with the user's stream is done in a - // separate for-loop for performance reasons with comms/compute overlap - std::vector loop_body_initial_sync = {set_stream, initial_sync_stream}; - for (Expr* expr : loop_body_initial_sync) { - for_loop_initial_sync->body().push_back(expr); - } - - auto* for_loop = IrBuilder::create( - stream_axis, - /*index=*/j, - start, - stop, - step, - /*vectorize=*/false, - /*vectorize_shift=*/nullptr, - /*unroll_required=*/false, - CircularBufferLoopStage::NotApplicable, - /*circular_buffer_loop_stage_depth=*/0); - - TensorView* tva_j = select(tva, 0, j); - TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); - TensorView* tv_out_j = select(tv_out, 0, j); - - NVF_ERROR( - tva->hasDeviceMesh(), - "The matmul's input ", - tva, - "is expected to have a DeviceMesh"); - for (auto tv : {tva_j, tva_allgathered_j, tv_out_j}) { - tv->setDeviceMesh(tva->getDeviceMesh()); - } - - auto* communication = IrBuilder::create( - CommunicationType::Allgather, - /*out=*/tva_allgathered_j, - /*in=*/tva_j, - /*team=*/tva->getDeviceMesh().vector(), - /*root=*/-1, - /*red_op=*/RedOpType::UNUSED, - /*scattered_axis=*/-1, - params_.communicator_backend); - auto* wait = IrBuilder::create(communication); - - Expr* compute = nullptr; - if (expr->isA()) { - compute = IrBuilder::create(tv_out_j, tva_allgathered_j, tvb); - } else { - compute = - IrBuilder::create(tv_out_j, tva_allgathered_j, tvb, tv_bias); - } - - auto* set_back_original_stream = - IrBuilder::create(original_stream); - auto* sync_stream = IrBuilder::create(stream); - - std::vector loop_body = { - set_stream, - tva_j->definition(), - tva_allgathered_j->definition(), - communication, - wait, - tv_out_j->definition(), - compute, - set_back_original_stream, - sync_stream}; - for (Expr* expr : loop_body) { - for_loop->body().push_back(expr); - } - - return { - get_current_stream, - allocate_tva_allgathered, - allocate_tv_out, - for_loop_initial_sync, - for_loop}; -} - bool HostIrLower::isLowerableAsStandaloneHostOp(Expr* expr) { if (expr->isOneOf< MatmulOp, @@ -743,33 +215,9 @@ std::unique_ptr HostIrLower::lower( tv->setMemoryType(MemoryType::Global); } - std::vector new_top_level_exprs; - for (auto top_level_expr : hic->topLevelExprs()) { - if (!isResharding(top_level_expr)) { - new_top_level_exprs.push_back(top_level_expr); - continue; - } - for (auto* expr : HostIrLower::lower(top_level_expr, my_device_index)) { - // Allocate the recv buffers of communications - if (expr->isA()) { - auto* communication = expr->as(); - TensorView* tv = communication->out(); - if (tv->getDeviceMesh().has(my_device_index)) { - auto* allocate = - IrBuilder::create(tv, MemoryType::Global); - new_top_level_exprs.push_back(allocate); - } - } - new_top_level_exprs.push_back(expr); - if (expr->isA()) { - auto wait = IrBuilder::create(expr->as()); - new_top_level_exprs.push_back(wait); - } - } - } - hic->resetTopLevelExprs(new_top_level_exprs); + hir_pass::StreamParallelType().runPass(hic.get()); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::ConvertOpToCommunication(params_).runPass(hic.get()); return hic; } diff --git a/csrc/host_ir/lower.h b/csrc/host_ir/lower.h index 5b1ecadece8..8df156d4512 100644 --- a/csrc/host_ir/lower.h +++ b/csrc/host_ir/lower.h @@ -43,8 +43,16 @@ class HostIrLower { SegmentedGroup* group2); private: - std::vector lowerToCollectiveBasedPipelinedGemmComm(Expr* expr); const HostIrLowerParams params_; }; +namespace hir_pass { + +std::vector convertSingleOpToCommunication( + Expr* c, + DeviceIdxType my_device_idx, + const HostIrLowerParams& params); + +} // namespace hir_pass + } // namespace nvfuser diff --git a/csrc/host_ir/lower_to_communication.cpp b/csrc/host_ir/lower_to_communication.cpp new file mode 100644 index 00000000000..49e796a679e --- /dev/null +++ b/csrc/host_ir/lower_to_communication.cpp @@ -0,0 +1,550 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +namespace { + +// TODO: handle `c10d::RedOpType::reduceOp::AVG` and +// `c10d::RedOpType::reduceOp::PREMUL_SUM` +inline c10d::ReduceOp::RedOpType getC10dReduceOpType(BinaryOpType op) { + switch (op) { + case BinaryOpType::Add: + return c10d::ReduceOp::RedOpType::SUM; + case BinaryOpType::Mul: + return c10d::ReduceOp::RedOpType::PRODUCT; + case BinaryOpType::Min: + return c10d::ReduceOp::RedOpType::MIN; + case BinaryOpType::Max: + return c10d::ReduceOp::RedOpType::MAX; + case BinaryOpType::BitwiseAnd: + return c10d::ReduceOp::RedOpType::BAND; + case BinaryOpType::BitwiseOr: + return c10d::ReduceOp::RedOpType::BOR; + case BinaryOpType::BitwiseXor: + return c10d::ReduceOp::RedOpType::BXOR; + default: + NVF_THROW("unsupported reduction operation"); + return c10d::ReduceOp::RedOpType::UNUSED; + } +} + +// Adds one or zero Scatter communication to the vector 'comms' +void lowerToScatter( + TensorView* input_tv, + TensorView* output_tv, + const HostIrLowerParams& params, + std::vector& comms) { + // we arbitrarily choose the first device of the sender mesh to be the root + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR( + receiver_mesh.rank() == 1, + "Gather only supported on a 1D mesh. Given ", + receiver_mesh); + auto root = input_tv->getDeviceMesh().at(0); + Team team = receiver_mesh.vector(); + if (!receiver_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create( + CommunicationType::Scatter, + output_tv, + input_tv, + team, + root, + c10d::ReduceOp::RedOpType::UNUSED, + /*scatter_axis=*/-1, + params.communicator_backend)); +} + +/* +Adds zero or multiple Gather communications to the vector 'comms' + +Note that since the root of a Gather collective is a destination, we possibly +need multiple Gathers if the tensor is replicated in the receiver mesh. +*/ +void lowerToGather( + TensorView* input_tv, + TensorView* output_tv, + const HostIrLowerParams& params, + std::vector& comms) { + // we create as many 'Gathers' as there are devices in the receiver mesh + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + NVF_ERROR( + sender_mesh.rank() == 1, + "Currently only lower Gather on a 1D mesh. Given ", + sender_mesh); + for (auto root : output_tv->getDeviceMesh().vector()) { + Team team = sender_mesh.vector(); + if (!sender_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create( + CommunicationType::Gather, + output_tv, + input_tv, + team, + root, + c10d::ReduceOp::RedOpType::UNUSED, + /*scatter_axis=*/-1, + params.communicator_backend)); + } +} + +// Add one or zero Allgather communication to the vector 'comms' +void lowerToAllgather( + TensorView* input_tv, + TensorView* output_tv, + const HostIrLowerParams& params, + std::vector& comms, + DeviceIdxType my_device_idx) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); + Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); + comms.push_back(IrBuilder::create( + CommunicationType::Allgather, + output_tv, + input_tv, + team, + /*root=*/-1, + c10d::ReduceOp::RedOpType::UNUSED, + /*scatter_axis=*/-1, + params.communicator_backend)); +} + +// Adds one or zero Broadcast communication to the vector 'comms' +void lowerToBroadcast( + TensorView* input_tv, + TensorView* output_tv, + DeviceIdxType root, + const HostIrLowerParams& params, + std::vector& comms) { + const DeviceMesh& mesh = output_tv->getDeviceMesh(); + NVF_ERROR( + mesh.rank() == 1, "Broadcast only supported a 1D mesh. Given ", mesh); + Team team = mesh.vector(); + if (!mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create( + CommunicationType::Broadcast, + output_tv, + input_tv, + team, + root, + c10d::ReduceOp::RedOpType::UNUSED, + /*scatter_axis=*/-1, + params.communicator_backend)); +} + +// Adds several Broadcast or SendRecv communications to the vector 'comms' +// For now, we assume that this function is called only if +// the input and output have the same sharding. Later we could support more +// general cases. +void lowerToBroadcastOrSendRecv( + TensorView* input_tv, + TensorView* output_tv, + const HostIrLowerParams& params, + std::vector& comms) { + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + NVF_ERROR( + sender_mesh.rank() == 1, + "Broadcast only supported a 1D mesh. Given ", + sender_mesh); + NVF_ERROR( + receiver_mesh.rank() == 1, + "Broadcast only supported a 1D mesh. Given ", + receiver_mesh); + if (isSharded(input_tv) && sender_mesh.size() > 1) { + // if the inputs and ouputs are parallelized, + // we create as many Broadcast as that will be handled in parallel + NVF_ERROR( + sender_mesh.size() == receiver_mesh.size(), + "the receiver and sender meshes have different sizes: ", + sender_mesh.size(), + " vs ", + receiver_mesh.size()); + for (auto i : c10::irange(sender_mesh.size())) { + const DeviceIdxType sender = sender_mesh.at(i); + const DeviceIdxType receiver = receiver_mesh.at(i); + comms.push_back(IrBuilder::create( + CommunicationType::SendRecv, + output_tv, + input_tv, + Team({sender, receiver}), + /*root=*/sender, + c10d::ReduceOp::RedOpType::UNUSED, + /*scatter_axis=*/-1, + params.communicator_backend)); + } + } else { + // Either of the following two cases is happening. + // 1. `sender_mesh` contains only one device. In this case, we broadcast + // from that device. + // 2. `sender_mesh` contains multiple devices but the input is not sharded. + // In this case, we arbitrarily choose the first device of the sender mesh + // to be the root. + lowerToBroadcast( + input_tv, + output_tv, + /*root=*/sender_mesh.at(0), + params, + comms); + } +} + +void lowerToReduce( + TensorView* input_tv, + TensorView* output_tv, + BinaryOpType op_type, + const HostIrLowerParams& params, + std::vector& comms) { + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + NVF_ERROR( + sender_mesh.rank() == 1, + "Reduce only supported a 1D mesh. Given ", + sender_mesh); + NVF_ERROR( + receiver_mesh.rank() == 1, + "Reduce only supported a 1D mesh. Given ", + receiver_mesh); + const auto reduce_op_type = getC10dReduceOpType(op_type); + // we create as many Reduces as there are devices in the receiver mesh + for (auto root : receiver_mesh.vector()) { + Team team = sender_mesh.vector(); + if (!sender_mesh.has(root)) { + team.push_back(root); + } + comms.push_back(IrBuilder::create( + CommunicationType::Reduce, + output_tv, + input_tv, + team, + root, + reduce_op_type, + /*scatter_axis=*/-1, + params.communicator_backend)); + } +} + +void lowerToAllreduce( + TensorView* input_tv, + TensorView* output_tv, + BinaryOpType op_type, + const HostIrLowerParams& params, + std::vector& comms, + DeviceIdxType my_device_idx) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); + Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); + comms.push_back(IrBuilder::create( + CommunicationType::Allreduce, + output_tv, + input_tv, + team, + /*root=*/-1, + getC10dReduceOpType(op_type), + /*scatter_axis=*/-1, + params.communicator_backend)); +} + +void lowerToReduceScatter( + TensorView* input_tv, + TensorView* output_tv, + BinaryOpType op_type, + const HostIrLowerParams& params, + std::vector& comms, + DeviceIdxType my_device_idx) { + const DeviceMesh& mesh = input_tv->getDeviceMesh(); + Team team = mesh.getSlice(my_device_idx, ParallelType::DIDx); + auto reduction_axis = output_tv->getReductionAxis().value(); + auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); + // The output tensor is sharded on scattered_axis and needs to be mapped + // back onto the input. The input has an reduced axis, so the scattered axis + // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The + // scattered_axis is axis=0 on the output and maps to axis=1 on the input. + if (reduction_axis <= scattered_axis) { + scattered_axis++; + } + + comms.push_back(IrBuilder::create( + CommunicationType::ReduceScatter, + output_tv, + input_tv, + /*team=*/team, + /*root=*/-1, + getC10dReduceOpType(op_type), + scattered_axis, + params.communicator_backend)); +} + +std::vector lowerToCollectiveBasedPipelinedGemmComm( + Expr* expr, + const HostIrLowerParams& params) { + NVF_ERROR( + (expr->isOneOf()), + "Expect a MatmulOp or a LinearOp, but got", + expr); + TensorView* tva = nullptr; + TensorView* tvb = nullptr; + TensorView* tv_bias = nullptr; + TensorView* tv_out = nullptr; + if (auto* matmul = dynamic_cast(expr)) { + tva = matmul->inA(); + tvb = matmul->inB(); + tv_out = matmul->out(); + } else { + auto* linear = expr->as(); + tva = linear->inA()->as(); + tvb = linear->inB()->as(); + tv_bias = (linear->has_bias() ? linear->bias()->as() : nullptr); + tv_out = linear->out()->as(); + NVF_ERROR( + !(linear->has_bias() && isSharded(tv_bias)), + "The bias ", + tv_bias, + " is expected to not be sharded"); + } + + NVF_ERROR( + !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded"); + NVF_ERROR( + !isSharded(tv_out), + "The output ", + tv_out, + " is expected to not be sharded"); + NVF_ERROR( + tv_out->axis(0)->getParallelType() == ParallelType::Stream, + "The output ", + tv_out, + " is expected to be stream-parallelized on axis 0"); + const int64_t sharded_axis_index = + getShardedLogicalAxis(tva, ParallelType::DIDx); + IterDomain* stream_axis = tva->axis(0); + NVF_ERROR( + stream_axis->getParallelType() == ParallelType::Serial && + sharded_axis_index == 1, + "The operand A ", + tva, + " is expected to be sharded on the dimension 1"); + + auto hic = FusionGuard::getCurFusion()->as(); + + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + + TensorView* tva_allgathered = + ops::newValLike(tva, tva->dtype())->as(); + tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial); + tva_allgathered->setMemoryType(MemoryType::Global); + auto* allocate_tva_allgathered = + IrBuilder::create(tva_allgathered, MemoryType::Global); + + tv_out->setMemoryType(MemoryType::Global); + auto* allocate_tv_out = + IrBuilder::create(tv_out, MemoryType::Global); + + auto* j = + IrBuilder::create(DataType::Index); // running index of the for-loop + auto* start = hic->zeroVal(); + auto* stop = stream_axis->extent(); + auto* step = hic->oneVal(); + auto* for_loop_initial_sync = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(j, number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + auto* initial_sync_stream = + IrBuilder::create(original_stream); + + // the initial sync of the streams with the user's stream is done in a + // separate for-loop for performance reasons with comms/compute overlap + std::vector loop_body_initial_sync = {set_stream, initial_sync_stream}; + for (Expr* expr : loop_body_initial_sync) { + for_loop_initial_sync->body().push_back(expr); + } + + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + + TensorView* tva_j = select(tva, 0, j); + TensorView* tva_allgathered_j = select(tva_allgathered, 0, j); + TensorView* tv_out_j = select(tv_out, 0, j); + + NVF_ERROR( + tva->hasDeviceMesh(), + "The matmul's input ", + tva, + "is expected to have a DeviceMesh"); + for (auto tv : {tva_j, tva_allgathered_j, tv_out_j}) { + tv->setDeviceMesh(tva->getDeviceMesh()); + } + + auto* communication = IrBuilder::create( + CommunicationType::Allgather, + /*out=*/tva_allgathered_j, + /*in=*/tva_j, + /*team=*/tva->getDeviceMesh().vector(), + /*root=*/-1, + /*red_op=*/RedOpType::UNUSED, + /*scattered_axis=*/-1, + params.communicator_backend); + auto* wait = IrBuilder::create(communication); + + Expr* compute = nullptr; + if (expr->isA()) { + compute = IrBuilder::create(tv_out_j, tva_allgathered_j, tvb); + } else { + compute = + IrBuilder::create(tv_out_j, tva_allgathered_j, tvb, tv_bias); + } + + auto* set_back_original_stream = + IrBuilder::create(original_stream); + auto* sync_stream = IrBuilder::create(stream); + + std::vector loop_body = { + set_stream, + tva_j->definition(), + tva_allgathered_j->definition(), + communication, + wait, + tv_out_j->definition(), + compute, + set_back_original_stream, + sync_stream}; + for (Expr* expr : loop_body) { + for_loop->body().push_back(expr); + } + + return { + get_current_stream, + allocate_tva_allgathered, + allocate_tv_out, + for_loop_initial_sync, + for_loop}; +} + +} // namespace + +std::vector convertSingleOpToCommunication( + Expr* c, + DeviceIdxType my_device_idx, + const HostIrLowerParams& params) { + FusionGuard fg(c->fusion()); + + if (c->isOneOf()) { + return lowerToCollectiveBasedPipelinedGemmComm(c, params); + } + + std::vector comms; + NVF_ERROR( + c->inputs().size() == 1 && c->input(0)->isA() && + c->outputs().size() == 1 && c->output(0)->isA(), + "Input/Output must be single TensorView: ", + c); + auto* input_tv = c->input(0)->as(); + auto* output_tv = c->output(0)->as(); + + input_tv->setMemoryType(MemoryType::Global); + output_tv->setMemoryType(MemoryType::Global); + + const DeviceMesh& sender_mesh = input_tv->getDeviceMesh(); + const DeviceMesh& receiver_mesh = output_tv->getDeviceMesh(); + const bool same_mesh = sender_mesh == receiver_mesh; + + // Stores whether the I/O has its first axis parallelized on DIDx + const bool is_input_sharded = isSharded(input_tv) && sender_mesh.size() > 1; + const bool is_output_sharded = + isSharded(output_tv) && receiver_mesh.size() > 1; + + NVF_ERROR( + HostIrLower::canLower(c), + "Lowering expression ", + c->toString(), + " to communication is not supported"); + NVF_ERROR( + !isInnerResharding(c), + "Resharding on an inner axis is not lowerable ", + c->toString()); + bool is_reduction = c->isA(); + + if (is_reduction) { + BinaryOpType op_type = c->as()->getReductionOpType(); + NVF_ERROR( + is_input_sharded || sender_mesh.size() == 1, + "the comm input must be sharded in case of reduce.", + "Insert a `set` before the reduction to reshard") + if (is_output_sharded) { + NVF_ERROR( + same_mesh, + "ReduceScatter operation must have the same sender and receiver device mesh. " + "Insert a Set operation before or after the reduction to reshard ot another device mesh"); + lowerToReduceScatter( + input_tv, output_tv, op_type, params, comms, my_device_idx); + } else { + if (same_mesh) { + lowerToAllreduce( + input_tv, output_tv, op_type, params, comms, my_device_idx); + } else { + lowerToReduce(input_tv, output_tv, op_type, params, comms); + } + } + } else { + if (!is_input_sharded && is_output_sharded) { + lowerToScatter(input_tv, output_tv, params, comms); + } else if (is_input_sharded && !is_output_sharded) { + if (same_mesh) { + lowerToAllgather(input_tv, output_tv, params, comms, my_device_idx); + } else { + lowerToGather(input_tv, output_tv, params, comms); + } + } else { + lowerToBroadcastOrSendRecv(input_tv, output_tv, params, comms); + } + } + + return comms; +} + +} // namespace nvfuser diff --git a/csrc/host_ir/lower_to_communication.h b/csrc/host_ir/lower_to_communication.h new file mode 100644 index 00000000000..1edef39b180 --- /dev/null +++ b/csrc/host_ir/lower_to_communication.h @@ -0,0 +1,20 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +std::vector convertSingleOpToCommunication( + Expr* c, + DeviceIdxType my_device_idx, + const HostIrLowerParams& params); + +} // namespace nvfuser diff --git a/csrc/host_ir/pass/convert_op_to_communication.cpp b/csrc/host_ir/pass/convert_op_to_communication.cpp new file mode 100644 index 00000000000..daa33ecae9a --- /dev/null +++ b/csrc/host_ir/pass/convert_op_to_communication.cpp @@ -0,0 +1,75 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::hir_pass { + +void ConvertOpToCommunication::passImplementation(Fusion* fusion) { + FusionGuard fg(fusion); + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); + DeviceIdxType my_device_index = Communicator::getInstance().deviceId(); + + auto handle_top_level_expr = [&](Expr* top_level_expr, + std::vector& new_top_level_exprs) { + if (!isResharding(top_level_expr)) { + return new_top_level_exprs.push_back(top_level_expr); + } + for (auto* expr : nvfuser::convertSingleOpToCommunication( + top_level_expr, my_device_index, params_)) { + // Allocate the recv buffers of communications + if (expr->isA()) { + auto* communication = expr->as(); + TensorView* tv = communication->out(); + if (tv->getDeviceMesh().has(my_device_index) && + hic->alias().count(tv) == 0) { + auto* allocate = + IrBuilder::create(tv, MemoryType::Global); + new_top_level_exprs.push_back(allocate); + } + } + new_top_level_exprs.push_back(expr); + if (expr->isA()) { + auto wait = IrBuilder::create(expr->as()); + new_top_level_exprs.push_back(wait); + } + } + }; + + std::vector new_top_level_exprs; + for (auto top_level_expr : hic->topLevelExprs()) { + if (top_level_expr->isA()) { + auto* for_loop = top_level_expr->as(); + std::vector new_for_loop_body; + for (auto* expr : for_loop->body().exprs()) { + handle_top_level_expr(expr, new_for_loop_body); + } + for_loop->body().clear(); + for (auto* expr : new_for_loop_body) { + for_loop->body().push_back(expr); + } + new_top_level_exprs.push_back(for_loop); + } else { + handle_top_level_expr(top_level_expr, new_top_level_exprs); + } + } + hic->resetTopLevelExprs(new_top_level_exprs); +} + +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/convert_op_to_communication.h b/csrc/host_ir/pass/convert_op_to_communication.h new file mode 100644 index 00000000000..5a7496b3cec --- /dev/null +++ b/csrc/host_ir/pass/convert_op_to_communication.h @@ -0,0 +1,36 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include +#include +#include + +namespace nvfuser::hir_pass { + +class ConvertOpToCommunication + : public OptimizationPass { + friend class OptimizationPass; + + public: + ConvertOpToCommunication( + const HostIrLowerParams& params = HostIrLowerParams()) + : params_(params) {} + + protected: + void passImplementation(Fusion* fusion); + static constexpr std::string_view name() { + return "ConvertOpToCommunication"; + } + + private: + HostIrLowerParams params_; +}; + +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/insert_deallocations.cpp b/csrc/host_ir/pass/insert_deallocations.cpp index 3a62dc97032..4f27f62e692 100644 --- a/csrc/host_ir/pass/insert_deallocations.cpp +++ b/csrc/host_ir/pass/insert_deallocations.cpp @@ -8,13 +8,17 @@ #include -namespace nvfuser::hir { +namespace nvfuser::hir_pass { + +void InsertDeallocations::passImplementation(Fusion* fusion) { + FusionGuard fg(fusion); + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); -void insertDeallocations(HostIrContainer* hic) { const std::vector& top_level_exprs = hic->topLevelExprs(); std::for_each(top_level_exprs.begin(), top_level_exprs.end(), [](Expr* expr) { NVF_ERROR( - !expr->isA(), + !expr->isA(), "Expected hostir container to not have deallocate, but found one anyways"); }); std::unordered_map last_use; @@ -35,9 +39,9 @@ void insertDeallocations(HostIrContainer* hic) { } std::sort(last_use_by_index.begin(), last_use_by_index.end()); for (auto&& [i, tv] : last_use_by_index | std::views::reverse) { - auto* deallocate = IrBuilder::create(tv); + auto* deallocate = IrBuilder::create(tv); hic->insertExprAfter(i, deallocate); } } -} // namespace nvfuser::hir +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/insert_deallocations.h b/csrc/host_ir/pass/insert_deallocations.h index 221441a4d23..a7bd8219f02 100644 --- a/csrc/host_ir/pass/insert_deallocations.h +++ b/csrc/host_ir/pass/insert_deallocations.h @@ -8,11 +8,20 @@ #pragma once #include +#include -namespace nvfuser::hir { +namespace nvfuser::hir_pass { /* For each input in every expression in the container, find the index of its * last use and insert a deallocate directly after */ -void insertDeallocations(HostIrContainer* hic); +class InsertDeallocations : public OptimizationPass { + friend class OptimizationPass; -} // namespace nvfuser::hir + protected: + void passImplementation(Fusion* fusion); + static constexpr std::string_view name() { + return "InsertDeallocations"; + } +}; + +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/optimization_pass.h b/csrc/host_ir/pass/optimization_pass.h new file mode 100644 index 00000000000..19d4c8abc7a --- /dev/null +++ b/csrc/host_ir/pass/optimization_pass.h @@ -0,0 +1,88 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nvfuser::hir_pass { + +//! Base class to unify host IR optimization pass APIs. +//! OptimizationPass can be turned on/off programmatically with the `setEnabled` +//! API. There's helper template OptimizationPassGuard to temporarily switch the +//! enablement within the context. Note the we are using a curiously recurring +//! template pattern here to ensure that static objects are unique for each +//! DerivedClass. +//! +//! Specific host IR optimization pass needs to be created like: +//! +//! class Pass0 : public OptimizationPass { +//! friend class OptimizationPass; +//! +//! protected: +//! void runPass(Fusion* fusion); +//! }; +template +class OptimizationPass { + public: + static void setEnabled(bool enabled) { + flag_.store(enabled); + } + + static bool getEnabled() { + return flag_.load(); + } + + void runPass(Fusion* fusion) { + if (!flag_.load()) { + return; + } + + FUSER_PERF_SCOPE(DerivedClass::name().data()); + static_cast(this)->passImplementation(fusion); + + if (isDebugDumpEnabled(DebugDumpOption::HostIrLoweringLogging)) { + debug() << "Fusion after pass: " << DerivedClass::name() << std::endl; + if (fusion->isA()) { + fusion->as()->print(debug()); + } else { + fusion->printMath(); + } + debug() << "========================================" << std::endl; + } + } + + protected: + static inline std::atomic flag_{true}; +}; + +//! OptimizationPassGuard is used to temporarily switch enable/disable on a +//! certain pass. Original status will be restored at destruction. +template +class OptimizationPassGuard { + public: + OptimizationPassGuard(bool enabled) : prev_status_(OptPass::getEnabled()) { + if (prev_status_ != enabled) { + OptPass::setEnabled(enabled); + } + } + ~OptimizationPassGuard() { + OptPass::setEnabled(prev_status_); + } + + protected: + bool prev_status_ = false; +}; + +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 550494554cf..eb764665ebd 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -18,7 +18,7 @@ #include #include -namespace nvfuser::hir { +namespace nvfuser::hir_pass { namespace { @@ -421,7 +421,7 @@ std::vector addStreamManagement(std::vector top_level_exprs) { // linear structure of the HostIrContainer::topLevelExpr to greedily merge the // adjacent compatible stream for-loop bodies. Ideally we should look at the dag // and use the segmenter. -void StreamParallelType::runPass(Fusion* fusion) { +void StreamParallelType::passImplementation(Fusion* fusion) { // Verify that input tensors don't have stream axes NVF_CHECK( std::all_of( @@ -459,4 +459,4 @@ void StreamParallelType::runPass(Fusion* fusion) { hic->resetTopLevelExprs(top_level_exprs); } -} // namespace nvfuser::hir +} // namespace nvfuser::hir_pass diff --git a/csrc/host_ir/pass/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h index 8b5f138ad7e..c98b2088915 100644 --- a/csrc/host_ir/pass/stream_parallel_type.h +++ b/csrc/host_ir/pass/stream_parallel_type.h @@ -8,9 +8,9 @@ #pragma once #include -#include +#include -namespace nvfuser::hir { +namespace nvfuser::hir_pass { // A pass used in HostIrLower that takes a HostIrContainer as input, reads the // TensorView's ParallelType::Stream, and modify the the HostIrContainer's top @@ -22,15 +22,14 @@ namespace nvfuser::hir { // An illustration of the pass can be found in the tests // `test_host_ir_stream_lowering.cpp` // with the option `NVFUSER_DUMP=host_ir`. -class StreamParallelType - : public preseg_passes::OptimizationPass { - friend class preseg_passes::OptimizationPass; +class StreamParallelType : public OptimizationPass { + friend class OptimizationPass; protected: - static void runPass(Fusion* fusion); + void passImplementation(Fusion* fusion); static constexpr std::string_view name() { return "StreamParallelType"; } }; -} // namespace nvfuser::hir +} // namespace nvfuser::hir_pass diff --git a/csrc/options.cpp b/csrc/options.cpp index 33610d5b8fc..1c8bbf3861b 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -120,6 +120,7 @@ std::unordered_map> Options< {"fusion_ir_presched", DebugDumpOption::FusionIrPresched}, {"fusion_ir_preseg", DebugDumpOption::FusionIrPreseg}, {"global_zeroed_memory", DebugDumpOption::GlobalZeroedMemory}, + {"host_ir_lowering_logging", DebugDumpOption::HostIrLoweringLogging}, {"host_ir", DebugDumpOption::HostIr}, {"index_type", DebugDumpOption::IndexType}, {"indexing_verbose", DebugDumpOption::IndexingVerbose}, diff --git a/csrc/options.h b/csrc/options.h index b050a0a0199..4c466b652d4 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -65,6 +65,7 @@ enum class DebugDumpOption { PerfDebugVerbose, //! When running kernels, print verbose information //! associated with what's running PreSegmenterLogging, + HostIrLoweringLogging, //! Dump the Host IR after each lowering pass PythonDefinition, //! Python Frontend Fusion Definition. PythonDefinitionSegments, //! Python Frontend Fusion Definition of segments. PythonFrontendDebug, //! Python Frontend debug information. diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index fbc475af9e1..7010af2950e 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include #include #include @@ -517,9 +519,10 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { NVF_ERROR( group_to_run->exprs().size() == 1, "Communication segments must contain only one Expr"); - HostIrLower lower; - for (auto* expr : lower.lower( - ir_cloner.clone(group_to_run->exprs().at(0)), deviceid)) { + for (auto* expr : convertSingleOpToCommunication( + ir_cloner.clone(group_to_run->exprs().at(0)), + deviceid, + HostIrLowerParams())) { NVF_ERROR( expr->isA(), "Exprs in a Communication group should be Communication"); @@ -552,7 +555,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { hic->addOutput(ir_cloner.clone(out)); } - insertDeallocations(hic.get()); + hir_pass::InsertDeallocations().runPass(hic.get()); hie_ = std::make_unique( std::move(hic), &Communicator::getInstance()); diff --git a/python/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp index e45f7181d5e..9b144cfa1ea 100644 --- a/python/python_frontend/fusion_definition.cpp +++ b/python/python_frontend/fusion_definition.cpp @@ -455,7 +455,7 @@ std::pair> FusionDefinition:: params.lower.communicator_backend = backend_type_; // Disable StreamParallelType pass temporarily as proper stream lowering // gets implemented - preseg_passes::OptimizationPassGuard guard( + hir_pass::OptimizationPassGuard guard( false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index 8150a58d5b7..3f914bf5620 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -36,8 +36,7 @@ TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { hic->addInput(tv); tv->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } TEST_F(HirLowerStreamTest, Split) { @@ -51,8 +50,7 @@ TEST_F(HirLowerStreamTest, Split) { tv1->split(0, 2); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } TEST_F(HirLowerStreamTest, Merge) { @@ -66,8 +64,7 @@ TEST_F(HirLowerStreamTest, Merge) { tv1->merge(0, 1); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } TEST_F(HirLowerStreamTest, SingleSetOp) { @@ -82,7 +79,7 @@ TEST_F(HirLowerStreamTest, SingleSetOp) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -113,7 +110,7 @@ TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { tv1->setMemoryType(MemoryType::Global); tv1->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -147,7 +144,7 @@ TEST_F(HirLowerStreamTest, SingleBinaryOp) { tv2->setMemoryType(MemoryType::Global); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -160,7 +157,6 @@ TEST_F(HirLowerStreamTest, SingleBinaryOp) { auto options = at::TensorOptions().device(at::kCUDA, 0); at::Tensor tv0_input = at::rand({4, 4}, options); at::Tensor tv1_input = at::rand({4, 4}, options); - // std::unordered_map inputs = {{tv0, input}}; auto output = hie.runWithInput({{tv0, tv0_input}, {tv1, tv1_input}})[0] .as(); auto expected_output = tv0_input + tv1_input; @@ -184,7 +180,7 @@ TEST_F(HirLowerStreamTest, TwoSetOps) { tv1->axis(0)->parallelize(ParallelType::Stream); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 5); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -223,7 +219,7 @@ TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { tv1->axis(0)->parallelize(ParallelType::Stream); tv3->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 9); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -259,8 +255,7 @@ TEST_F(HirLowerStreamTest, ReductionUnsupported) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } TEST_F(HirLowerStreamTest, Reduction) { @@ -275,7 +270,7 @@ TEST_F(HirLowerStreamTest, Reduction) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -310,7 +305,7 @@ TEST_F(HirLowerStreamTest, Matmul_M) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -348,7 +343,7 @@ TEST_F(HirLowerStreamTest, BatchedMatmul) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -386,7 +381,7 @@ TEST_F(HirLowerStreamTest, Matmul_N) { c->setMemoryType(MemoryType::Global); c->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass(hic.get()); + hir_pass::StreamParallelType().runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 4); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -424,8 +419,7 @@ TEST_F(HirLowerStreamTest, Matmul_K) { c->setMemoryType(MemoryType::Global); c->axis(-1)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } // We don's support PostOnStream because it does not support well pre-allocated @@ -472,8 +466,7 @@ TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { output->axis(-1)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW( - preseg_passes::OptimizationPass::runPass(hic.get())); + EXPECT_ANY_THROW(hir_pass::StreamParallelType().runPass(hic.get())); } } // namespace hir diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index db53f7f114d..4732d12057f 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -363,9 +364,11 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { - // Disable StreamParallelType pass temporarily as proper stream lowering gets - // implemented - preseg_passes::OptimizationPassGuard guard(false); + // Disable StreamParallelType and ReorderShardedAxisPass pass temporarily as + // proper stream lowering gets implemented + hir_pass::OptimizationPassGuard guard(false); + preseg_passes::OptimizationPassGuard + guard2(false); constexpr int64_t M = 32768; constexpr int64_t K = 32768; @@ -422,8 +425,11 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { } TEST_F(OverlapDistributedMatmulTest, AG_linear) { - // Disable StreamParallelType pass tempor - preseg_passes::OptimizationPassGuard guard(false); + // Disable StreamParallelType and ReorderShardedAxisPass pass temporarily as + // proper stream lowering gets implemented + hir_pass::OptimizationPassGuard guard(false); + preseg_passes::OptimizationPassGuard + guard2(false); constexpr int64_t M = 32768; constexpr int64_t K = 32768;