From a177cb4a015c49e7dfe7d376f237ecd643bda422 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 26 Mar 2025 06:08:43 -0700 Subject: [PATCH 01/68] add ParallelType::Stream lowering pass in host Ir for single device fusions --- CMakeLists.txt | 2 + csrc/host_ir/executor.h | 4 + csrc/host_ir/lower.cpp | 8 + csrc/ir/internal_nodes.h | 4 + csrc/multidevice/executor.h | 4 + csrc/ops/indexing.cpp | 10 +- csrc/ops/indexing.h | 6 +- csrc/ops/utils.cpp | 27 +- csrc/ops/utils.h | 14 +- csrc/preseg_passes/stream_parallel_type.cpp | 347 +++++++++ csrc/preseg_passes/stream_parallel_type.h | 26 + tests/cpp/test_host_ir_stream_lowering.cpp | 823 ++++++++++++++++++++ tests/cpp/test_multidevice_host_ir.cpp | 10 + 13 files changed, 1271 insertions(+), 14 deletions(-) create mode 100644 csrc/preseg_passes/stream_parallel_type.cpp create mode 100644 csrc/preseg_passes/stream_parallel_type.h create mode 100644 tests/cpp/test_host_ir_stream_lowering.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b9865da34a7..3f2750b59b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/preseg_passes/stream_parallel_type.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp @@ -731,6 +732,7 @@ if(BUILD_TEST) list(APPEND HOSTIR_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp ${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp + ${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp ) add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_host_ir) diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index dfe84fba068..89ac5119681 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -97,6 +97,10 @@ class HostIrEvaluator final : public OptOutDispatch { return container_->outputs(); } + auto* container() const { + return container_.get(); + } + std::ostream& print(std::ostream& os) const { return container_->print(os); }; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 308e1399872..1a74d9a9f01 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -718,6 +719,10 @@ std::unique_ptr HostIrLower::lower( hic->addOutput(ir_cloner.clone(output)); } + for (auto tv : hic->allTvs()) { + tv->setMemoryType(MemoryType::Global); + } + std::vector new_top_level_exprs; for (auto top_level_expr : hic->topLevelExprs()) { if (!isResharding(top_level_expr)) { @@ -744,6 +749,9 @@ std::unique_ptr HostIrLower::lower( } hic->resetTopLevelExprs(new_top_level_exprs); + preseg_passes::OptimizationPass::runPass( + hic.get()); + return hic; } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 91d3ca4ec39..1a2bb1634bb 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2477,6 +2477,10 @@ class ForLoop final : public Expr { return input(0); } + IterDomain* iterDomain() const { + return input(1)->as(); + } + Val* indexOrStartIfTrivial() const { return isTrivial() ? start() : index(); } diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index c1cc3e31cfe..7dd08a87f0a 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -103,6 +103,10 @@ class MultiDeviceExecutor { return host_ir_executor_->getFusionExecutorCaches(); }; + auto* hostIrEvaluator() const { + return host_ir_executor_.get(); + } + private: // holds the Communicator to be used for execution Communicator& comm_; diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index 5ff75065ff2..80c0ff84b85 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -19,8 +19,14 @@ namespace nvfuser { -TensorView* select(TensorView* tv, int64_t dim, Val* index) { - auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); +TensorView* select( + TensorView* tv, + int64_t dim, + Val* index, + bool keep_reduction_axis) { + auto dom = keep_reduction_axis + ? tv->getLogicalDomain() + : TensorDomain::noReductions(tv->getLogicalDomain()); NVF_CHECK(!dom.empty(), "select can not be applied to 0d tensor."); std::vector new_root; diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index c8152c33f82..7a219c534a3 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -15,7 +15,11 @@ namespace nvfuser { -NVF_API TensorView* select(TensorView* tv, int64_t dim, Val* index); +NVF_API TensorView* select( + TensorView* tv, + int64_t dim, + Val* index, + bool keep_reduction_axis = false); // torch.index_select NVF_API TensorView* indexSelect( diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 8d3870d1a84..5d32c22e212 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -432,7 +432,9 @@ IterDomain* newOutputIterDomain( #pragma GCC diagnostic pop #endif -std::vector newOutputDomain(const std::vector& vals) { +std::vector newOutputDomain( + const std::vector& vals, + bool keep_reduction_axis) { std::vector tvs; for (auto val : vals) { if (auto* tv = dynamic_cast(val)) { @@ -443,14 +445,20 @@ std::vector newOutputDomain(const std::vector& vals) { !tvs.empty(), "Tried to create new output TensorView but received empty list."); - std::vector out_domain( - TensorDomain::noReductions(tvs[0]->getLogicalDomain()).size(), nullptr); + auto getLogicalDomain = + [keep_reduction_axis](TensorView* tv) -> std::vector { + return keep_reduction_axis + ? tv->getLogicalDomain() + : TensorDomain::noReductions(tv->getLogicalDomain()); + }; + + std::vector out_domain(getLogicalDomain(tvs[0]).size(), nullptr); for (const auto dim_i : arange(out_domain.size())) { std::vector input_ids; input_ids.reserve(tvs.size()); for (auto* tv : tvs) { - auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); + auto dom = getLogicalDomain(tv); input_ids.emplace_back(dom[dim_i]); } out_domain[dim_i] = newOutputIterDomain(input_ids); @@ -458,8 +466,11 @@ std::vector newOutputDomain(const std::vector& vals) { return out_domain; } -TensorView* newOutputTV(const std::vector& vals, DataType dtype) { - auto out_domain = newOutputDomain(vals); +TensorView* newOutputTV( + const std::vector& vals, + DataType dtype, + bool keep_reduction_axis) { + auto out_domain = newOutputDomain(vals, keep_reduction_axis); auto* new_out = IrBuilder::create( IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), @@ -502,12 +513,12 @@ std::vector maybeBroadcast(const std::vector& vals) { return out_vals; } -Val* newValLike(Val* val, DataType dtype) { +Val* newValLike(Val* val, DataType dtype, bool keep_reduction_axis) { NVF_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); if (val->isA()) { - return newOutputTV({val}, dtype); + return newOutputTV({val}, dtype, keep_reduction_axis); } return newScalar(ValType::Others, dtype); diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 94d6391cf45..1a2abda03fc 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -99,13 +99,21 @@ IterDomain* newOutputIterDomain( // output tensorview, e.g., for BinaryOp. `vals` can contain scalars, e.g, when // creating the output TensorView for `tv0+scalar`. This is for convenience and // scalars will be ignored. -std::vector newOutputDomain(const std::vector& vals); +std::vector newOutputDomain( + const std::vector& vals, + bool keep_reduction_axis = false); -TensorView* newOutputTV(const std::vector& vals, DataType dtype); +TensorView* newOutputTV( + const std::vector& vals, + DataType dtype, + bool keep_reduction_axis = false); std::vector maybeBroadcast(const std::vector& vals); -NVF_API Val* newValLike(Val* val, DataType dtype); +NVF_API Val* newValLike( + Val* val, + DataType dtype, + bool keep_reduction_axis = false); // returns the minimum init value for reduction: // -inf for floating type; diff --git a/csrc/preseg_passes/stream_parallel_type.cpp b/csrc/preseg_passes/stream_parallel_type.cpp new file mode 100644 index 00000000000..5a814c9a59a --- /dev/null +++ b/csrc/preseg_passes/stream_parallel_type.cpp @@ -0,0 +1,347 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::preseg_passes { + +// returns the first stream axis in the domain, or nullptr if there is none. +// Throws if two axis are stream parallelized +IterDomain* getStreamAxis(const std::vector& domain) { + IterDomain* ret = nullptr; + for (auto id : domain) { + if (id->getParallelType() == ParallelType::Stream) { + NVF_CHECK( + ret == nullptr, + "Expected at most one stream axis in the domain, but found ", + id, + " and ", + ret); + ret = id; + } + } + return ret; +} + + + +// TODO: ideally we should look at the dag and use the segmenter. Here we take +// advantage of the linear structure of HostIrContainer::topLevelExprs to +// greedily merge the adjacent compatible stream for-loop bodies +void StreamParallelType::runPass(Fusion* fusion) { + // check that there are no stream axes in the inputs + NVF_CHECK( + std::all_of( + fusion->inputs().begin(), + fusion->inputs().end(), + [](Val* input) { + auto input_tv = dynamic_cast(input); + return input_tv == nullptr || + getStreamAxis(input_tv->getLoopDomain()) == nullptr; + }), + "Expected no stream axis in the TensorView inputs."); + + FusionGuard fg(fusion); // set as current container to register the newly + // created for-loops + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); + // needed ? + IdModel id_model(fusion); + id_model.buildAlmostExactGraph(); + + std::vector new_top_level_exprs; + // Step 1. Find the segments of expressions that can be merged into a single + // stream for-loop At the end of this step, new_top_level_exprs contains a + // list of expressions including newly created for-loops that will represent + // the stream parallelization, and the relevant expressions grouped inside the + // for-loops bodies. + for (auto expr : hic->topLevelExprs()) { + // we only support exprs having at most 1 output for now + if (expr->outputs().size() == 0) { + new_top_level_exprs.push_back(expr); + continue; + } + NVF_CHECK( + expr->outputs().size() == 1, + "Each expr should have at most one output."); + TensorView* output = expr->output(0)->as(); + // retrieves the Loop IterDomain that is stream parallelized, if any + IterDomain* stream_axis = getStreamAxis(output->getLoopDomain()); + if (stream_axis == nullptr) { + // if the consumer is not stream parallelized, it means the expr need not + // be inside a stream for-loop + new_top_level_exprs.push_back(expr); + continue; + } + NVF_ERROR( + HostIrLower::isLoweredAsStandaloneHostOp(expr), + "Stream parallel type not supported for expr ", + expr); + // find the corresponding stream axis but in the Logical (and not Loop + // Domain) + auto it_logical_stream_axis = std::find( + output->getLogicalDomain().begin(), + output->getLogicalDomain().end(), + stream_axis); + // for now we do not support split/merge stream axis + NVF_ERROR( + it_logical_stream_axis != output->getLogicalDomain().end(), + "Cannot stream parallelize on a split/merge axis ", + stream_axis); + // we don't support reducing or broadcasting a stream axis + NVF_CHECK( + stream_axis->getIterType() == IterType::Iteration, + "Stream axis ", + stream_axis, + " should be an iteration axis."); + // check if the current expr can be merged with the previous stream for-loop + // We consider the previous expression to check whether the expr should + // create a new stream for-loop or be integrated into the previous one + if (!new_top_level_exprs.empty() && + new_top_level_exprs.back()->isA() && + id_model.idGraph(IdMappingMode::ALMOSTEXACT) + .disjointValSets() + .strictAreMapped( + stream_axis, + new_top_level_exprs.back()->as()->iterDomain())) { + // merge with previous for-loop + new_top_level_exprs.back()->as()->body().push_back(expr); + } else { + // create a new for-loop + auto* j = IrBuilder::create( + DataType::Index); // running index of the for-loop + auto* start = hic->zeroVal(); + auto* stop = stream_axis->extent(); + auto* step = hic->oneVal(); + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/j, + start, + stop, + step, + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + for_loop->body().push_back(expr); + // replace the current expr by the for-loop containing it + new_top_level_exprs.push_back(for_loop); + } + } + + // Step 2. Setup each for loop's body by Slicing the tensors. + std::vector top_level_exprs = std::move(new_top_level_exprs); + new_top_level_exprs.clear(); + for (auto top_level_expr : top_level_exprs) { + // TODO: change in place? consr issue + if (!top_level_expr->isA()) { + new_top_level_exprs.push_back(top_level_expr); + continue; + } + auto* for_loop = top_level_expr->as(); + // this will contain the new body of the current for-loop + std::vector new_loop_body; + + std::vector current_loop_body = for_loop->body().exprs(); + for (auto it_expr = current_loop_body.begin(); + it_expr != current_loop_body.end(); + ++it_expr) { + Expr* expr = *it_expr; + for (auto* input : ir_utils::filterByType(expr->inputs())) { + int64_t input_stream_id_logical_index = -1; + for (auto id : input->getLoopDomain()) { + if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) + .disjointValSets() + .strictAreMapped(for_loop->iterDomain(), id)) { + NVF_CHECK( + input_stream_id_logical_index == -1, + "Expected at most one axis mapping to the stream axis ", + for_loop->iterDomain(), + " in the tensor ", + input, + " loop's domain ", + input->getLoopDomain()); + auto it_input_stream_id_logical = std::find( + input->getLogicalDomain().begin(), + input->getLogicalDomain().end(), + id); + NVF_CHECK( + it_input_stream_id_logical != input->getLogicalDomain().end(), + "Expected to find ", + id, + " in ", + input, + "'s logical domain ", + input->getLogicalDomain()); + input_stream_id_logical_index = std::distance( + input->getLogicalDomain().begin(), it_input_stream_id_logical); + } + } + if (input_stream_id_logical_index == -1) { + continue; + } + TensorView* input_j = select( + input, + input_stream_id_logical_index, + for_loop->index(), + /*keep_reduction_axis=*/true); + new_loop_body.push_back(input_j->definition()); + for (auto it_running_expr = current_loop_body.begin(); + it_running_expr != current_loop_body.end(); + ++it_running_expr) { + Expr* running_expr = *it_running_expr; + for (auto* running_input : + ir_utils::filterByType(running_expr->inputs())) { + if (running_input == input) { + *it_running_expr = ir_utils::replaceValInExprInputs( + running_expr, input, input_j); + } + } + } + } + + for (auto* output : ir_utils::filterByType(expr->outputs())) { + int64_t output_stream_id_logical_index = -1; + for (auto id : output->getLoopDomain()) { + if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) + .disjointValSets() + .strictAreMapped(for_loop->iterDomain(), id)) { + NVF_CHECK( + output_stream_id_logical_index == -1, + "Expected at most one axis mapping to the stream axis ", + for_loop->iterDomain(), + " in the tensor ", + output, + " loop's domain ", + output->getLoopDomain()); + auto it_output_stream_id_logical = std::find( + output->getLogicalDomain().begin(), + output->getLogicalDomain().end(), + id); + NVF_CHECK( + it_output_stream_id_logical != output->getLogicalDomain().end(), + "Expected to find ", + id, + " in ", + output, + "'s logical domain ", + output->getLogicalDomain()); + output_stream_id_logical_index = std::distance( + output->getLogicalDomain().begin(), + it_output_stream_id_logical); + } + } + if (output_stream_id_logical_index == -1) { + continue; + } + TensorView* output_j = select( + output, + output_stream_id_logical_index, + for_loop->index(), + /*keep_reduction_axis=*/true); + new_top_level_exprs.push_back( + IrBuilder::create(output, MemoryType::Global)); + new_loop_body.push_back(output_j->definition()); + for (auto it_running_expr = current_loop_body.begin(); + it_running_expr != current_loop_body.end(); + ++it_running_expr) { + Expr* running_expr = *it_running_expr; + for (auto* running_output : + ir_utils::filterByType(running_expr->outputs())) { + if (running_output == output) { + TensorView* output_j_alias = + ops::newValLike( + output_j, output_j->dtype(), /*keep_reduction_axis=*/true) + ->as(); + hic->markAlias(output_j, output_j_alias); + *it_running_expr = ir_utils::transferDefinitionToNewOutputs( + running_expr, {output_j_alias}); + if (Communication* comm = dynamic_cast( + output_j_alias->definition()); + comm && comm->type() == CommunicationType::Allgather) { + std::cout << "HERE, with expr:" << *it_running_expr + << std::endl; + } + } + } + } + } + new_loop_body.push_back(*it_expr); + } + // reseting the for-loop body + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + new_top_level_exprs.push_back(top_level_expr); + } + + // Step 3. Finalize the for-loop bodies by adding the stream setup and + // synchronization + for (auto* top_level_expr : new_top_level_exprs) { + if (!top_level_expr->isA()) { + continue; + } + auto* for_loop = top_level_expr->as(); + std::vector new_loop_body; + + // Get the current stream to later synchronize subsequent new streams + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + new_loop_body.push_back(get_current_stream); + + // set the stream to the one corresponding to the current for-loop index + auto* j = for_loop->index(); + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(j, number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + new_loop_body.push_back(set_stream); + + // sync the new stream with the original stream + auto* initial_sync_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(initial_sync_stream); + + // add the actual exprs to the for-loop body + for (auto* expr : for_loop->body().exprs()) { + new_loop_body.push_back(expr); + } + + // set back the original stream + auto* set_back_original_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(set_back_original_stream); + // synchronize original stream with the for-loop's streams + auto* sync_stream = IrBuilder::create(stream); + new_loop_body.push_back(sync_stream); + + // reset the for-loop's body to the one we constructed. + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + } + + // reset hic topLevelExprs to new_top_level_exprs + hic->resetTopLevelExprs(new_top_level_exprs); +} + +} // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/stream_parallel_type.h b/csrc/preseg_passes/stream_parallel_type.h new file mode 100644 index 00000000000..a9600809e21 --- /dev/null +++ b/csrc/preseg_passes/stream_parallel_type.h @@ -0,0 +1,26 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::preseg_passes { + +// A pass used in HostIrLower that takes a HostIrContainer as input, reads the TensorView's ParallelType::Stream, and modify the the HostIrContainer's top level expressions with the corresponding Host For Loops, which bodies contain stream assignement, selecting on tensor's axis, and the exprs on those sliced tensors. After this pass, the ParallelType::Stream is removed from the TensorView's axis. +class StreamParallelType : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static constexpr std::string_view name() { + return "StreamParallelType"; + } +}; + +} // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp new file mode 100644 index 00000000000..9f3d9f432ed --- /dev/null +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -0,0 +1,823 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nvfuser { + +namespace hir { + +using HirLowerStreamTest = NVFuserTest; + +TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv = makeContigTensor(2); + hic->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Split) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Merge) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, SingleSetOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleBinaryOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + hic->addInput(tv0); + hic->addInput(tv1); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + // std::unordered_map inputs = {{tv0, input}}; + auto output = hie.runWithInput({{tv0, tv0_input}, {tv1, tv1_input}})[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, TwoSetOps) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + hic->addInput(tv0); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 3); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + hic->addInput(tv0); + hic->addOutput(tv3); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + hic->pushBackTopLevelExprs(tv3->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 5); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(4)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ReductionUnsupported) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Reduction) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_M) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, BatchedMatmul) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_N) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass( + hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_K) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +// We don's support PostOnStream because it does not support well pre-allocated +// outputs. There is no strong motivation to support PostOnStream +TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { + const std::vector input_sizes = {4, 8, 32}; + const std::vector output_sizes = { + input_sizes.at(1), input_sizes.at(2)}; + + auto get_fusion = [input_sizes]() -> std::unique_ptr { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor(input_sizes); + auto tv1 = add(tv0, tv0); + auto tv2 = sum(tv1, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv2); + return fusion; + }; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto host_unit = IrBuilder::create(get_fusion()); + + IrCloner ir_cloner(hic.get()); + TensorView* input = + ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0)) + ->as(); + TensorView* output = + ir_cloner.clone(host_unit->fusion_to_execute()->outputs().at(0)) + ->as(); + + std::vector inputs = {input}; + std::vector outputs = {output}; + auto post_on_stream = + IrBuilder::create(host_unit, inputs, outputs); + + hic->pushBackTopLevelExprs(post_on_stream); + + hic->addInput(input); + hic->addOutput(output); + + output->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW(preseg_passes::OptimizationPass< + preseg_passes::StreamParallelType>::runPass(hic.get())); +} + +} // namespace hir + +using MultiDeviceExecutorLowerStreamTest = NVFuserTest; + +TEST_F(MultiDeviceExecutorLowerStreamTest, InputsAreNotStreamParallelized) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv = makeContigTensor(2); + fusion->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Split) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Merge) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOpNonOutermost) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleBinaryOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({tv0_input, tv1_input}))[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, TwoSetOps) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + fusion->addInput(tv0); + fusion->addOutput(tv2); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 3); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + fusion->addInput(tv0); + fusion->addOutput(tv3); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 5); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(4)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ReductionUnsupported) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Reduction) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_M) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, BatchedMatmul) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_N) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_K) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +// We only support Stream parallel type on ops that support pre-allocated +// output, which means they need a special handle in HostIrEvaluator and they +// need to be lowered as a Host Ir Op in the TopLevelExpression, no a +// PostOnStream(HostUnit(.)) See HostIrLower::isLoweredAsStandaloneHostOp and +// the test HirLowerStreamTest.DoNotSupportPostOnStream +TEST_F(MultiDeviceExecutorLowerStreamTest, DoNotSupportPostOnStream) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = + abs(tv0); // arbitrary example of an unsupported op. There is no deep + // reason why we not support it -- if needed we could widen the + // support. But I wanna make sure that an unsupported op do not + // silently fails + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 88286d6e4c0..e9db27cfd7f 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -362,6 +363,11 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + // Disable StreamParallelType pass temporarily as proper stream lowering gets + // implemented + preseg_passes::OptimizationPassGuard guard( + false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024; @@ -417,6 +423,10 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { } TEST_F(OverlapDistributedMatmulTest, AG_linear) { + // Disable StreamParallelType pass tempor + preseg_passes::OptimizationPassGuard guard( + false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024; From e8869419dd152f5d6f71f505830e35b97cb9f274 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 26 Mar 2025 06:29:12 -0700 Subject: [PATCH 02/68] improve comments --- csrc/preseg_passes/stream_parallel_type.cpp | 159 ++++++++++++-------- csrc/preseg_passes/stream_parallel_type.h | 11 +- 2 files changed, 109 insertions(+), 61 deletions(-) diff --git a/csrc/preseg_passes/stream_parallel_type.cpp b/csrc/preseg_passes/stream_parallel_type.cpp index 5a814c9a59a..12fe2f6a285 100644 --- a/csrc/preseg_passes/stream_parallel_type.cpp +++ b/csrc/preseg_passes/stream_parallel_type.cpp @@ -20,8 +20,9 @@ namespace nvfuser::preseg_passes { -// returns the first stream axis in the domain, or nullptr if there is none. -// Throws if two axis are stream parallelized +// Helper function to find the first stream-parallelized axis in a domain. +// This function throws if multiple stream-parallelized axes are found (only one +// is allowed) IterDomain* getStreamAxis(const std::vector& domain) { IterDomain* ret = nullptr; for (auto id : domain) { @@ -38,13 +39,27 @@ IterDomain* getStreamAxis(const std::vector& domain) { return ret; } - - -// TODO: ideally we should look at the dag and use the segmenter. Here we take -// advantage of the linear structure of HostIrContainer::topLevelExprs to -// greedily merge the adjacent compatible stream for-loop bodies +// StreamParallelType pass implementation. +// This pass handles stream parallelization of operations in a fusion. +// It works by: +// 1. Identifying stream-parallelized axes in tensor operations +// 2. Grouping compatible operations into stream-parallel for-loops +// 3. Setting up proper stream synchronization and management +// +// The pass ensures that: +// - Input tensors don't have stream axes +// - Only one stream axis exists per tensor +// - Stream axes are properly synchronized +// - Operations are correctly grouped into stream-parallel regions +// - The resulting HostIrContainer's top level expression is valid for execution +// and does not contain any stream axes +// +// TODO: Here, we assume that the fusion input is a HostIrContainer and use the +// linear structure of the HostIrContainer::topLevelExpr to greedily merge the +// adjacent compatible stream for-loop bodies. Ideally we should look at the dag +// and use the segmenter. void StreamParallelType::runPass(Fusion* fusion) { - // check that there are no stream axes in the inputs + // Verify that input tensors don't have stream axes NVF_CHECK( std::all_of( fusion->inputs().begin(), @@ -56,62 +71,71 @@ void StreamParallelType::runPass(Fusion* fusion) { }), "Expected no stream axis in the TensorView inputs."); - FusionGuard fg(fusion); // set as current container to register the newly - // created for-loops + // Set up the fusion environment and build the ID model + FusionGuard fg(fusion); hir::HostIrContainer* hic = dynamic_cast(fusion); NVF_CHECK(hic, "Expected HostIrContainer"); - // needed ? + IdModel id_model(fusion); id_model.buildAlmostExactGraph(); std::vector new_top_level_exprs; - // Step 1. Find the segments of expressions that can be merged into a single - // stream for-loop At the end of this step, new_top_level_exprs contains a - // list of expressions including newly created for-loops that will represent - // the stream parallelization, and the relevant expressions grouped inside the - // for-loops bodies. + + // Step 1: Group expressions into stream-parallel regions + // This step identifies which expressions can be merged into single stream + // for-loops + // + // After this step, new_top_level_exprs contains a + // list of expressions including newly created for-loops representing + // the stream parallelization containing and the relevant expressions for (auto expr : hic->topLevelExprs()) { - // we only support exprs having at most 1 output for now + // Skip expressions with no outputs if (expr->outputs().size() == 0) { new_top_level_exprs.push_back(expr); continue; } + + // Verify single output constraint NVF_CHECK( expr->outputs().size() == 1, "Each expr should have at most one output."); + + // Get the output tensor and check for stream parallelization TensorView* output = expr->output(0)->as(); - // retrieves the Loop IterDomain that is stream parallelized, if any IterDomain* stream_axis = getStreamAxis(output->getLoopDomain()); + + // If no stream axis, keep expression as is if (stream_axis == nullptr) { - // if the consumer is not stream parallelized, it means the expr need not - // be inside a stream for-loop new_top_level_exprs.push_back(expr); continue; } + + // Verify expression can be handled as a standalone host operation NVF_ERROR( HostIrLower::isLoweredAsStandaloneHostOp(expr), "Stream parallel type not supported for expr ", expr); - // find the corresponding stream axis but in the Logical (and not Loop - // Domain) + + // Find the stream axis in the logical (and not loop) domain auto it_logical_stream_axis = std::find( output->getLogicalDomain().begin(), output->getLogicalDomain().end(), stream_axis); - // for now we do not support split/merge stream axis + + // Verify stream axis is not split/merged NVF_ERROR( it_logical_stream_axis != output->getLogicalDomain().end(), "Cannot stream parallelize on a split/merge axis ", stream_axis); - // we don't support reducing or broadcasting a stream axis + + // Verify stream axis is an iteration axis (not reduction/broadcast) NVF_CHECK( stream_axis->getIterType() == IterType::Iteration, "Stream axis ", stream_axis, " should be an iteration axis."); - // check if the current expr can be merged with the previous stream for-loop - // We consider the previous expression to check whether the expr should - // create a new stream for-loop or be integrated into the previous one + + // Check if expression can be merged with previous stream for-loop if (!new_top_level_exprs.empty() && new_top_level_exprs.back()->isA() && id_model.idGraph(IdMappingMode::ALMOSTEXACT) @@ -119,21 +143,16 @@ void StreamParallelType::runPass(Fusion* fusion) { .strictAreMapped( stream_axis, new_top_level_exprs.back()->as()->iterDomain())) { - // merge with previous for-loop + // Merge with existing for-loop new_top_level_exprs.back()->as()->body().push_back(expr); } else { - // create a new for-loop - auto* j = IrBuilder::create( - DataType::Index); // running index of the for-loop - auto* start = hic->zeroVal(); - auto* stop = stream_axis->extent(); - auto* step = hic->oneVal(); + // Create new for-loop for stream parallelization auto* for_loop = IrBuilder::create( stream_axis, - /*index=*/j, - start, - stop, - step, + /*index=*/IrBuilder::create(DataType::Index), + /*start=*/hic->zeroVal(), + /*stop=*/stream_axis->extent(), + /*step=*/hic->oneVal(), /*vectorize=*/false, /*vectorize_shift=*/nullptr, /*unroll_required=*/false, @@ -145,30 +164,36 @@ void StreamParallelType::runPass(Fusion* fusion) { } } - // Step 2. Setup each for loop's body by Slicing the tensors. + // Step 2: Process each for-loop's body by slicing tensors + // This step handles the actual tensor slicing for stream parallelization std::vector top_level_exprs = std::move(new_top_level_exprs); new_top_level_exprs.clear(); + for (auto top_level_expr : top_level_exprs) { - // TODO: change in place? consr issue if (!top_level_expr->isA()) { new_top_level_exprs.push_back(top_level_expr); continue; } + auto* for_loop = top_level_expr->as(); - // this will contain the new body of the current for-loop std::vector new_loop_body; + // Process each expression in the loop body std::vector current_loop_body = for_loop->body().exprs(); for (auto it_expr = current_loop_body.begin(); it_expr != current_loop_body.end(); ++it_expr) { Expr* expr = *it_expr; + + // Process input tensors for (auto* input : ir_utils::filterByType(expr->inputs())) { + // Find stream axis index in input tensor int64_t input_stream_id_logical_index = -1; for (auto id : input->getLoopDomain()) { if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) .disjointValSets() .strictAreMapped(for_loop->iterDomain(), id)) { + // Verify only one stream axis exists NVF_CHECK( input_stream_id_logical_index == -1, "Expected at most one axis mapping to the stream axis ", @@ -177,6 +202,8 @@ void StreamParallelType::runPass(Fusion* fusion) { input, " loop's domain ", input->getLoopDomain()); + + // Find stream axis in logical domain auto it_input_stream_id_logical = std::find( input->getLogicalDomain().begin(), input->getLogicalDomain().end(), @@ -193,15 +220,21 @@ void StreamParallelType::runPass(Fusion* fusion) { input->getLogicalDomain().begin(), it_input_stream_id_logical); } } + + // Skip if no stream axis found if (input_stream_id_logical_index == -1) { continue; } + + // Create sliced tensor for current stream iteration TensorView* input_j = select( input, input_stream_id_logical_index, for_loop->index(), /*keep_reduction_axis=*/true); new_loop_body.push_back(input_j->definition()); + + // Update all expressions using this input for (auto it_running_expr = current_loop_body.begin(); it_running_expr != current_loop_body.end(); ++it_running_expr) { @@ -216,12 +249,15 @@ void StreamParallelType::runPass(Fusion* fusion) { } } + // Process output tensors for (auto* output : ir_utils::filterByType(expr->outputs())) { + // Find stream axis index in output tensor int64_t output_stream_id_logical_index = -1; for (auto id : output->getLoopDomain()) { if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) .disjointValSets() .strictAreMapped(for_loop->iterDomain(), id)) { + // Verify only one stream axis exists NVF_CHECK( output_stream_id_logical_index == -1, "Expected at most one axis mapping to the stream axis ", @@ -230,6 +266,8 @@ void StreamParallelType::runPass(Fusion* fusion) { output, " loop's domain ", output->getLoopDomain()); + + // Find stream axis in logical domain auto it_output_stream_id_logical = std::find( output->getLogicalDomain().begin(), output->getLogicalDomain().end(), @@ -247,17 +285,25 @@ void StreamParallelType::runPass(Fusion* fusion) { it_output_stream_id_logical); } } + + // Skip if no stream axis found if (output_stream_id_logical_index == -1) { continue; } + + // Create sliced tensor for current stream iteration TensorView* output_j = select( output, output_stream_id_logical_index, for_loop->index(), /*keep_reduction_axis=*/true); + + // Allocate memory for the output tensor new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); new_loop_body.push_back(output_j->definition()); + + // Update all expressions using this output for (auto it_running_expr = current_loop_body.begin(); it_running_expr != current_loop_body.end(); ++it_running_expr) { @@ -265,6 +311,7 @@ void StreamParallelType::runPass(Fusion* fusion) { for (auto* running_output : ir_utils::filterByType(running_expr->outputs())) { if (running_output == output) { + // Create alias for the sliced output TensorView* output_j_alias = ops::newValLike( output_j, output_j->dtype(), /*keep_reduction_axis=*/true) @@ -272,19 +319,14 @@ void StreamParallelType::runPass(Fusion* fusion) { hic->markAlias(output_j, output_j_alias); *it_running_expr = ir_utils::transferDefinitionToNewOutputs( running_expr, {output_j_alias}); - if (Communication* comm = dynamic_cast( - output_j_alias->definition()); - comm && comm->type() == CommunicationType::Allgather) { - std::cout << "HERE, with expr:" << *it_running_expr - << std::endl; - } } } } } new_loop_body.push_back(*it_expr); } - // reseting the for-loop body + + // Update for-loop body with processed expressions for_loop->body().clear(); for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); @@ -292,8 +334,7 @@ void StreamParallelType::runPass(Fusion* fusion) { new_top_level_exprs.push_back(top_level_expr); } - // Step 3. Finalize the for-loop bodies by adding the stream setup and - // synchronization + // Step 3: Add stream management and synchronization for (auto* top_level_expr : new_top_level_exprs) { if (!top_level_expr->isA()) { continue; @@ -301,46 +342,44 @@ void StreamParallelType::runPass(Fusion* fusion) { auto* for_loop = top_level_expr->as(); std::vector new_loop_body; - // Get the current stream to later synchronize subsequent new streams + // Get current stream for later synchronization auto* get_current_stream = IrBuilder::create(); hir::Stream* original_stream = get_current_stream->stream(); new_loop_body.push_back(get_current_stream); - // set the stream to the one corresponding to the current for-loop index - auto* j = for_loop->index(); + // Set up stream for current iteration auto* number_of_streams = IrBuilder::create("numberOfStreams", DataType::Int); - auto* stream_index = mod(j, number_of_streams); + auto* stream_index = mod(for_loop->index(), number_of_streams); auto* stream = IrBuilder::create(stream_index); auto* set_stream = IrBuilder::create(stream); new_loop_body.push_back(set_stream); - // sync the new stream with the original stream + // Synchronize with original stream auto* initial_sync_stream = IrBuilder::create(original_stream); new_loop_body.push_back(initial_sync_stream); - // add the actual exprs to the for-loop body + // Add the actual computation expressions for (auto* expr : for_loop->body().exprs()) { new_loop_body.push_back(expr); } - // set back the original stream + // Restore original stream and synchronize auto* set_back_original_stream = IrBuilder::create(original_stream); new_loop_body.push_back(set_back_original_stream); - // synchronize original stream with the for-loop's streams auto* sync_stream = IrBuilder::create(stream); new_loop_body.push_back(sync_stream); - // reset the for-loop's body to the one we constructed. + // Update for-loop body with stream management for_loop->body().clear(); for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); } } - // reset hic topLevelExprs to new_top_level_exprs + // Update the container's top-level expressions hic->resetTopLevelExprs(new_top_level_exprs); } diff --git a/csrc/preseg_passes/stream_parallel_type.h b/csrc/preseg_passes/stream_parallel_type.h index a9600809e21..9c0c39efe87 100644 --- a/csrc/preseg_passes/stream_parallel_type.h +++ b/csrc/preseg_passes/stream_parallel_type.h @@ -12,7 +12,16 @@ namespace nvfuser::preseg_passes { -// A pass used in HostIrLower that takes a HostIrContainer as input, reads the TensorView's ParallelType::Stream, and modify the the HostIrContainer's top level expressions with the corresponding Host For Loops, which bodies contain stream assignement, selecting on tensor's axis, and the exprs on those sliced tensors. After this pass, the ParallelType::Stream is removed from the TensorView's axis. +// A pass used in HostIrLower that takes a HostIrContainer as input, reads the +// TensorView's ParallelType::Stream, and modify the the HostIrContainer's top +// level expressions with the corresponding Host For Loops, which bodies contain +// stream assignement, selecting on tensor's axis, and the exprs on those sliced +// tensors. After this pass, the ParallelType::Stream is removed from the +// TensorView's axis. +// +// An illustration of the pass can be found in the tests +// `test_host_ir_stream_lowering.cpp` +// with the option `NVFUSER_DUMP=host_ir`. class StreamParallelType : public OptimizationPass { friend class OptimizationPass; From b6c54f2e47246535f6b2ea915bf236dabe54f087 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 16 Apr 2025 03:48:24 -0700 Subject: [PATCH 03/68] fix rebase --- csrc/preseg_passes/stream_parallel_type.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/preseg_passes/stream_parallel_type.cpp b/csrc/preseg_passes/stream_parallel_type.cpp index 12fe2f6a285..82f1b3d0e67 100644 --- a/csrc/preseg_passes/stream_parallel_type.cpp +++ b/csrc/preseg_passes/stream_parallel_type.cpp @@ -112,7 +112,7 @@ void StreamParallelType::runPass(Fusion* fusion) { // Verify expression can be handled as a standalone host operation NVF_ERROR( - HostIrLower::isLoweredAsStandaloneHostOp(expr), + HostIrLower::isLowerableAsStandaloneHostOp(expr), "Stream parallel type not supported for expr ", expr); From 32a8d552befc8deb4f4fff1deb3d88558bb88074 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 16 Apr 2025 08:23:11 -0700 Subject: [PATCH 04/68] temporarily disable stream pass also in the python test --- csrc/python_frontend/fusion_definition.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index c48abc9dbdc..ad7b8baf2d6 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -452,6 +453,10 @@ std::pair> FusionDefinition:: if (scheds->multi_device_executor == nullptr) { MultiDeviceExecutorParams params; params.lower.communicator_backend = backend_type_; + // Disable StreamParallelType pass temporarily as proper stream lowering gets + // implemented + preseg_passes::OptimizationPassGuard guard( + false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), Communicator::getInstance(), From afbd020b9e2e6b9461bb335532d44bb884ebc533 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 16 Apr 2025 08:38:34 -0700 Subject: [PATCH 05/68] lint --- csrc/python_frontend/fusion_definition.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index ad7b8baf2d6..950b2bd148a 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -453,10 +453,10 @@ std::pair> FusionDefinition:: if (scheds->multi_device_executor == nullptr) { MultiDeviceExecutorParams params; params.lower.communicator_backend = backend_type_; - // Disable StreamParallelType pass temporarily as proper stream lowering gets - // implemented - preseg_passes::OptimizationPassGuard guard( - false); + // Disable StreamParallelType pass temporarily as proper stream lowering + // gets implemented + preseg_passes::OptimizationPassGuard + guard(false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), Communicator::getInstance(), From 165bd1bab236119a85e9e4dd5843424887960470 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 16 Apr 2025 08:57:55 -0700 Subject: [PATCH 06/68] move stream_parallel_type to host_ir/pass folder --- CMakeLists.txt | 2 +- csrc/host_ir/lower.cpp | 2 +- csrc/{preseg_passes => host_ir/pass}/stream_parallel_type.cpp | 2 +- csrc/{preseg_passes => host_ir/pass}/stream_parallel_type.h | 0 csrc/python_frontend/fusion_definition.cpp | 2 +- tests/cpp/test_host_ir_stream_lowering.cpp | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename csrc/{preseg_passes => host_ir/pass}/stream_parallel_type.cpp (99%) rename csrc/{preseg_passes => host_ir/pass}/stream_parallel_type.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 96a447055ea..dcf94d4a3a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,7 +212,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp - ${NVFUSER_SRCS_DIR}/preseg_passes/stream_parallel_type.cpp + ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 1a74d9a9f01..c36fae09e0a 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -19,7 +20,6 @@ #include #include #include -#include #include #include diff --git a/csrc/preseg_passes/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp similarity index 99% rename from csrc/preseg_passes/stream_parallel_type.cpp rename to csrc/host_ir/pass/stream_parallel_type.cpp index 82f1b3d0e67..f1419e3c626 100644 --- a/csrc/preseg_passes/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -16,7 +17,6 @@ #include #include #include -#include namespace nvfuser::preseg_passes { diff --git a/csrc/preseg_passes/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h similarity index 100% rename from csrc/preseg_passes/stream_parallel_type.h rename to csrc/host_ir/pass/stream_parallel_type.h diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index 950b2bd148a..d6e552032b1 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -7,11 +7,11 @@ // clang-format on #include #include +#include #include #include #include #include -#include #include #include #include diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index 9f3d9f432ed..f6d74caea87 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -11,12 +11,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index e9db27cfd7f..7b233bc47db 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -9,9 +9,9 @@ #include #include #include +#include #include #include -#include #include namespace nvfuser { From 59ba13cfc646ec5474afbc5c7d06dd63a4e4b9eb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Wed, 16 Apr 2025 12:59:34 -0400 Subject: [PATCH 07/68] Print all ID expressions in tv->printTransforms (#4258) Currently, when we call `tv->printTransforms()`, we print root->logical then logical->loop expressions. In case the allocation domain is not on the path from logical to loop, those allocation expressions were not printed at all. This PR uses `tv->domain()->allExprs()` instead to print all the ID expressions. For an example, see this test: ```c++ TEST_F(AllocationDomainTest, PrintTransforms) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); // iS0{i0} auto* root0 = IterDomainBuilder( fusion->zeroVal(), IrBuilder::create(DataType::Index)).build(); // iS1{ceilDiv(i0, 5)}, iS2{5} = split(iS0) IterDomain *logical0, *logical1; std::tie(logical0, logical1) = IterDomain::split( root0, IrBuilder::create(5L, DataType::Index), /*inner_split=*/true); // iS3{ceilDiv(i0, 7)}, iS4{7} = split(iS0) IterDomain *alloc0, *alloc1; std::tie(alloc0, alloc1) = IterDomain::split( root0, IrBuilder::create(7L, DataType::Index), /*inner_split=*/true); // iS5{ ceilDiv(i0, 7) * 7 } = merge(iS3, iS4) IterDomain* loop0 = IterDomain::merge(alloc0, alloc1); std::vector root{root0}; std::vector logical{logical0, logical1}; std::vector allocation{alloc0, alloc1}; std::vector loop{loop0}; auto* td = IrBuilder::create(root, logical, allocation, loop); auto* tv = IrBuilder::create(td, DataType::Float); std::cout << "tv: " << tv->toString() << std::endl; tv->printTransforms(); } ``` On TOT, this prints ``` [ RUN ] AllocationDomainTest.PrintTransforms tv: T0_l_float[iS5{( ( ceilDiv(i0, 7) ) * 7 )}] root domain : (iS0{i0}) Split: iS0{i0} by factor 5 -> iS1{( ceilDiv(i0, 5) )}, iS2{5} logical domain : (iS1{( ceilDiv(i0, 5) )}, iS2{5}) allocation domain : (iS3{( ceilDiv(i0, 7) )}, iS4{7}) contiguity: f f loop domain : (iS5{( ( ceilDiv(i0, 7) ) * 7 )}) [ OK ] AllocationDomainTest.PrintTransforms (1 ms) ``` From this we cannot tell how the loop or allocation domains relate to the root or logical domains. After this PR it prints: ``` [ RUN ] AllocationDomainTest.PrintTransforms tv: T0_l_float[iS5{( ( ceilDiv(i0, 7) ) * 7 )}] root domain : (iS0{i0}) Split: iS0{i0} by factor 5 -> iS1{( ceilDiv(i0, 5) )}, iS2{5} logical domain : (iS1{( ceilDiv(i0, 5) )}, iS2{5}) allocation domain : (iS3{( ceilDiv(i0, 7) )}, iS4{7}) contiguity: f f Split: iS0{i0} by factor 7 -> iS3{( ceilDiv(i0, 7) )}, iS4{7} Merge: iS3{( ceilDiv(i0, 7) )} and iS4{7} -> iS5{( ( ceilDiv(i0, 7) ) * 7 )} Split: iS0{i0} by factor 5 -> iS1{( ceilDiv(i0, 5) )}, iS2{5} loop domain : (iS5{( ( ceilDiv(i0, 7) ) * 7 )}) [ OK ] AllocationDomainTest.PrintTransforms (1 ms) ``` Note that currently this redundantly prints root->logical transforms since we print root->logical earlier. We could just remove that section, meaning effectively that this PR would move those expressions to the same section where loop and allocation transforms are printed. --- csrc/ir/iostream.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/csrc/ir/iostream.cpp b/csrc/ir/iostream.cpp index 2b7e0c8f57a..59e01acbeb5 100644 --- a/csrc/ir/iostream.cpp +++ b/csrc/ir/iostream.cpp @@ -152,15 +152,10 @@ void IrTransformPrinter::printTransforms(const TensorView* tv) { os() << " contiguity: " << tv->domain()->getContiguityString() << "\n"; - const auto& from = tv->getLogicalDomain(); - const auto& loop = tv->getLoopDomain(); - const auto all_exp = DependencyCheck::getAllExprsBetween( - {from.begin(), from.end()}, {loop.begin(), loop.end()}); - - for (const auto exp : all_exp) { + for (const auto exp : tv->domain()->allExprs()) { os() << " " << exp->toString(); } - os() << " loop domain : (" << toDelimitedString(loop) << ")\n"; + os() << " loop domain : (" << toDelimitedString(tv->getLoopDomain()) << ")\n"; } std::ostream& operator<<(std::ostream& os, const Statement* stmt) { From 85a9463201ff0ab12fbb49638f1629ef929c643f Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 16 Apr 2025 13:04:51 -0700 Subject: [PATCH 08/68] InsertReshardingsPass decomposes matmul/linear+ReduceScatter. (#4239) As a follow-up to #4209, handle ReduceScatter as well. --- csrc/multidevice/utils.cpp | 17 ++++----- csrc/preseg_passes/insert_reshardings.cpp | 39 ++++++++++++++++++- tests/python/multidevice/fixtures.py | 15 +++++++- tests/python/multidevice/test_dtensor.py | 14 ++++--- tests/python/multidevice/test_matmul.py | 46 +++++++++++++++++++++++ tests/python/multidevice/test_overlap.py | 3 -- 6 files changed, 113 insertions(+), 21 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 475d852028b..ac6d7152592 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -131,6 +131,13 @@ std::unordered_map mapDeviceParallelTypeToId( continue; } + // rDIDx{i0}, usually a product of an Allreduce or a ReduceScatter, is + // treated as replicated. This way `iDIDx{i0} => rDIDx{i0}` is considered + // resharding. + if (id->isReduction()) { + continue; + } + NVF_ERROR( parallel_type_to_id.try_emplace(parallel_type, id).second, "Found multiple loop IterDomains with the same parallel type (", @@ -564,16 +571,6 @@ bool haveDifferentShardings( return false; } - // iDIDx{i0} => rDIDx{i0} triggers an allreduce even though the two `i0`s - // are equivalent. - if (c_id->isReduction()) { - NVF_ERROR( - !p_id->isReduction(), - "Reduction IterDomains in the producer's logical shouldn't be mapped: ", - p_id); - return false; - } - return simplifyExpr( SimplifyingIrBuilder::eqExpr(p_index, c_index), /*variables=*/{}, diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index eab98af208e..264facb743d 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -148,6 +148,8 @@ void rFactorLoopSplits(Fusion* fusion) { std::vector rfactor_axes; rfactor_axes.reserve(tv->nDims()); + std::unordered_set reduced_parallel_types; + for (auto&& [i, loop_id] : enumerate(tv->getLoopDomain())) { if (!loop_id->isReduction()) { // rFactor only applies to reduction dimensions. @@ -162,14 +164,47 @@ void rFactorLoopSplits(Fusion* fusion) { continue; } - if (!loop_id->isParallelized()) { + const ParallelType parallel_type = loop_id->getParallelType(); + if (parallel_type == ParallelType::Serial) { // rFactor non-parallelized IDs so they get reduced locally. rfactor_axes.push_back(i); + } else { + reduced_parallel_types.insert(parallel_type); } } if (!rfactor_axes.empty()) { - tv->rFactor(rfactor_axes); + TensorView* local = tv->rFactor(rfactor_axes); + // Before rFactor: + // + // [i{m} i{n} r{k}] + // / \ / \. + // iDIDx{d} i{n/d} rDIDx{d} r{k/d} + // + // After rFactor: + // + // r{k} + // / \. + // [i{m} i{n} iDIDx{d} r{k/d}] + // / \. + // iDIDx{d} i{n/d} + // + // | + // | reduce + // v + // + // [i{m} i{n} rDIDx{d}] + // / \. + // iDIDx{d} i{n/d} + // + // The TensorView returned by rFactor has two iDIDx, which is disallowed. + // The following code unparallelizes the first iDIDx{d}. + for (IterDomain* loop_id : local->getLoopDomain()) { + if (!loop_id->isRFactorProduct() && + reduced_parallel_types.count(loop_id->getParallelType())) { + loop_id->parallelize(ParallelType::Serial); + } + } } } } diff --git a/tests/python/multidevice/fixtures.py b/tests/python/multidevice/fixtures.py index f14a884fbb9..71da16c2f14 100644 --- a/tests/python/multidevice/fixtures.py +++ b/tests/python/multidevice/fixtures.py @@ -46,8 +46,21 @@ def shard_tensor( return mesh.shard_tensor(t, dim, self.rank).cuda(self.rank) -@pytest.fixture(scope="session") +@pytest.fixture def multidevice_test(): + # Reset the cache here to work around a bug in FusionDefintion.execute. + # FusionDefinition._finalize_definition maps the same `definition` to the + # same FusionSchedules and therefore the same FusionExecutorCache. This was + # correct until multiple FusionDefinitions started to have the same + # `definition` but different `multidevice_schedule`s. This seems to be a + # known issue beacuse a similar workaround for single-GPU schedules is done + # here: + # https://github.com/NVIDIA/Fuser/blob/f44f1913c26f8325099ab6fe46d678cbea435658/tests/python/test_schedule_ops.py#L115. + # + # I couldn't think of an easy way to fix this issue properly. Also, that + # FusionCache is obsolete makes me less motivated to do so. + nvfuser.FusionCache.reset() + fixture = MultideviceTest() yield fixture # Sync all ranks after each test for isolation. diff --git a/tests/python/multidevice/test_dtensor.py b/tests/python/multidevice/test_dtensor.py index 0cdb52cda27..f1dfc38f62f 100644 --- a/tests/python/multidevice/test_dtensor.py +++ b/tests/python/multidevice/test_dtensor.py @@ -19,14 +19,18 @@ multidevice_test = fixtures.multidevice_test +# Set up the default process group for torch APIs like +# dist.device_mesh.init_device_mesh. @pytest.fixture(scope="module") -def setup_process_group(multidevice_test): +def setup_process_group(): + communicator = nvfuser.Communicator.instance() + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. dist.init_process_group( backend="nccl", init_method="tcp://localhost:29500", - world_size=multidevice_test.size, - rank=multidevice_test.rank, + world_size=communicator.size(), + rank=communicator.rank(), ) yield dist.destroy_process_group() @@ -94,7 +98,7 @@ def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]: @pytest.mark.mpi -def test_plus_one(setup_process_group): +def test_plus_one(setup_process_group, multidevice_test): def define_fusion(fd: FusionDefinition): inp = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) one = fd.define_scalar(1.0, dtype=DataType.Float) @@ -118,7 +122,7 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi -def test_linear(setup_process_group): +def test_linear(setup_process_group, multidevice_test): @dataclass class LinearConfig: def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index c09e4142e26..3002c4e6caf 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -161,6 +161,52 @@ def multidevice_schedule(self): torch.testing.assert_close(out.cpu(), unsharded_out, rtol=1.3e-6, atol=1e-3) +@pytest.mark.mpi +def test_linear_reduce_scatter(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.DeviceMesh(range(d)) + e = 768 + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor([-1, -1, d * e]) + self.weight = self.define_tensor([e, d * e]) + self.out = self.ops.linear(self.inp, self.weight, None) + self.add_output(self.out) + + def multidevice_schedule(self): + for t in [self.inp, self.weight, self.out]: + self.sched._set_device_mesh(t, mesh) + self.sched.split(t, -1, d, False) + self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(t) + + # Scatter + self.sched.split(self.out, 1, d, False) + self.sched.parallelize(self.out, 1, nvfuser.ParallelType.mesh_x) + + torch.cuda.set_device(multidevice_test.local_rank) + + b, s = 2, 1024 + unsharded_inp = torch.randn(b, s, d * e) + unsharded_weight = torch.randn(e, d * e) + + inp = multidevice_test.shard_tensor(unsharded_inp, -1, mesh) + weight = multidevice_test.shard_tensor(unsharded_weight, -1, mesh) + + fd = Model() + (out,), _ = fd.execute([inp, weight]) + + unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None) + # rtol is the same as the default for fp32. atol is slightly increased. + torch.testing.assert_close( + out, + multidevice_test.shard_tensor(unsharded_out, 1, mesh), + rtol=1.3e-6, + atol=1e-3, + ) + + @pytest.mark.mpi def test_matmul_allreduce(multidevice_test): d, b, s, e = multidevice_test.size, 1, 4, 8 diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 460a1f0edd8..0ad770e022c 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -88,9 +88,6 @@ def test_overlap_allgather_matmul_stream_outermost( ins = [x, weight, bias] out_ref = torch.nn.functional.linear(x_unsharded, weight.cpu(), bias.cpu()) - # Resetting the cache here is necessary to workaround a bug that would need a proper fix. If not avoiding the cache, there is an issue for the second test that is being run. More specifically, the second time we define the fusion, we hit the cache in https://github.com/NVIDIA/Fuser/blob/6ff60e2a320733a2f49de57007d6bb45000107cd/csrc/python_frontend/fusion_definition.cpp#L95 . Later, when we call _set_device_mesh, we get a "thro out of range" here https://github.com/NVIDIA/Fuser/blob/6ff60e2a320733a2f49de57007d6bb45000107cd/csrc/python_frontend/schedule_bindings.cpp#L60 because the FusionDefinition has not run so it doesn't contain any state. - nvfuser.FusionCache.reset() - fd = OverlapAGMatmulStreamOutermost(m, k, n, s, d, backend_type) # warmup From 34fa83b0afbfab92c3ee35ec6aace8f22d975455 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Wed, 16 Apr 2025 15:35:34 -0700 Subject: [PATCH 09/68] Create kir::Continue for persistent grid short-circuit (#4260) This PR adds support for `continue` in the generated CUDA kernels. **Why?** For persistent kernels, we need to efficiently skip OOB tiles to avoid unnecessary work. Required for: https://github.com/NVIDIA/Fuser/pull/4243 --- csrc/codegen.cpp | 4 ++++ csrc/device_lower/pass/index.cpp | 5 +++++ csrc/device_lower/pass/index.h | 1 + csrc/dispatch.h | 1 + csrc/kernel_ir.cpp | 19 +++++++++++++++++++ csrc/kernel_ir.h | 17 +++++++++++++++++ 6 files changed, 47 insertions(+) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 92e47b3d01e..0882684a11d 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -3749,6 +3749,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n"; } + void handle(const kir::Continue* cont) final { + indent() << "continue;\n"; + } + void handle(const kir::Return* ret) final { indent() << "return;\n"; } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 25605ddd0d4..8bfe1c242e8 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2805,6 +2805,11 @@ void IndexLowering::handle(const kir::SetMaxNReg* maxnreg) { pushBack(const_cast(maxnreg)); // NOLINT } +void IndexLowering::handle(const kir::Continue* cont) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(cont)); // NOLINT +} + void IndexLowering::handle(const kir::Return* ret) { // TODO(kir): remove the need for const_cast pushBack(const_cast(ret)); // NOLINT diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 7e1699821c2..2bc0e973709 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -77,6 +77,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::FenceAsyncProxy*) final; void handle(const kir::WgMmaFence*) final; void handle(const kir::SetMaxNReg*) final; + void handle(const kir::Continue*) final; void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 218ccd8267a..1ca0a2c460d 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -124,6 +124,7 @@ class Val; f(FenceAsyncProxy); \ f(WgMmaFence); \ f(SetMaxNReg); \ + f(Continue); \ f(Return); \ f(MBarrierInit); \ f(MBarrierInvalidate); \ diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index b2d30c93b02..546e4c970a3 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -670,6 +670,25 @@ std::string SetMaxNReg::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(SetMaxNReg) +Continue::Continue(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +std::string Continue::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "continue\n"; + return ss.str(); +} + +std::string Continue::toInlineString(int indent_size) const { + NVF_CHECK(false, "Continue can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Continue) + Return::Return(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 5b08664774d..2cc5df4deed 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -42,6 +42,7 @@ class GridSync; class FenceAsyncProxy; class WgMmaFence; class SetMaxNReg; +class Continue; class Return; class MBarrierInit; class MBarrierInvalidate; @@ -613,6 +614,22 @@ class SetMaxNReg final : public Expr { } }; +class Continue final : public Expr { + public: + using Expr::Expr; + + explicit Continue(IrBuilderPasskey passkey); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Continue"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; +}; + class Return final : public Expr { public: using Expr::Expr; From 5ecf7fec7a52ccba815390df3752c859980838ab Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 16 Apr 2025 17:10:08 -0700 Subject: [PATCH 10/68] Remove several uses of NVFUSER_DISTRIBUTED (#4255) --- csrc/multidevice/c10d_mock.h | 30 +++++++++++++++++++++++++++++- csrc/multidevice/ipc_handle.cpp | 4 ---- tests/cpp/test_multidevice_ipc.cpp | 15 +++------------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/csrc/multidevice/c10d_mock.h b/csrc/multidevice/c10d_mock.h index b4ac0152ada..3befb61323f 100644 --- a/csrc/multidevice/c10d_mock.h +++ b/csrc/multidevice/c10d_mock.h @@ -5,6 +5,19 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on + +// This file provides a mock implementation of c10d that builds but doesn't +// function. +// +// nvFuser is sometimes built on a pytorch without c10d. When that +// happens, c10d isn't linked, NVFUSER_DISTRIBUTED is undefined and the +// multi-GPU component of nvFuser is expected to be disabled. +// +// Instead of adding `#ifdef NVFUSER_DISTRIBUTED` in too many places, this file +// provides a buildable mock implementation of c10d to keep nvFuser code less +// divergent. This implementation won't run because tests and user code are +// guarded by Communicator::is_available. + #pragma once #include @@ -170,6 +183,21 @@ struct TCPStoreOptions { static constexpr uint16_t kDefaultPort = 0; }; -class TCPStore : public torch::CustomClassHolder {}; +class TCPStore : public torch::CustomClassHolder { + public: + std::vector get(const std::string&) { + return {}; + } + + void set(const std::string&, const std::vector&) {} + + bool check(const std::vector&) { + return false; + } + + bool deleteKey(const std::string&) { + return false; + } +}; } // namespace c10d diff --git a/csrc/multidevice/ipc_handle.cpp b/csrc/multidevice/ipc_handle.cpp index dd96b5a72e8..6bb700dc2de 100644 --- a/csrc/multidevice/ipc_handle.cpp +++ b/csrc/multidevice/ipc_handle.cpp @@ -95,7 +95,6 @@ std::string IpcHandleCache::getTcpStoreKey( void IpcHandleCache::exchangeHandles( const std::vector& communications) { -#ifdef NVFUSER_DISTRIBUTED Communicator* communicator = &Communicator::getInstance(); const int64_t my_rank = communicator->deviceId(); @@ -152,9 +151,6 @@ void IpcHandleCache::exchangeHandles( insert(communication, std::move(ipc_handles)); } -#else // NVFUSER_DISTRIBUTED - NVF_ERROR(false, "NVFUSER_DISTRIBUTED is not defined"); -#endif // NVFUSER_DISTRIBUTED } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_ipc.cpp b/tests/cpp/test_multidevice_ipc.cpp index 30daf6db145..ba574c0f676 100644 --- a/tests/cpp/test_multidevice_ipc.cpp +++ b/tests/cpp/test_multidevice_ipc.cpp @@ -34,7 +34,7 @@ TEST_F(IpcTest, IpcMemHandle) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // Allocate and setup GPU buffers constexpr size_t kBufferSize = sizeof(int64_t); const int64_t num_devices = communicator_->size(); @@ -75,16 +75,13 @@ TEST_F(IpcTest, IpcMemHandle) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // TL;DR: We can do pointer arithmetic on the importer side. IOW, the pointer // can be used as a regular pointer on the importer side. @@ -131,16 +128,13 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // TL;DR: We CANNOT do pointer arithmetic on the exporter side! The IPC handle // points to the beginning of the allocated buffer. @@ -189,9 +183,6 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } // cuStreamWriteValue32 and cuStreamWaitValue32 are CUDA driver API used in the From 14849978b4cd5c0471b0434a993e43d0a566df67 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Thu, 17 Apr 2025 08:59:19 -0400 Subject: [PATCH 11/68] warp specializied tma persistent kernel, step-2, use TMA load (#4240) This PR follows https://github.com/NVIDIA/Fuser/pull/4215. It is step-2a of implementing warp specializied tma persistent kernel described in the design doc. Changes: (1) change from `cpAsync` to `1D TMA load` (2) Revise scheduler to handle special transformations and inlining of TMA loaded tensors. (3) Revise batch size in test to avoid error due to the missing of predicate for 1D TMA load, will change back after appropriate predicate is added. --------- Co-authored-by: jjsjann123 Co-authored-by: Ryan Spring --- csrc/scheduler/normalization_inner_outer.cpp | 113 ++++++++++++++++-- csrc/scheduler/normalization_utils.cpp | 14 ++- .../test_combined_inner_outer_reduction.cpp | 4 +- 3 files changed, 113 insertions(+), 18 deletions(-) diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index fcb7260501d..4e9f2f93bfe 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -6,6 +6,7 @@ */ // clang-format on #include +#include #include #include #include @@ -213,7 +214,11 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( // reload from gmem for each iteration. // Note: in current use cases (layer norm bwd and RMS norm bwd), there are // outer broadcast tvs and always project to inputs. + // Warp specialized persistent kernel always cache inputs in shared memory, + // should project to inputs. const auto& outer_broadcast_tvs = getOuterBroadcastTvs(fusion, reduction_tvs); + bool skip_check_buffer_size = !outer_broadcast_tvs.empty() || + isOptionEnabled(EnableOption::WarpSpecializedNormalization); normalization_scheduler_utils::BufferProjectionStrategy project_strategy = normalization_scheduler_utils::isProjectBufferToInputs( fusion, @@ -223,7 +228,7 @@ PersistentBufferStorageParams getPersistentBufferStorageParams( persistent_buffer_size_info, InnerOuterPersistentKernelScheduler::schedulerType(), /*can_use_smem_persistent=*/true, - outer_broadcast_tvs.empty()); + !skip_check_buffer_size); buffer_params.project_to_input = (project_strategy == @@ -1036,12 +1041,7 @@ std::unique_ptr innerOuterWarpSpecializedTmaHeuristic( iop.bdimy, LaunchParams::UNINITIALIZED_VAL); - if (!rparams->smem_persistent_buffers.empty()) { - rparams->tag = - "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; - } else { - rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; - } + rparams->tag = "TMA Warp Specialized Persistent Heuristic.\n"; if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" @@ -1628,20 +1628,97 @@ void scheduleTmaWarpSpecializedInnerOuter( fusion->addOutput(output); } + // Collect tvs loaded with TMA, they require special scheduling. + std::vector tma_load_tvs; + if (rparams->tma_warp_specialized) { + for (auto tv : smem_consumers) { + auto smem_tv = ir_utils::getSoleProducerTv(tv); + if (std::find(tma_load_tvs.begin(), tma_load_tvs.end(), smem_tv) == + tma_load_tvs.end()) { + tma_load_tvs.emplace_back(smem_tv); + } + } + } + const bool is_unroll_or_vectorization = rparams->isUnrolled(); const bool is_vectorize = rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; const bool is_outer_grid_persistence = rparams->persistent_kernel && rparams->cross_grid_inner_reduction && !rparams->fastest_dim; - // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this - // propagation will not propagate to the final outer reduction. - reduction_scheduler_utils::propagateTransformation( - inner_reference_tv, boundaryNodesSet); + // Propagate transformations for inner reduction. + // Two steps are used since tma tvs are scheduled differently. + // Step-1, propagate iteration domain in inner reduction. + // Step-2, propagate reduction domain in inner reduction. + if (rparams->tma_warp_specialized) { + // Find the axis that splits the reduction domain and iteration domain. + int first_redu_axis = -1; + int n_dims = (int)inner_reference_tv->nDims(); + for (auto i = 0; i < n_dims; i++) { + if (inner_reference_tv->axis(i)->isReduction() || + inner_reference_tv->axis(i)->isRFactorProduct()) { + first_redu_axis = i; + break; + } + } + + // Step-1, propagate iteration domain in inner reduction. + // outer_reference_tvs are excluded since they are already scheduled + // with a different pattern for the final step of outer reduciton. + if (first_redu_axis > 0) { + TransformPropagator propagator(inner_reference_tv, first_redu_axis - 1); + std::vector all_tvs_except = ir_utils::allTvsExcept( + fusion, {outer_reference_tvs.begin(), outer_reference_tvs.end()}); + SetSelector selector({all_tvs_except.begin(), all_tvs_except.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } + + // Step-2, propagate reduction domain in inner reduction. + // (a) Tvs in boundaryNodesSet are excluded since they should follow outer + // reduction pattern. + // (b) TMA tvs are excluded since they require special scheduling. + // (3) Excluding tma tvs breaks the propagation path from inner reduction tv + // to cached_gmem which stores the results of the first-stage of outer + // reduction. The solution is adding a dummy output to link them. The same + // trick is used when projecting persistent buffers to inputs. + auto inner_reduction_input = + ir_utils::getSoleProducerTv(inner_reference_tv); + for (auto tv : cached_gmem) { + // T1(smem) --> T2 (l) --> T3 = OuterRedu(T2) --> T4(cached_gmem) + // outer_reduction_input: T2 + // partial_outer_redu_tv: T3 + auto partial_outer_redu_tv = ir_utils::getSoleProducerTv(tv); + auto outer_reduction_input = + ir_utils::getSoleProducerTv(partial_outer_redu_tv); + auto dummy_output = add(inner_reduction_input, outer_reduction_input); + fusion->addOutput(dummy_output); + dummy_outputs.emplace_back(dummy_output); + } + + // Tvs requiring special scheduling + std::unordered_set special_tvs{ + tma_load_tvs.begin(), tma_load_tvs.end()}; + for (auto tv : boundaryNodesSet) { + if (special_tvs.count(tv) == 0) { + special_tvs.emplace(tv); + } + } + TransformPropagator propagator(inner_reference_tv); + std::vector all_tvs_except_cache = ir_utils::allTvsExcept( + fusion, {special_tvs.begin(), special_tvs.end()}); + SetSelector selector( + {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } else { + reduction_scheduler_utils::propagateTransformation( + inner_reference_tv, boundaryNodesSet); + } reduction_scheduler_utils::propagateRFactor( inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); - // Don't allow parallelization propagation goes through boundaryNodesSet + // parallelization propagation const auto& selected_tvs_inner = scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); const auto& unroll_vectorizable_cached_tvs = @@ -1684,6 +1761,18 @@ void scheduleTmaWarpSpecializedInnerOuter( {selected_tvs_outer.begin(), selected_tvs_outer.end()}); } + // Up to this point, the outer dimension of the TMA tv is scheduled + // the same way as the inner reduction tv. However, the inner dimension + // has not been scheduled yet. Since 1D TMA allows unrestricted load size, + // we can simply parallelize the entire inner dimension using bulk. + // Example: 2D tensor, [BIDy, S, | Bulk] + // Example: 1D tensor, [Bulk] + if (rparams->tma_warp_specialized) { + for (auto tv : tma_load_tvs) { + tv->axis(-1)->parallelize(ParallelType::Bulk); + } + } + // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write // is guaranteed to be smaller or equal to input vectorization factor. if (rparams->vectorization_factor_tmp_gmem_write > 1) { diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index a92418998b8..07c64a4f9c2 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -1319,11 +1319,15 @@ std::vector movePersistentBufferToSmem( } if (use_smem) { tv->setMemoryType(MemoryType::Shared); - // When loading from global memory (gmem), use CpAsync with a short data - // path of gmem -> smem to reduce temporary register usage. Otherwise, the - // data path from gmem to shared memory (smem) follows this sequence: gmem - // -> L1 cache -> register -> smem. - if (supportCpAsync(tv) && is_cached_input) { + // Use 1D TMA, CpAsyncBulk + if (rparams->tma_warp_specialized && is_cached_input) { + tv->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulk); + } else if (supportCpAsync(tv) && is_cached_input) { + // When loading from global memory (gmem), use CpAsync with a short data + // path of gmem -> smem to reduce temporary register usage. Otherwise, + // the data path from gmem to shared memory (smem) follows this + // sequence: gmem -> L1 cache -> register -> smem. tv->definition()->as()->setOpType( LoadStoreOpType::CpAsync); tv->definition()->as()->setCacheOp(CacheOp::Unspecified); diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 18a6c439930..4c5a1591628 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1158,13 +1158,15 @@ TEST_P(TmaWarpSpecializedTest, RMSNormBwd) { __LINE__, __FILE__); } +// batch size is revised to 132*148 which is divisible by sm count on H100 & +// B200 will change back to 32 & 2048 after predicate for 1D TMA is added. INSTANTIATE_TEST_SUITE_P( , TmaWarpSpecializedTest, ::testing::Combine( testing::Values(true, false), testing::Values(DataType::Float, DataType::BFloat16), - testing::Values(32, 2048), + testing::Values(132 * 148), ::testing::Range((int64_t)1024, (int64_t)8193, (int64_t)1024)), [](const testing::TestParamInfo& info) -> std::string { From 9b9cd8fdbb3a7fe90694eef59c2ea1dd93cc8347 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 17 Apr 2025 14:22:49 -0400 Subject: [PATCH 12/68] Fix scheduling of split-K with smem_epilogue on Hopper (#4257) Introduces `cacheBefore` to match `cacheAfter` utility, which just propagates entries in `graph_` corresponding to new IDs in the cached tensors. Also avoids re-scheduling tensors if they are split-K sum tensors. There is a current limitation for 32-bit outputs where we skip stmatrix but our current vectorized stores encounter 2-way bank conflicts. This is probably not that important to perf and can be fixed in scheduling of that store in another PR. Fixes #4159 --- csrc/scheduler/hopper_multi_matmul.cpp | 49 +++------------ csrc/scheduler/hopper_multi_matmul.h | 8 --- csrc/scheduler/multi_matmul.cpp | 54 +++++++++++++++++ csrc/scheduler/multi_matmul.h | 14 +++++ tests/cpp/test_matmul.cpp | 82 ++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 50 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index ec3b509ba32..df83886bdf7 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -237,41 +237,6 @@ void HopperMultipleMatmulScheduler::reorderBlockTileTraversal( } } -TensorView* HopperMultipleMatmulScheduler::cacheAfter( - TensorView* orig, - LoadStoreOpType op_type, - CacheOp cache_op, - bool propagate_allocation_domain) { - const std::vector orig_alloc = orig->getMaybeAllocationDomain(); - - TensorView* c = - orig->cacheAfter(op_type, cache_op, propagate_allocation_domain); - - if (propagate_allocation_domain) { - const std::vector cache_alloc = c->getMaybeAllocationDomain(); - NVF_ERROR(orig_alloc.size() == cache_alloc.size()); - for (size_t i : arange(orig_alloc.size())) { - ValGroup vg = graph_->toGroup(orig_alloc[i]); - graph_->initializeVal(cache_alloc[i], vg); - } - } - - const std::vector orig_logical = - TensorDomain::noReductions(orig->getLogicalDomain()); - const std::vector cache_logical = c->getLogicalDomain(); - // in split-K we do rFactor which gives us a full = sum(partial) - // where partial has root domain that matches the logical domain of the - // original tensor. The logical domain contains Iteration transforms of the - // Reduction axis in the original mma output. - NVF_ERROR(orig_logical.size() == cache_logical.size()); - for (size_t i : arange(orig_logical.size())) { - ValGroup vg = graph_->toGroup(orig_logical[i]); - graph_->initializeVal(cache_logical[i], vg); - } - - return c; -} - std::vector> HopperMultipleMatmulScheduler:: blockTileTensors(const std::vector& tvs) { if (canonical_dim_ordering_.empty()) { @@ -623,19 +588,19 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { NVF_ERROR(d->definition() && d->definition()->isA()); TensorView* dc = d->definition()->input(0)->as(); - // NOTE: cacheBefore does not work with blockTileTensors - // cacheInputsAndOutputs creates a cache_before for each output. - // Apply cacheAfter to the existing cache tensor for output. // The chain of operations storing data to global memory: // registers -> (stmatrix) -> smem -> (tma_store) -> gmem - TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set); + TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set); std::vector tvs_to_schedule{d, d_smem}; - bool dc_in_mma_results = + bool dc_is_mma_result = std::find(mma_results_.begin(), mma_results_.end(), dc) != mma_results_.end(); + bool dc_is_splitk_sum = params_->splitk_factor > 1 && + std::find(splitk_sums_.begin(), splitk_sums_.end(), dc) != + splitk_sums_.end(); - if (!dc_in_mma_results) { + if (!dc_is_mma_result && !dc_is_splitk_sum) { // Skip scheduling dc if it is an mma_result. This can happen if we are // not casting back to half-precision in the output tvs_to_schedule.push_back(dc); @@ -666,7 +631,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Should not propagate if the dc is a mma output as the mma output has // already been scheduled. - if (!dc_in_mma_results) { + if (!dc_is_mma_result && !dc_is_splitk_sum) { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( dc->getLoopDomain()); dc->setLoopDomain(s.as()); diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 854d0705234..a46d046f2e9 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -124,14 +124,6 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { TensorView* tv, std::vector& outer_dim_roles); - //! This calls orig->cacheAfter() and also updates the broadcast graph to - //! reflect the new IterDomain mappings - TensorView* cacheAfter( - TensorView* orig, - LoadStoreOpType op_type = LoadStoreOpType::Set, - CacheOp cache_op = CacheOp::AllLevels, - bool propagate_allocation_domain = false); - //! Do block tiling for a collection of TensorViews. The tensors should be //! unscheduled before this method is called. //! 1) Axes will be ordered according to canonicalDimOrdering, and then axes diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 377629d99e6..1976bdf6bef 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -227,4 +227,58 @@ void MultipleMatmulScheduler::cacheInputsAndOutputs(bool skip_intermediates) { } } +TensorView* MultipleMatmulScheduler::cacheBefore( + TensorView* orig, + LoadStoreOpType op_type) { + TensorView* c = orig->cacheBefore(op_type); + + const std::vector orig_logical = + TensorDomain::noReductions(orig->getLogicalDomain()); + const std::vector cache_logical = c->getLogicalDomain(); + NVF_ERROR(orig_logical.size() == cache_logical.size()); + for (size_t i : arange(orig_logical.size())) { + // The domain of orig gets transferred to c and a new domain is applied to + // orig + ValGroup vg = graph_->toGroup(cache_logical[i]); + graph_->initializeVal(orig_logical[i], vg); + } + + return c; +} + +TensorView* MultipleMatmulScheduler::cacheAfter( + TensorView* orig, + LoadStoreOpType op_type, + CacheOp cache_op, + bool propagate_allocation_domain) { + const std::vector orig_alloc = orig->getMaybeAllocationDomain(); + + TensorView* c = + orig->cacheAfter(op_type, cache_op, propagate_allocation_domain); + + if (propagate_allocation_domain) { + const std::vector cache_alloc = c->getMaybeAllocationDomain(); + NVF_ERROR(orig_alloc.size() == cache_alloc.size()); + for (size_t i : arange(orig_alloc.size())) { + ValGroup vg = graph_->toGroup(orig_alloc[i]); + graph_->initializeVal(cache_alloc[i], vg); + } + } + + const std::vector orig_logical = + TensorDomain::noReductions(orig->getLogicalDomain()); + const std::vector cache_logical = c->getLogicalDomain(); + // in split-K we do rFactor which gives us a full = sum(partial) + // where partial has root domain that matches the logical domain of the + // original tensor. The logical domain contains Iteration transforms of the + // Reduction axis in the original mma output. + NVF_ERROR(orig_logical.size() == cache_logical.size()); + for (size_t i : arange(orig_logical.size())) { + ValGroup vg = graph_->toGroup(orig_logical[i]); + graph_->initializeVal(cache_logical[i], vg); + } + + return c; +} + } // namespace nvfuser diff --git a/csrc/scheduler/multi_matmul.h b/csrc/scheduler/multi_matmul.h index 7bcb86f0ead..8f9d200bba7 100644 --- a/csrc/scheduler/multi_matmul.h +++ b/csrc/scheduler/multi_matmul.h @@ -60,6 +60,20 @@ class MultipleMatmulScheduler { TensorView* operand, int64_t vec_size) = 0; + //! This calls orig->cacheBefore() and also updates the broadcast graph to + //! reflect the new IterDomain mappings + TensorView* cacheBefore( + TensorView* orig, + LoadStoreOpType op_type = LoadStoreOpType::Set); + + //! This calls orig->cacheAfter() and also updates the broadcast graph to + //! reflect the new IterDomain mappings + TensorView* cacheAfter( + TensorView* orig, + LoadStoreOpType op_type = LoadStoreOpType::Set, + CacheOp cache_op = CacheOp::AllLevels, + bool propagate_allocation_domain = false); + protected: Fusion* fusion_; const MatmulParams* params_; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 14a3ef84d62..8a7e6c3782b 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3656,6 +3656,32 @@ class HopperMatmulTest : public HopperBase { } }; +// 2 math group, non-persistent, non-warp specialized, no CGA +// TODO: This could be in HopperMatmulTest::SetUp() instead +MatmulParams defaultHopperParams() { + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.circular_buffering_strategy = + MatmulParams::CircularBufferingStrategy::Pipelined; + mparams.tiling_strategy = MatmulParams::TilingStrategy::OneTilePerCTA; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {1, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + return mparams; +} + TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { Fusion fusion; FusionGuard fg(&fusion); @@ -5360,4 +5386,60 @@ TEST_F(HopperMatmulTest, HSH_NT_SingleMathGroupSyncCheck) { cg_outputs[0].as(), out_ref, 1e-6 * K, 1e-6 * K)); } +// See https://github.com/NVIDIA/Fuser/issues/4159 +TEST_F(HopperMatmulTest, HSS_NT_SplitKTMAStore) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto t0 = at::randn({K, M, 1}, options); + auto t1 = at::randn({K, 1, N}, options); + auto out_ref = + at::matmul(t0.squeeze().t().to(at::kFloat), t1.squeeze().to(at::kFloat)); + + MatmulParams mparams = defaultHopperParams(); + mparams.use_smem_epilogue = true; + mparams.splitk_factor = 2; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + KernelExecutor ke; + ke.compile(&fusion, {t0, t1}); + // TODO: Either enable stmatrix for 32-bit outputs or fix current 2-way bank + // conflict by scheduling the vectorized store properly + auto bank_conflicts = getBankConflictInfo(ke.compiledKernel()->kernel()); + EXPECT_EQ(bank_conflicts.size(), 1); + for (const auto& [expr, conflict_ways] : bank_conflicts) { + int64_t input_ways, output_ways; + std::tie(input_ways, output_ways) = conflict_ways; + EXPECT_EQ(input_ways, 0); + EXPECT_EQ(output_ways, 2); + } + auto cg_outputs = ke.run({t0, t1}); + ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse( + ke.compiledKernel()->kernel())); + + // Relax tolerance for larger sum due to large K + NVF_CHECK(at::allclose( + cg_outputs[0].as(), out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser From 1bc13d8c7fe6502ca06f91e80b3b9c4140784577 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:13:05 -0400 Subject: [PATCH 13/68] Add NVFUSER_DUMP=sass_to_file option (#4263) This adds the `NVFUSER_DUMP=sass_to_file` option which operates similar to `NVFUSER_DUMP=ptx` and `NVFUSER_DUMP=cuda_to_file`. We currently now have five printing options: - `cuda_kernel` - `cuda_to_file` - `ptx` (actually prints to .ptx file) - `sass` - `sass_to_file` It would probably make sense to make these option names more uniform in a follow-up PR by renaming `cuda_kernel`->`cuda` and `ptx`->`ptx_to_file`. We could also potentially add a new `ptx` dump option that prints to screen, though I would skip this temporarily at least because it could cause unexpected behavior by diverting from file dump to screen, e.g. in CI jobs. --- csrc/options.cpp | 1 + csrc/options.h | 5 +++-- csrc/runtime/compiled_kernel.cpp | 7 +++++++ csrc/runtime/executor_utils.h | 2 ++ 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index fadf8cef0bb..391919cc825 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -139,6 +139,7 @@ std::unordered_map> Options< {"python_definition_segments", DebugDumpOption::PythonDefinitionSegments}, {"python_frontend_debug", DebugDumpOption::PythonFrontendDebug}, {"sass", DebugDumpOption::Sass}, + {"sass_to_file", DebugDumpOption::SassToFile}, {"segmented_fusion", DebugDumpOption::FusionSegments}, {"segmenter_logging", DebugDumpOption::FusionSegmenterLog}, {"scheduler_params", DebugDumpOption::SchedulerDebug}, diff --git a/csrc/options.h b/csrc/options.h index 8e61d2d14d7..c782a8cea14 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -70,7 +70,8 @@ enum class DebugDumpOption { TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result Cubin, //! Dump compiled CUBIN - Sass, // Dump disassembled SASS + Sass, //! Dump disassembled SASS + SassToFile, //!< Dump disassembled SASS to File Ptx, //! Dump compiled PTX BankConflictInfo, //! Dump bank confliction info SyncMap, //! RAW dependency info @@ -79,7 +80,7 @@ enum class DebugDumpOption { ExprSort, //! Print merging decisions on expression sorting ExprSortVerbose, //! Print verbose debug info on expression sorting LoopRotation, //! Print loop rotation log - Occupancy, // Dump occupancy + Occupancy, //! Dump occupancy IndexType, //! Print the index type of the launched kernel PredicateElimination, //! Print the predicate elimination information IndexingVerbose, //! Print verbose debug info on indexing diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index 9679987881a..9defcbaab9c 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -719,6 +719,13 @@ std::unique_ptr compileSource( compiled_kernel->cubin_filename = dumpCompiledCodeToFile(compiled_kernel->cubin, func_name, ".cubin"); } + if (isDebugDumpEnabled(DebugDumpOption::SassToFile)) { + std::string sass_str = + disassembleBinary(compiled_kernel->cubin, "-fun 1 -c"); + compiled_kernel->sass = {sass_str.begin(), sass_str.end()}; + compiled_kernel->sass_filename = + dumpCompiledCodeToFile(compiled_kernel->sass, func_name, ".sass"); + } } if (!compile_to_sass || isDebugDumpEnabled(DebugDumpOption::Ptx)) { diff --git a/csrc/runtime/executor_utils.h b/csrc/runtime/executor_utils.h index 9b04b82e85d..843f4ead896 100644 --- a/csrc/runtime/executor_utils.h +++ b/csrc/runtime/executor_utils.h @@ -43,6 +43,8 @@ struct CudaExecutable : public NonCopyable { std::string cubin_filename; std::string kernel_name; std::string compile_args; + std::vector sass; + std::string sass_filename; long block_size = -1; int register_spills = -1; }; From c477a3fa9000045bbbce152d1b90e98c1f4aee48 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Thu, 17 Apr 2025 15:58:47 -0400 Subject: [PATCH 14/68] disable TmaWarpSpecializedTes, it needs predicate (#4267) This test may fail as the trick using a divisible batch size also depends on circular buffer stages. Diable it and will re-enable after predicate is added. --- tests/cpp/test_combined_inner_outer_reduction.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 4c5a1591628..88825ae9237 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1164,7 +1164,7 @@ INSTANTIATE_TEST_SUITE_P( , TmaWarpSpecializedTest, ::testing::Combine( - testing::Values(true, false), + testing::Values(false), // tmp disable tma warp specialized testing::Values(DataType::Float, DataType::BFloat16), testing::Values(132 * 148), ::testing::Range((int64_t)1024, (int64_t)8193, (int64_t)1024)), From f3b22abc7251c3538dd71aae99f6235a6ecd42d4 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 17 Apr 2025 17:35:44 -0700 Subject: [PATCH 15/68] Create separate AsyncGroup helpers for fence, commit, and wait operations (#4271) For better pipelining of TMA store and `wgmma`, we should treat the `fence`, `commit`, and `wait` phases of `AsyncGroup` separately. * Replace `getSyncExprs` with `getAsyncCommit` and `getAsyncWait`. * Create `getAsyncFence` for `WgMmaFence` and `FenceAsyncProxy`. * No change to the CUDA kernel in this PR. ### Current ToT for `HopperMatmulTest/MLPGemmPersistentBroadcastInputs.NumWarpGroups/2` ```cuda #pragma unroll for(nvfuser_index_t i57 = 0; i57 < 16; ++i57) { if ((b42 && (i43 < (-(16 * i57))))) { stmatrix4((uint32_t)((toSmem(T8) + ((((nvfuser_index_t)threadIdx.y) * 32768) + (((i57 / 4) * 8192) + ((i12 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i57 % 4) * 2)) ^ (i12 % 8)) * 16)))))), (*reinterpret_cast*>(&T7[(8 * i57)]))); } } block_sync::sync(dim3(128, 2, 1)); #pragma unroll for(nvfuser_index_t i58 = 0; i58 < 4; ++i58) { fenceAsyncProxy(); if (((Hopper::electSync(4294967295U) && b16) && b19)) { Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr14, (Array{(int32_t)((i39 + (64 * i58))), i41}) }), (i13 + (8192 * i58))); } } block_sync::sync(dim3(128, 2, 1)); cpAsyncBulkCommitGroup(); cpAsyncBulkWaitGroup<0LL>(); ``` Notice that we want WAR sync before `stmatrix` to overlap TMA store with `wgmma`. The `commit` and `wait` phase for TMA store are separate. ### Optimized `HopperMatmulTest/MLPGemmPersistentBroadcastInputs.NumWarpGroups/2` ```cuda cpAsyncBulkWaitGroup<0LL>(); #pragma unroll for(nvfuser_index_t i57 = 0; i57 < 16; ++i57) { if ((b42 && (i43 < (-(16 * i57))))) { stmatrix4((uint32_t)((toSmem(T8) + ((((nvfuser_index_t)threadIdx.y) * 32768) + (((i57 / 4) * 8192) + ((i12 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i57 % 4) * 2)) ^ (i12 % 8)) * 16)))))), (*reinterpret_cast*>(&T7[(8 * i57)]))); } } block_sync::sync(dim3(128, 2, 1)); fenceAsyncProxy(); #pragma unroll for(nvfuser_index_t i58 = 0; i58 < 4; ++i58) { if (((Hopper::electSync(4294967295U) && b16) && b19)) { Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr14, (Array{(int32_t)((i39 + (64 * i58))), i41}) }), (i13 + (8192 * i58))); cpAsyncBulkCommitGroup(); } } ``` --- csrc/device_lower/pass/insert_syncs.cpp | 50 +++++++++++++++++++------ csrc/device_lower/utils.cpp | 15 -------- csrc/device_lower/utils.h | 10 ----- 3 files changed, 39 insertions(+), 36 deletions(-) diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 3a6b11c7ca1..cad4ac8c9ce 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -54,6 +54,30 @@ std::optional isOptionalComputeSync( } } +// Commit a series of operations to an async group. +// Create wgmma.fence for AsyncOpType::WgMma +// Otherwise, create fence.proxy.async +Expr* getAsyncFence(AsyncOpType async_type) { + if (async_type == AsyncOpType::WgMma) { + return IrBuilder::create(); + } + return IrBuilder::create(); +} + +// Commit a series of operations to an async group. +// Create wgmma.commit_group.sync.aligned for AsyncOpType::WgMma +// Create cpAsyncBulkCommitGroup for AsyncOpType::CpAsyncBulk +Expr* getAsyncCommit(AsyncOpType async_type) { + return IrBuilder::create(async_type); +} + +// Wait for a number of async groups to finish. +// Create wgmma.wait_group.sync.aligned for AsyncOpType::WgMma +// Create cpAsyncBulkWaitGroup for AsyncOpType::CpAsyncBulk +Expr* getAsyncWait(AsyncOpType async_type, int64_t keep_stages = 0) { + return IrBuilder::create(async_type, keep_stages); +} + // Tensor memory is similar to shared memory because they are both // shared between threads in a block. In that sense, we can consider // tensor memory as special type of shared memory. In this file, we use @@ -449,8 +473,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // except when these are accumulator register accesses across multiple // wgmma.mma_async instructions of the same shape. In the latter case, // an ordering guarantee is provided by default. - auto wgmma_fence = IrBuilder::create(); - registerInsertBefore(expr, wgmma_fence, scope); + registerInsertBefore(expr, getAsyncFence(AsyncOpType::WgMma), scope); if (!lower_utils::allMmaInputsGuardedByMBarrier(mma)) { // fence.proxy.async makes sure that writes to operands in the generic // proxy are visible to the async proxy @@ -522,10 +545,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { } } for (const auto& [async_type, ops] : input_async_ops) { - auto sync_exprs = lower_utils::getSyncExprs( - async_type, - /*keep_stages=*/0, - /*requires_commit=*/async_type != AsyncOpType::WgMma); + std::vector sync_exprs; + if (async_type != AsyncOpType::WgMma) { + sync_exprs.push_back(getAsyncCommit(async_type)); + } + sync_exprs.push_back(getAsyncWait(async_type, /*keep_stages=*/0)); for (auto sync_expr : sync_exprs) { insertSyncExpr(ops, expr, sync_expr, nullptr); } @@ -843,8 +867,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // kernel. for (auto expr : async_exprs_writing_fusion_output_) { auto async_type = ir_utils::getAsyncOpType(expr); - auto sync_exprs = - lower_utils::getSyncExprs(async_type, /*keep_stages=*/0); + std::vector sync_exprs{ + getAsyncCommit(async_type), + getAsyncWait(async_type, /*keep_stages=*/0)}; exprs_.insert(exprs_.end(), sync_exprs.begin(), sync_exprs.end()); } @@ -1225,8 +1250,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { active_compute_for_loop_->iter_domain()); int64_t pending_ops = opt.stage - opt.prefetch - 1; - auto sync_exprs = - lower_utils::getSyncExprs(AsyncOpType::WgMma, pending_ops); + std::vector sync_exprs{ + getAsyncCommit(AsyncOpType::WgMma), + getAsyncWait(AsyncOpType::WgMma, /*keep_stages=*/pending_ops)}; size_t num_exprs = for_loop->body().exprs().size(); NVF_ERROR(num_exprs > 1); NVF_ERROR(for_loop->body().exprs().back()->isA()); @@ -1308,7 +1334,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Actually insert these wait expressions. for (auto [type, pending_ops] : types_and_pending_ops_to_protect) { - auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops); + std::vector sync_exprs{ + getAsyncCommit(type), + getAsyncWait(type, /*keep_stages=*/pending_ops)}; NVF_ERROR(!for_loop->body().exprs().empty()); // Default position is last expression in for loop diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 936cfc11513..4bd53083c6f 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2099,21 +2099,6 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma) { ir_utils::isCpAsyncBulkLoad(ir_utils::getTv(mma->inB())->definition()); } -std::vector getSyncExprs( - AsyncOpType async_type, - int64_t keep_stages, - bool requires_commit) { - std::vector sync_exprs; - sync_exprs.reserve(2); - if (requires_commit) { - auto commit = IrBuilder::create(async_type); - sync_exprs.push_back(commit); - } - auto wait = IrBuilder::create(async_type, keep_stages); - sync_exprs.push_back(wait); - return sync_exprs; -} - } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index b45c7e2e3f6..7abc0ab6bfc 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -378,16 +378,6 @@ struct IterDomainDependencySorter { // Check if all the inputs of the given MmaOp is guarded by mbarrier bool allMmaInputsGuardedByMBarrier(const MmaOp* mma); -// Create a list of expressions that will be used to wait for async operations. -// For example, if op_type is AsyncOpType::WgMma, then the returned expressions -// will be: -// wgmma.commit_group.sync.aligned -// wgmma.wait_group.sync.aligned -std::vector getSyncExprs( - AsyncOpType async_type, - int64_t keep_stages = 0, - bool requires_commit = true); - } // namespace lower_utils } // namespace nvfuser From c1d8423d0fada31ebd7433d8bbc58103d3aa6dde Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 17 Apr 2025 19:35:58 -0700 Subject: [PATCH 16/68] Rename LoadWarp to AsyncWarp (#4270) * Blackwell UTCMMA is an mbarrier async operation that is not a TMA load operation. * It will be included in the separate warp group and run in parallel with TMA loads for the operands. * This new name is more consistent with what the warp group does. --- csrc/codegen.cpp | 4 +-- .../analysis/predicate_elimination.cpp | 2 +- csrc/device_lower/pass/circular_buffer.cpp | 14 ++++----- csrc/device_lower/pass/insert_syncs.cpp | 30 +++++++++---------- csrc/ir/interface_nodes.h | 4 +-- csrc/kernel_ir.h | 2 +- csrc/parallel_dimension_map.h | 4 +-- csrc/predicate_compute.cpp | 26 ++++++++-------- csrc/scheduler/hopper_multi_matmul.cpp | 4 +-- csrc/type.cpp | 4 +-- csrc/type.h | 6 ++-- 11 files changed, 50 insertions(+), 50 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 0882684a11d..3bdaa49aea8 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1283,7 +1283,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { tidy->value().as() + tidz->value().as(); NVF_ERROR( num_threads == 128, - "Expected 128 threads in LoadWarp, but found ", + "Expected 128 threads in AsyncWarp, but found ", num_threads); NVF_ERROR(pdim_map.hasWarpSpecialization()); ss << "dim3(" << genInlineOrOne(tidx) << ", " << genInlineOrOne(tidy) @@ -3557,7 +3557,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "block_sync::sync();\n"; } else if (isAligned()) { indent() << "__syncthreads();\n"; - } else if (sync->isLoadWarpSync()) { + } else if (sync->isAsyncWarpSync()) { ArgumentBuilder template_args; template_args.arg(isAligned()); ArgumentBuilder func_args; diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index eb230c38321..db77253bed7 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -75,7 +75,7 @@ namespace { bool isComputeWarp(TensorView* consumer, IterDomain* id_in_consumer) { // TODO: This function can not find all the expressions in the compute // warp. For example, if we have: - // if (load warp) { + // if (async warp) { // T1 = T0; // } else { // T2 = T1; diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 83626d653d1..cee3b5cbae5 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -106,7 +106,7 @@ class CircularBufferLoopCloner : public kir::IrVisitor { SimplifyingIrBuilder::create(opt.prefetch, DataType::Index)); break; } - case CircularBufferLoopStage::LoadWarp: + case CircularBufferLoopStage::AsyncWarp: case CircularBufferLoopStage::ComputeWarp: { break; } @@ -1462,24 +1462,24 @@ class CircularBufferInserter : private kir::ExprMutator { .num_registers.value(); GpuLower::current()->decIncRegisterUsage() = std::make_pair(decrease_num_registers, increase_num_registers); - // Decrease registers in load warp group - kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create( + // Decrease registers in async warp group + kir::SetMaxNReg* dec_reg_async_warp = IrBuilder::create( IrBuilder::create(decrease_num_registers, DataType::Index), /*increase_registers=*/false); - warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + warp_dispatch_ite->thenBody().push_back(dec_reg_async_warp); // Increase registers in compute warp group - kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create( + kir::SetMaxNReg* inc_reg_async_warp = IrBuilder::create( IrBuilder::create(increase_num_registers, DataType::Index), /*increase_registers*/ true); - warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + warp_dispatch_ite->elseBody().push_back(inc_reg_async_warp); } // Load loop: ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( circular_buffer_loop, loads, - CircularBufferLoopStage::LoadWarp, + CircularBufferLoopStage::AsyncWarp, insertion_position); warp_dispatch_ite->thenBody().push_back(load_loop); diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index cad4ac8c9ce..e3223da7904 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -20,10 +20,10 @@ namespace nvfuser { namespace { -// Determine if any for loop is a LoadWarp circular buffering stage -bool isWithinLoadWarp(const std::vector for_loops) { +// Determine if any for loop is a AsyncWarp circular buffering stage +bool isWithinAsyncWarp(const std::vector for_loops) { return std::any_of(for_loops.begin(), for_loops.end(), [](ForLoop* fl) { - return fl->circularBufferLoopStage() == CircularBufferLoopStage::LoadWarp; + return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; }); } @@ -36,16 +36,16 @@ bool isWithinComputeWarp(const std::vector for_loops) { } // Return true if any for loop is ComputeWarp. -// Return false if any for loop is LoadWarp. +// Return false if any for loop is AsyncWarp. // Return std:nullopt if none of the for loops are a warp specialized stage. std::optional isOptionalComputeSync( const std::vector for_loops) { - bool contains_load_warp = isWithinLoadWarp(for_loops); + bool contains_async_warp = isWithinAsyncWarp(for_loops); bool contains_compute_warp = isWithinComputeWarp(for_loops); NVF_ERROR( - !contains_load_warp || !contains_compute_warp, - "The list of for-loops contains both LoadWarp and ComputeWarp stages."); - if (isWithinLoadWarp(for_loops)) { + !contains_async_warp || !contains_compute_warp, + "The list of for-loops contains both AsyncWarp and ComputeWarp stages."); + if (isWithinAsyncWarp(for_loops)) { return false; } else if (isWithinComputeWarp(for_loops)) { return true; @@ -988,7 +988,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { //! Warp Specialization creates an If-Then-Else to separate load and compute //! operations. Therefore, the async_inputs_in_current_scope_ will not contain //! the async inputs for the corresponding async expression. Track async - //! inputs separately when we encounter them in load warp. + //! inputs separately when we encounter them in async warp. std::unordered_set warp_specialized_async_inputs_in_current_scope_; //! Track async exprs separately when we encounter them in compute warp. @@ -1052,7 +1052,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { return; } - // Gather all async inputs in LoadWarp + // Gather all async inputs in AsyncWarp TensorView* out_tv = ir_utils::getTvOutput(expr); NVF_ERROR(out_tv != nullptr); auto circular_buffer_loop = @@ -1060,7 +1060,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { out_tv, for_loops_); if (circular_buffer_loop != nullptr && circular_buffer_loop->circularBufferLoopStage() == - CircularBufferLoopStage::LoadWarp) { + CircularBufferLoopStage::AsyncWarp) { auto use_async_ops = getUseAsyncOpTypes(out_tv); if (!use_async_ops.empty()) { warp_specialized_async_inputs_in_current_scope_.emplace(out_tv); @@ -1200,10 +1200,10 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Special logic is required for warp specialized circular buffering because // the TMA loads and wgmma ops are separated by an IfThenElse. // kir::ExprMutator traverses the fusion in depth-wise order, so TMA loads in - // the LoadWarp are detected before the wgmma expressions in the ComputeWarp. + // the AsyncWarp are detected before the wgmma expressions in the ComputeWarp. // // This function inserts wgmma.commit_group and wgmma.wait_group expressions - // before the mbarrier::arrive, which allows load warp to launch next TMA + // before the mbarrier::arrive, which allows async warp to launch next TMA // load. First, we commit all the wgmma expressions issued in this iteration // of the for-loop. Then, we wait for some number of wgmma expressions based // on number of circular buffer stages and number of prefetch stages. @@ -1223,7 +1223,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { NVF_ERROR( warp_specialized_async_exprs_to_protect_.empty() || !warp_specialized_async_inputs_in_current_scope_.empty(), - "Expected TMA loads in LoadWarp for WgMma operations were detected in ComputeWarp."); + "Expected TMA loads in AsyncWarp for WgMma operations were detected in ComputeWarp."); // short-circuit: no wgmma expressions to protect in computeWarp. if (warp_specialized_async_exprs_to_protect_.empty()) { @@ -1231,7 +1231,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { return; } - // Establish all tma loads in LoadWarp are used by WgMma operations in + // Establish all tma loads in AsyncWarp are used by WgMma operations in // ComputeWarp. for (Expr* expr : warp_specialized_async_exprs_to_protect_) { if (ir_utils::isCpAsyncBulkStore(expr)) { diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index cffd277eeb2..0445339012d 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -192,7 +192,7 @@ class TVDomainGuard; // if (threadIdx.y == blockDim.y - 1) { // // If we use warp specialization on TIDy, then the blockDim.y of the // // kernel will be (whatever_value_inferred_from_schedule + 1), and the -// // last threadIdx.y will be used as load warp +// // last threadIdx.y will be used as async warp // for i in range(data.size): // wait buffer[i % stage] to be empty // load data[i] to buffer[i % stage] @@ -256,7 +256,7 @@ struct WarpSpecialized { validate_num_registers(num_registers.value().second); NVF_ERROR( num_registers.value().first <= num_registers.value().second, - "The number of registers for load warp group must be <= to the number", + "The number of registers for async warp group must be <= to the number", " of registers for the compute warp groups."); } diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 2cc5df4deed..a255b389cbb 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -516,7 +516,7 @@ class BlockSync final : public Expr { return attribute>(1).value_or(false); } - bool isLoadWarpSync() const { + bool isAsyncWarpSync() const { auto optional_compute_or_load_sync = attribute>(1); return optional_compute_or_load_sync.has_value() && !optional_compute_or_load_sync.value(); diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 35a85f9383b..ceb5218ec24 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -55,7 +55,7 @@ class ParallelDimensionMap { //! for loading circular buffer tensors. Val* getRawLoad(ParallelType pt) const; - //! The padded val ensures that CTA has 128 threads for the LoadWarp. This + //! The padded val ensures that CTA has 128 threads for the AsyncWarp. This //! function returns the padded val for the warp specialized ParallelType. int64_t getWarpSpecializationPaddedVal(ParallelType pt) const; @@ -89,7 +89,7 @@ class ParallelDimensionMap { //! If we are doing warp specialization on pt, then we need to increase //! the parallel dimension size of pt by one, where the extra one is used - //! as the load warp. In this case, pt becomes non-exact. + //! as the async warp. In this case, pt becomes non-exact. void adjustMappingsForWarpSpecialization(); private: diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 49b48fbea1e..86c8d362f4c 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -624,33 +624,33 @@ Val* createMultipleExpressionElectSync( const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); // Determine if warp specialized tma load expression. - ParallelType load_warp_on = ParallelType::Serial; - auto load_warp_loop_it = + ParallelType async_warp_on = ParallelType::Serial; + auto async_warp_loop_it = std::find_if(loops.begin(), loops.end(), [](ForLoop* fl) { return fl->circularBufferLoopStage() == - CircularBufferLoopStage::LoadWarp; + CircularBufferLoopStage::AsyncWarp; }); bool is_register_sharing = false; - if (load_warp_loop_it != loops.end()) { + if (async_warp_loop_it != loops.end()) { auto circular_buffer_type = std::get( GpuLower::current() ->circularBufferInfo() - .getCircularBufferOptionsFor((*load_warp_loop_it)->iter_domain()) + .getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain()) .type); - load_warp_on = circular_buffer_type.on; + async_warp_on = circular_buffer_type.on; is_register_sharing = circular_buffer_type.num_registers.has_value(); } // Short-circuit: register sharing is not used, don't need to pad a full warp - // group. If we are in a load warp, then the warp-dispatching IfThenElse - // already selects on `load_warp_on`, so we should not generate + // group. If we are in a async warp, then the warp-dispatching IfThenElse + // already selects on `async_warp_on`, so we should not generate // predicates for it here. if (!is_register_sharing) { - Val* conditional = load_warp_on == ParallelType::TIDx + Val* conditional = async_warp_on == ParallelType::TIDx ? pred->fusion()->trueVal() : createElectSyncPredicate(); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (pdim_map.has(pt) && load_warp_on != pt) { + if (pdim_map.has(pt) && async_warp_on != pt) { conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); @@ -661,20 +661,20 @@ Val* createMultipleExpressionElectSync( // If not specialized on TIDx, load branch has full size of bdimx, // we can use the first warp, otherwise should use the last warp. - bool use_first_warp = load_warp_on != ParallelType::TIDx; + bool use_first_warp = async_warp_on != ParallelType::TIDx; Val* conditional = createElectSyncPredicate(use_first_warp); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { if (!pdim_map.has(pt)) { continue; } - if (load_warp_on != pt) { + if (async_warp_on != pt) { // Not specialized on pt, use the first thread. conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); } else { // Specialized on pt, use the last thread. - Val* raw = GpuLower::current()->parallelDimensionMap().get(load_warp_on); + Val* raw = GpuLower::current()->parallelDimensionMap().get(async_warp_on); conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr( diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index df83886bdf7..b24c353c37d 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -755,12 +755,12 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { // register properly in that case. cb_type = (CircularBufferType)WarpSpecialized(ParallelType::TIDy); } else { - constexpr int64_t num_registers_load_warp = 40; + constexpr int64_t num_registers_async_warp = 40; constexpr int64_t num_registers_compute_warp = 232; cb_type = (CircularBufferType)WarpSpecialized( ParallelType::TIDy, std::make_pair( - num_registers_load_warp, num_registers_compute_warp)); + num_registers_async_warp, num_registers_compute_warp)); } break; } diff --git a/csrc/type.cpp b/csrc/type.cpp index d1a5b2abd80..c3664fc6805 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -1629,8 +1629,8 @@ std::ostream& operator<<( case CircularBufferLoopStage::Epilog: os << "{CircularBufferEpilog}"; break; - case CircularBufferLoopStage::LoadWarp: - os << "{LoadWarp}"; + case CircularBufferLoopStage::AsyncWarp: + os << "{AsyncWarp}"; break; case CircularBufferLoopStage::ComputeWarp: os << "{ComputeWarp}"; diff --git a/csrc/type.h b/csrc/type.h index 6a032aa5495..cb8561b25b8 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -819,7 +819,7 @@ enum class CircularBufferLoopStage { Prolog = 0, Main, Epilog, - LoadWarp, + AsyncWarp, ComputeWarp, EndOfStages, // A special placeholder used to iterate over all stages NotApplicable @@ -831,7 +831,7 @@ enum class CircularBufferLoopStage { inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Prolog || stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::LoadWarp; + stage == CircularBufferLoopStage::AsyncWarp; } // The consuming expressions of circular buffer are cloned for these circular @@ -851,7 +851,7 @@ inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { // somewhere (*may or may not be in this loop*) inline bool mayHaveWarHazard(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::LoadWarp || + stage == CircularBufferLoopStage::AsyncWarp || stage == CircularBufferLoopStage::ComputeWarp; } From 68511631fabe8d900229bea75dbca6e320f2e60e Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 17 Apr 2025 20:39:34 -0700 Subject: [PATCH 17/68] Remove stale exprs (#4268) Follow-up to #3776 This PR removes duplicated cast exprs from `SegmentedGroup`. There's nothing wrong leaving them there as they won't be picked up anyway, but `stablyOrderedExprs()`, which is used when printing `SegmentedFusion`, complains. Not sure why this could cause codegen diffs. --- csrc/fusion_segmenter.cpp | 2 ++ tests/cpp/test_segmentation.cpp | 24 ++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index cdb241b3517..ada881996b7 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -4375,6 +4375,8 @@ void SegmentCandidateFinder::revertPrivatizedUpcast(SegmentedGroup* group) { maybe_deduplicate_edge(consumer_edge_to_update); } + std::erase(group->exprs_, uop); + // Note that it should not be necessary to do anything with // group->output_vals since the inserted upcast ops should never produce // fusion outputs. diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 27ae0dbefbd..24db8d7793a 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -750,19 +750,31 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { fusion.addInput(tv0); auto tv1 = segment_set(tv0); - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = sum(tv2, {1}); - fusion.addOutput(tv3); + auto tv2 = set(tv1); + auto tv3 = castOp(DataType::Float, tv2); - auto tv4 = sum(tv2, {1}); + auto tv4 = sum(tv3, {1}); fusion.addOutput(tv4); + auto tv5 = sum(tv3, {1}); + fusion.addOutput(tv5); + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); auto t0 = at::randn({16, 32}, options); FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); + KernelArgumentHolder outputs; + + // Make sure NVFUSER_DUMP=segmented_fusion works + { + DebugDumpOptionsGuard options_guard; + DebugDumpOptionsGuard::getCurOptions().set(DebugDumpOption::FusionSegments); + std::ostringstream tmp_buf; + DebugStreamGuard debug_stream_guard(tmp_buf); + outputs = executor_cache.runFusionWithInputs({t0}); + } + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); // There must be two segments, one with ExprEvalExecutor and another @@ -787,7 +799,7 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { continue; } - EXPECT_EQ(uop->in()->as()->view()->name(), 1); + EXPECT_EQ(uop->in()->as()->view()->name(), 2); ++num_upcast_ops; } From 5494b0ab837d62474affe5d355e4d4ac45f1b87f Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 17 Apr 2025 22:56:43 -0700 Subject: [PATCH 18/68] Unskip the DeepSeek test (#4273) --- tests/python/test_deepseek_v3.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/test_deepseek_v3.py b/tests/python/test_deepseek_v3.py index 1c9f2acb6f7..d13b853e706 100644 --- a/tests/python/test_deepseek_v3.py +++ b/tests/python/test_deepseek_v3.py @@ -2,7 +2,6 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import pytest import transformers import torch from contextlib import contextmanager @@ -25,7 +24,10 @@ def default_tensor_type(dtype=torch.float32, device="cpu"): torch.set_default_device(prev_device) -@pytest.mark.skip(reason="flaky on CI due to download timeout: http://nv/eCm") +# This test timed out once when downloading +# "/deepseek-ai/DeepSeek-V3/resolve/main/configuration_deepseek.py" (cf. +# http://nv/eCm). I consider this a one-off, but please let me know if this +# error becomes consistent. def test_transformer_layer(): config = transformers.AutoConfig.from_pretrained( "deepseek-ai/deepseek-v3", trust_remote_code=True From 1181eac5f93fdc1ff91a9f6182246950e2825008 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 18 Apr 2025 04:41:13 -0700 Subject: [PATCH 19/68] minor improvements and cleanup --- csrc/host_ir/pass/stream_parallel_type.cpp | 143 ++++++++++++--------- csrc/preseg_passes/optimization_pass.h | 2 - csrc/type.cpp | 2 +- 3 files changed, 84 insertions(+), 63 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index f1419e3c626..e72fb2ac5ee 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -39,55 +39,12 @@ IterDomain* getStreamAxis(const std::vector& domain) { return ret; } -// StreamParallelType pass implementation. -// This pass handles stream parallelization of operations in a fusion. -// It works by: -// 1. Identifying stream-parallelized axes in tensor operations -// 2. Grouping compatible operations into stream-parallel for-loops -// 3. Setting up proper stream synchronization and management -// -// The pass ensures that: -// - Input tensors don't have stream axes -// - Only one stream axis exists per tensor -// - Stream axes are properly synchronized -// - Operations are correctly grouped into stream-parallel regions -// - The resulting HostIrContainer's top level expression is valid for execution -// and does not contain any stream axes -// -// TODO: Here, we assume that the fusion input is a HostIrContainer and use the -// linear structure of the HostIrContainer::topLevelExpr to greedily merge the -// adjacent compatible stream for-loop bodies. Ideally we should look at the dag -// and use the segmenter. -void StreamParallelType::runPass(Fusion* fusion) { - // Verify that input tensors don't have stream axes - NVF_CHECK( - std::all_of( - fusion->inputs().begin(), - fusion->inputs().end(), - [](Val* input) { - auto input_tv = dynamic_cast(input); - return input_tv == nullptr || - getStreamAxis(input_tv->getLoopDomain()) == nullptr; - }), - "Expected no stream axis in the TensorView inputs."); - - // Set up the fusion environment and build the ID model - FusionGuard fg(fusion); - hir::HostIrContainer* hic = dynamic_cast(fusion); - NVF_CHECK(hic, "Expected HostIrContainer"); - - IdModel id_model(fusion); - id_model.buildAlmostExactGraph(); - +// Step 1: Group expressions into stream-parallel regions +std::vector groupStreamParallelRegions( + hir::HostIrContainer* hic, + const IdModel& id_model) { std::vector new_top_level_exprs; - // Step 1: Group expressions into stream-parallel regions - // This step identifies which expressions can be merged into single stream - // for-loops - // - // After this step, new_top_level_exprs contains a - // list of expressions including newly created for-loops representing - // the stream parallelization containing and the relevant expressions for (auto expr : hic->topLevelExprs()) { // Skip expressions with no outputs if (expr->outputs().size() == 0) { @@ -130,15 +87,16 @@ void StreamParallelType::runPass(Fusion* fusion) { // Verify stream axis is an iteration axis (not reduction/broadcast) NVF_CHECK( - stream_axis->getIterType() == IterType::Iteration, + stream_axis->getIterType() == IterType::Iteration || + stream_axis->getIterType() == IterType::Broadcast, "Stream axis ", stream_axis, - " should be an iteration axis."); + " should be an iteration or broadcast axis."); // Check if expression can be merged with previous stream for-loop if (!new_top_level_exprs.empty() && new_top_level_exprs.back()->isA() && - id_model.idGraph(IdMappingMode::ALMOSTEXACT) + id_model.idGraph(IdMappingMode::BROADCAST) .disjointValSets() .strictAreMapped( stream_axis, @@ -149,7 +107,7 @@ void StreamParallelType::runPass(Fusion* fusion) { // Create new for-loop for stream parallelization auto* for_loop = IrBuilder::create( stream_axis, - /*index=*/IrBuilder::create(DataType::Index), + /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream), /*start=*/hic->zeroVal(), /*stop=*/stream_axis->extent(), /*step=*/hic->oneVal(), @@ -164,10 +122,15 @@ void StreamParallelType::runPass(Fusion* fusion) { } } - // Step 2: Process each for-loop's body by slicing tensors - // This step handles the actual tensor slicing for stream parallelization - std::vector top_level_exprs = std::move(new_top_level_exprs); - new_top_level_exprs.clear(); + return new_top_level_exprs; +} + +// Step 2: Process for-loop bodies by slicing tensors +std::vector processForLoopBodies( + hir::HostIrContainer* hic, + const IdModel& id_model, + std::vector top_level_exprs) { + std::vector new_top_level_exprs; for (auto top_level_expr : top_level_exprs) { if (!top_level_expr->isA()) { @@ -190,7 +153,7 @@ void StreamParallelType::runPass(Fusion* fusion) { // Find stream axis index in input tensor int64_t input_stream_id_logical_index = -1; for (auto id : input->getLoopDomain()) { - if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) + if (id_model.idGraph(IdMappingMode::BROADCAST) .disjointValSets() .strictAreMapped(for_loop->iterDomain(), id)) { // Verify only one stream axis exists @@ -254,7 +217,7 @@ void StreamParallelType::runPass(Fusion* fusion) { // Find stream axis index in output tensor int64_t output_stream_id_logical_index = -1; for (auto id : output->getLoopDomain()) { - if (id_model.idGraph(IdMappingMode::ALMOSTEXACT) + if (id_model.idGraph(IdMappingMode::BROADCAST) .disjointValSets() .strictAreMapped(for_loop->iterDomain(), id)) { // Verify only one stream axis exists @@ -334,9 +297,16 @@ void StreamParallelType::runPass(Fusion* fusion) { new_top_level_exprs.push_back(top_level_expr); } - // Step 3: Add stream management and synchronization - for (auto* top_level_expr : new_top_level_exprs) { + return new_top_level_exprs; +} + +// Step 3: Add stream management and synchronization +std::vector addStreamManagement(std::vector top_level_exprs) { + std::vector new_top_level_exprs; + + for (auto* top_level_expr : top_level_exprs) { if (!top_level_expr->isA()) { + new_top_level_exprs.push_back(top_level_expr); continue; } auto* for_loop = top_level_expr->as(); @@ -377,10 +347,63 @@ void StreamParallelType::runPass(Fusion* fusion) { for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); } + new_top_level_exprs.push_back(top_level_expr); } + return new_top_level_exprs; +} + +// StreamParallelType pass implementation. +// This pass handles stream parallelization of operations in a fusion. +// It works by: +// 1. Identifying stream-parallelized axes in tensor operations +// 2. Grouping compatible operations into stream-parallel for-loops +// 3. Setting up proper stream synchronization and management +// +// The pass ensures that: +// - Input tensors don't have stream axes +// - Only one stream axis exists per tensor +// - Stream axes are properly synchronized +// - Operations are correctly grouped into stream-parallel regions +// - The resulting HostIrContainer's top level expression is valid for execution +// and does not contain any stream axes +// +// TODO: Here, we assume that the fusion input is a HostIrContainer and use the +// linear structure of the HostIrContainer::topLevelExpr to greedily merge the +// adjacent compatible stream for-loop bodies. Ideally we should look at the dag +// and use the segmenter. +void StreamParallelType::runPass(Fusion* fusion) { + // Verify that input tensors don't have stream axes + NVF_CHECK( + std::all_of( + fusion->inputs().begin(), + fusion->inputs().end(), + [](Val* input) { + auto input_tv = dynamic_cast(input); + return input_tv == nullptr || + getStreamAxis(input_tv->getLoopDomain()) == nullptr; + }), + "Expected no stream axis in the TensorView inputs."); + + // Set up the fusion environment and build the ID model + FusionGuard fg(fusion); + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); + + IdModel id_model(fusion); + id_model.buildBroadcastGraph(); + + // Step 1: Group expressions into stream-parallel regions + std::vector top_level_exprs = groupStreamParallelRegions(hic, id_model); + + // Step 2: Process for-loop bodies by slicing tensors + top_level_exprs = processForLoopBodies(hic, id_model, std::move(top_level_exprs)); + + // Step 3: Add stream management and synchronization + top_level_exprs = addStreamManagement(std::move(top_level_exprs)); + // Update the container's top-level expressions - hic->resetTopLevelExprs(new_top_level_exprs); + hic->resetTopLevelExprs(top_level_exprs); } } // namespace nvfuser::preseg_passes diff --git a/csrc/preseg_passes/optimization_pass.h b/csrc/preseg_passes/optimization_pass.h index 53d8a8acd3c..359a4a42742 100644 --- a/csrc/preseg_passes/optimization_pass.h +++ b/csrc/preseg_passes/optimization_pass.h @@ -18,8 +18,6 @@ namespace nvfuser::preseg_passes { -using FusionPass = std::function; - //! [experimental API] //! Base class to unify optimization pass APIs. //! OptimizationPass can be turned on/off programmatically with the `setEnabled` diff --git a/csrc/type.cpp b/csrc/type.cpp index d1a5b2abd80..e4a89372d52 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -729,7 +729,7 @@ static const char* parallel_type2string(ParallelType t) { case ParallelType::TIDx: return "threadIdx.x"; case ParallelType::Stream: - return "Stream"; + return "StreamIdx"; case ParallelType::Vectorize: return "V"; case ParallelType::Unroll: From cad9bce67e39678d3a972bdaa8098e16d84f0206 Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 18 Apr 2025 05:17:31 -0700 Subject: [PATCH 20/68] further refactor of stream pass --- csrc/host_ir/pass/stream_parallel_type.cpp | 192 ++++++++++----------- 1 file changed, 91 insertions(+), 101 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index e72fb2ac5ee..d40c9cff147 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -20,9 +20,8 @@ namespace nvfuser::preseg_passes { -// Helper function to find the first stream-parallelized axis in a domain. -// This function throws if multiple stream-parallelized axes are found (only one -// is allowed) +namespace { + IterDomain* getStreamAxis(const std::vector& domain) { IterDomain* ret = nullptr; for (auto id : domain) { @@ -39,6 +38,83 @@ IterDomain* getStreamAxis(const std::vector& domain) { return ret; } +void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { + // Find the stream axis in the logical domain + auto it_logical_stream_axis = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + stream_axis); + + // Verify stream axis is not split/merged + NVF_ERROR( + it_logical_stream_axis != tv->getLogicalDomain().end(), + "Cannot stream parallelize on a split/merge axis ", + stream_axis); + + // Verify stream axis is an iteration or broadcast axis + NVF_CHECK( + stream_axis->getIterType() == IterType::Iteration || + stream_axis->getIterType() == IterType::Broadcast, + "Stream axis ", + stream_axis, + " should be an iteration or broadcast axis."); +} + +bool areIdsMapped(const IdModel& id_model, IterDomain* id1, IterDomain* id2) { + return id_model.idGraph(IdMappingMode::BROADCAST) + .disjointValSets() + .strictAreMapped(id1, id2); +} + +bool canMergeWithPreviousForLoop( + const std::vector& new_top_level_exprs, + IterDomain* stream_axis, + const IdModel& id_model) { + return !new_top_level_exprs.empty() && + new_top_level_exprs.back()->isA() && + areIdsMapped( + id_model, + stream_axis, + new_top_level_exprs.back()->as()->iterDomain()); +} + +int64_t findStreamAxisIndex( + const TensorView* tv, + IterDomain* stream_axis, + const IdModel& id_model) { + int64_t stream_id_logical_index = -1; + for (auto id : tv->getLoopDomain()) { + if (areIdsMapped(id_model, stream_axis, id)) { + // Verify only one stream axis exists + NVF_CHECK( + stream_id_logical_index == -1, + "Expected at most one axis mapping to the stream axis ", + stream_axis, + " in the tensor ", + tv, + " loop's domain ", + tv->getLoopDomain()); + + // Find stream axis in logical domain + auto it_stream_id_logical = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + id); + NVF_CHECK( + it_stream_id_logical != tv->getLogicalDomain().end(), + "Expected to find ", + id, + " in ", + tv, + "'s logical domain ", + tv->getLogicalDomain()); + stream_id_logical_index = std::distance( + tv->getLogicalDomain().begin(), it_stream_id_logical); + } + } + return stream_id_logical_index; +} + // Step 1: Group expressions into stream-parallel regions std::vector groupStreamParallelRegions( hir::HostIrContainer* hic, @@ -73,34 +149,11 @@ std::vector groupStreamParallelRegions( "Stream parallel type not supported for expr ", expr); - // Find the stream axis in the logical (and not loop) domain - auto it_logical_stream_axis = std::find( - output->getLogicalDomain().begin(), - output->getLogicalDomain().end(), - stream_axis); - - // Verify stream axis is not split/merged - NVF_ERROR( - it_logical_stream_axis != output->getLogicalDomain().end(), - "Cannot stream parallelize on a split/merge axis ", - stream_axis); - - // Verify stream axis is an iteration axis (not reduction/broadcast) - NVF_CHECK( - stream_axis->getIterType() == IterType::Iteration || - stream_axis->getIterType() == IterType::Broadcast, - "Stream axis ", - stream_axis, - " should be an iteration or broadcast axis."); + // Validate stream axis + validateStreamAxis(stream_axis, output); // Check if expression can be merged with previous stream for-loop - if (!new_top_level_exprs.empty() && - new_top_level_exprs.back()->isA() && - id_model.idGraph(IdMappingMode::BROADCAST) - .disjointValSets() - .strictAreMapped( - stream_axis, - new_top_level_exprs.back()->as()->iterDomain())) { + if (canMergeWithPreviousForLoop(new_top_level_exprs, stream_axis, id_model)) { // Merge with existing for-loop new_top_level_exprs.back()->as()->body().push_back(expr); } else { @@ -117,7 +170,6 @@ std::vector groupStreamParallelRegions( CircularBufferLoopStage::NotApplicable, /*circular_buffer_loop_stage_depth=*/0); for_loop->body().push_back(expr); - // replace the current expr by the for-loop containing it new_top_level_exprs.push_back(for_loop); } } @@ -150,41 +202,9 @@ std::vector processForLoopBodies( // Process input tensors for (auto* input : ir_utils::filterByType(expr->inputs())) { - // Find stream axis index in input tensor - int64_t input_stream_id_logical_index = -1; - for (auto id : input->getLoopDomain()) { - if (id_model.idGraph(IdMappingMode::BROADCAST) - .disjointValSets() - .strictAreMapped(for_loop->iterDomain(), id)) { - // Verify only one stream axis exists - NVF_CHECK( - input_stream_id_logical_index == -1, - "Expected at most one axis mapping to the stream axis ", - for_loop->iterDomain(), - " in the tensor ", - input, - " loop's domain ", - input->getLoopDomain()); - - // Find stream axis in logical domain - auto it_input_stream_id_logical = std::find( - input->getLogicalDomain().begin(), - input->getLogicalDomain().end(), - id); - NVF_CHECK( - it_input_stream_id_logical != input->getLogicalDomain().end(), - "Expected to find ", - id, - " in ", - input, - "'s logical domain ", - input->getLogicalDomain()); - input_stream_id_logical_index = std::distance( - input->getLogicalDomain().begin(), it_input_stream_id_logical); - } - } + int64_t input_stream_id_logical_index = findStreamAxisIndex( + input, for_loop->iterDomain(), id_model); - // Skip if no stream axis found if (input_stream_id_logical_index == -1) { continue; } @@ -214,42 +234,9 @@ std::vector processForLoopBodies( // Process output tensors for (auto* output : ir_utils::filterByType(expr->outputs())) { - // Find stream axis index in output tensor - int64_t output_stream_id_logical_index = -1; - for (auto id : output->getLoopDomain()) { - if (id_model.idGraph(IdMappingMode::BROADCAST) - .disjointValSets() - .strictAreMapped(for_loop->iterDomain(), id)) { - // Verify only one stream axis exists - NVF_CHECK( - output_stream_id_logical_index == -1, - "Expected at most one axis mapping to the stream axis ", - for_loop->iterDomain(), - " in the tensor ", - output, - " loop's domain ", - output->getLoopDomain()); - - // Find stream axis in logical domain - auto it_output_stream_id_logical = std::find( - output->getLogicalDomain().begin(), - output->getLogicalDomain().end(), - id); - NVF_CHECK( - it_output_stream_id_logical != output->getLogicalDomain().end(), - "Expected to find ", - id, - " in ", - output, - "'s logical domain ", - output->getLogicalDomain()); - output_stream_id_logical_index = std::distance( - output->getLogicalDomain().begin(), - it_output_stream_id_logical); - } - } + int64_t output_stream_id_logical_index = findStreamAxisIndex( + output, for_loop->iterDomain(), id_model); - // Skip if no stream axis found if (output_stream_id_logical_index == -1) { continue; } @@ -276,14 +263,14 @@ std::vector processForLoopBodies( if (running_output == output) { // Create alias for the sliced output TensorView* output_j_alias = - ops::newValLike( - output_j, output_j->dtype(), /*keep_reduction_axis=*/true) + ops::newValLike(output_j, output_j->dtype(), true) ->as(); hic->markAlias(output_j, output_j_alias); *it_running_expr = ir_utils::transferDefinitionToNewOutputs( running_expr, {output_j_alias}); } } + } } new_loop_body.push_back(*it_expr); @@ -353,6 +340,8 @@ std::vector addStreamManagement(std::vector top_level_exprs) { return new_top_level_exprs; } +} // anonymous namespace + // StreamParallelType pass implementation. // This pass handles stream parallelization of operations in a fusion. // It works by: @@ -407,3 +396,4 @@ void StreamParallelType::runPass(Fusion* fusion) { } } // namespace nvfuser::preseg_passes + From 7ae7c52f7340c85458b8b48d3d971532e183e36b Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 18 Apr 2025 05:38:36 -0700 Subject: [PATCH 21/68] improve comments clarity --- csrc/host_ir/pass/stream_parallel_type.cpp | 65 ++++++++++++++-------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index d40c9cff147..dd19f4f6b5d 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -22,6 +22,7 @@ namespace nvfuser::preseg_passes { namespace { +// Finds the stream axis in a tensor's domain. There should be at most one stream axis. IterDomain* getStreamAxis(const std::vector& domain) { IterDomain* ret = nullptr; for (auto id : domain) { @@ -38,6 +39,7 @@ IterDomain* getStreamAxis(const std::vector& domain) { return ret; } +// Validates that a stream axis is valid in a tensor void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { // Find the stream axis in the logical domain auto it_logical_stream_axis = std::find( @@ -60,12 +62,14 @@ void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { " should be an iteration or broadcast axis."); } +// Checks if two iteration domains are mapped in the ID model bool areIdsMapped(const IdModel& id_model, IterDomain* id1, IterDomain* id2) { return id_model.idGraph(IdMappingMode::BROADCAST) .disjointValSets() .strictAreMapped(id1, id2); } +// Determines if a stream-parallel for-loop can be merged with the previous one bool canMergeWithPreviousForLoop( const std::vector& new_top_level_exprs, IterDomain* stream_axis, @@ -78,6 +82,7 @@ bool canMergeWithPreviousForLoop( new_top_level_exprs.back()->as()->iterDomain()); } +// Finds where a stream axis appears in a tensor's logical domain int64_t findStreamAxisIndex( const TensorView* tv, IterDomain* stream_axis, @@ -121,6 +126,7 @@ std::vector groupStreamParallelRegions( const IdModel& id_model) { std::vector new_top_level_exprs; + // Process each top-level expression for (auto expr : hic->topLevelExprs()) { // Skip expressions with no outputs if (expr->outputs().size() == 0) { @@ -128,7 +134,7 @@ std::vector groupStreamParallelRegions( continue; } - // Verify single output constraint + // Each expression should have exactly one output NVF_CHECK( expr->outputs().size() == 1, "Each expr should have at most one output."); @@ -137,13 +143,13 @@ std::vector groupStreamParallelRegions( TensorView* output = expr->output(0)->as(); IterDomain* stream_axis = getStreamAxis(output->getLoopDomain()); - // If no stream axis, keep expression as is + // If no stream axis found, keep the expression as is if (stream_axis == nullptr) { new_top_level_exprs.push_back(expr); continue; } - // Verify expression can be handled as a standalone host operation + // Verify that the expression can be handled as a standalone host operation NVF_ERROR( HostIrLower::isLowerableAsStandaloneHostOp(expr), "Stream parallel type not supported for expr ", @@ -152,12 +158,12 @@ std::vector groupStreamParallelRegions( // Validate stream axis validateStreamAxis(stream_axis, output); - // Check if expression can be merged with previous stream for-loop + // Check if we can merge this expression with the previous for-loop if (canMergeWithPreviousForLoop(new_top_level_exprs, stream_axis, id_model)) { - // Merge with existing for-loop + // Merge with existing for-loop by adding the expression to its body new_top_level_exprs.back()->as()->body().push_back(expr); } else { - // Create new for-loop for stream parallelization + // Create a new for-loop for stream parallelization auto* for_loop = IrBuilder::create( stream_axis, /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream), @@ -169,6 +175,7 @@ std::vector groupStreamParallelRegions( /*unroll_required=*/false, CircularBufferLoopStage::NotApplicable, /*circular_buffer_loop_stage_depth=*/0); + // Add the expression to the new for-loop's body for_loop->body().push_back(expr); new_top_level_exprs.push_back(for_loop); } @@ -184,7 +191,9 @@ std::vector processForLoopBodies( std::vector top_level_exprs) { std::vector new_top_level_exprs; + // Process each top-level expression for (auto top_level_expr : top_level_exprs) { + // Skip non-for-loop expressions if (!top_level_expr->isA()) { new_top_level_exprs.push_back(top_level_expr); continue; @@ -192,24 +201,26 @@ std::vector processForLoopBodies( auto* for_loop = top_level_expr->as(); std::vector new_loop_body; + std::vector current_loop_body = for_loop->body().exprs(); // Process each expression in the loop body - std::vector current_loop_body = for_loop->body().exprs(); for (auto it_expr = current_loop_body.begin(); it_expr != current_loop_body.end(); ++it_expr) { Expr* expr = *it_expr; - // Process input tensors + // Process input tensors that might have stream axes for (auto* input : ir_utils::filterByType(expr->inputs())) { + // Find if this input has a stream axis int64_t input_stream_id_logical_index = findStreamAxisIndex( input, for_loop->iterDomain(), id_model); + // Skip if no stream axis found if (input_stream_id_logical_index == -1) { continue; } - // Create sliced tensor for current stream iteration + // Create a sliced version of the input tensor for this stream iterdomain TensorView* input_j = select( input, input_stream_id_logical_index, @@ -217,7 +228,7 @@ std::vector processForLoopBodies( /*keep_reduction_axis=*/true); new_loop_body.push_back(input_j->definition()); - // Update all expressions using this input + // Update all expressions that use this input to use the sliced version for (auto it_running_expr = current_loop_body.begin(); it_running_expr != current_loop_body.end(); ++it_running_expr) { @@ -232,28 +243,30 @@ std::vector processForLoopBodies( } } - // Process output tensors + // Process output tensors that might have stream axes for (auto* output : ir_utils::filterByType(expr->outputs())) { + // Find if this output has a stream axis int64_t output_stream_id_logical_index = findStreamAxisIndex( output, for_loop->iterDomain(), id_model); + // Skip if no stream axis found if (output_stream_id_logical_index == -1) { continue; } - // Create sliced tensor for current stream iteration + // Create a sliced version of the output tensor for this stream axis TensorView* output_j = select( output, output_stream_id_logical_index, for_loop->index(), /*keep_reduction_axis=*/true); - // Allocate memory for the output tensor + // Allocate memory for the output tensor, and place the allocation IR before the for-loop, at the top level new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); new_loop_body.push_back(output_j->definition()); - // Update all expressions using this output + // Update all expressions that use this output to use the sliced version for (auto it_running_expr = current_loop_body.begin(); it_running_expr != current_loop_body.end(); ++it_running_expr) { @@ -261,7 +274,8 @@ std::vector processForLoopBodies( for (auto* running_output : ir_utils::filterByType(running_expr->outputs())) { if (running_output == output) { - // Create alias for the sliced output + // Create an alias for the sliced output to maintain the original tensor's properties + // Alias is needed here to avoid that transferDefinitionToNewOutputs throws. Indeed, HIC does not make the SSA assumption, but the util functions we use (such as transferDefinitionToNewOutputs) do, therefore we need to create an alias for the sliced output to not create loops in the dag. TensorView* output_j_alias = ops::newValLike(output_j, output_j->dtype(), true) ->as(); @@ -270,13 +284,14 @@ std::vector processForLoopBodies( running_expr, {output_j_alias}); } } - } } + + // Add the original expression to the new loop body new_loop_body.push_back(*it_expr); } - // Update for-loop body with processed expressions + // Update the for-loop body with all the processed expressions for_loop->body().clear(); for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); @@ -291,20 +306,23 @@ std::vector processForLoopBodies( std::vector addStreamManagement(std::vector top_level_exprs) { std::vector new_top_level_exprs; + // Process each top-level expression for (auto* top_level_expr : top_level_exprs) { + // Skip non-for-loop expressions if (!top_level_expr->isA()) { new_top_level_exprs.push_back(top_level_expr); continue; } + auto* for_loop = top_level_expr->as(); std::vector new_loop_body; - // Get current stream for later synchronization + // Get the current stream before entering the loop auto* get_current_stream = IrBuilder::create(); hir::Stream* original_stream = get_current_stream->stream(); new_loop_body.push_back(get_current_stream); - // Set up stream for current iteration + // Set up a new stream for this iteration based on the loop index auto* number_of_streams = IrBuilder::create("numberOfStreams", DataType::Int); auto* stream_index = mod(for_loop->index(), number_of_streams); @@ -312,24 +330,24 @@ std::vector addStreamManagement(std::vector top_level_exprs) { auto* set_stream = IrBuilder::create(stream); new_loop_body.push_back(set_stream); - // Synchronize with original stream + // Synchronize with the original stream before starting computation auto* initial_sync_stream = IrBuilder::create(original_stream); new_loop_body.push_back(initial_sync_stream); - // Add the actual computation expressions + // Add all the expressions to the loop body for (auto* expr : for_loop->body().exprs()) { new_loop_body.push_back(expr); } - // Restore original stream and synchronize + // Restore the original stream and synchronize with the iteration's stream auto* set_back_original_stream = IrBuilder::create(original_stream); new_loop_body.push_back(set_back_original_stream); auto* sync_stream = IrBuilder::create(stream); new_loop_body.push_back(sync_stream); - // Update for-loop body with stream management + // Update the for-loop body with the new expressions for_loop->body().clear(); for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); @@ -396,4 +414,3 @@ void StreamParallelType::runPass(Fusion* fusion) { } } // namespace nvfuser::preseg_passes - From 6dd673f4b39f0b9db7809be9110d6f998da1bd8e Mon Sep 17 00:00:00 2001 From: snordmann Date: Fri, 18 Apr 2025 05:49:22 -0700 Subject: [PATCH 22/68] more comments --- csrc/host_ir/pass/stream_parallel_type.cpp | 42 +++++++++++++--------- csrc/ops/indexing.h | 3 ++ 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index dd19f4f6b5d..63ebc9fc42c 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -22,7 +22,8 @@ namespace nvfuser::preseg_passes { namespace { -// Finds the stream axis in a tensor's domain. There should be at most one stream axis. +// Finds the stream axis in a tensor's domain. There should be at most one +// stream axis. IterDomain* getStreamAxis(const std::vector& domain) { IterDomain* ret = nullptr; for (auto id : domain) { @@ -102,9 +103,7 @@ int64_t findStreamAxisIndex( // Find stream axis in logical domain auto it_stream_id_logical = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - id); + tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), id); NVF_CHECK( it_stream_id_logical != tv->getLogicalDomain().end(), "Expected to find ", @@ -113,8 +112,8 @@ int64_t findStreamAxisIndex( tv, "'s logical domain ", tv->getLogicalDomain()); - stream_id_logical_index = std::distance( - tv->getLogicalDomain().begin(), it_stream_id_logical); + stream_id_logical_index = + std::distance(tv->getLogicalDomain().begin(), it_stream_id_logical); } } return stream_id_logical_index; @@ -159,7 +158,8 @@ std::vector groupStreamParallelRegions( validateStreamAxis(stream_axis, output); // Check if we can merge this expression with the previous for-loop - if (canMergeWithPreviousForLoop(new_top_level_exprs, stream_axis, id_model)) { + if (canMergeWithPreviousForLoop( + new_top_level_exprs, stream_axis, id_model)) { // Merge with existing for-loop by adding the expression to its body new_top_level_exprs.back()->as()->body().push_back(expr); } else { @@ -212,15 +212,16 @@ std::vector processForLoopBodies( // Process input tensors that might have stream axes for (auto* input : ir_utils::filterByType(expr->inputs())) { // Find if this input has a stream axis - int64_t input_stream_id_logical_index = findStreamAxisIndex( - input, for_loop->iterDomain(), id_model); + int64_t input_stream_id_logical_index = + findStreamAxisIndex(input, for_loop->iterDomain(), id_model); // Skip if no stream axis found if (input_stream_id_logical_index == -1) { continue; } - // Create a sliced version of the input tensor for this stream iterdomain + // Create a sliced version of the input tensor for this stream + // iterdomain TensorView* input_j = select( input, input_stream_id_logical_index, @@ -246,8 +247,8 @@ std::vector processForLoopBodies( // Process output tensors that might have stream axes for (auto* output : ir_utils::filterByType(expr->outputs())) { // Find if this output has a stream axis - int64_t output_stream_id_logical_index = findStreamAxisIndex( - output, for_loop->iterDomain(), id_model); + int64_t output_stream_id_logical_index = + findStreamAxisIndex(output, for_loop->iterDomain(), id_model); // Skip if no stream axis found if (output_stream_id_logical_index == -1) { @@ -261,7 +262,8 @@ std::vector processForLoopBodies( for_loop->index(), /*keep_reduction_axis=*/true); - // Allocate memory for the output tensor, and place the allocation IR before the for-loop, at the top level + // Allocate memory for the output tensor, and place the allocation IR + // before the for-loop, at the top level new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); new_loop_body.push_back(output_j->definition()); @@ -274,8 +276,12 @@ std::vector processForLoopBodies( for (auto* running_output : ir_utils::filterByType(running_expr->outputs())) { if (running_output == output) { - // Create an alias for the sliced output to maintain the original tensor's properties - // Alias is needed here to avoid that transferDefinitionToNewOutputs throws. Indeed, HIC does not make the SSA assumption, but the util functions we use (such as transferDefinitionToNewOutputs) do, therefore we need to create an alias for the sliced output to not create loops in the dag. + // Create an alias for the sliced output to maintain the original + // tensor's properties Alias is needed here to avoid that + // transferDefinitionToNewOutputs throws. Indeed, HIC does not + // make the SSA assumption, but the util functions we use (such as + // transferDefinitionToNewOutputs) do, therefore we need to create + // an alias for the sliced output to not create loops in the dag. TensorView* output_j_alias = ops::newValLike(output_j, output_j->dtype(), true) ->as(); @@ -401,10 +407,12 @@ void StreamParallelType::runPass(Fusion* fusion) { id_model.buildBroadcastGraph(); // Step 1: Group expressions into stream-parallel regions - std::vector top_level_exprs = groupStreamParallelRegions(hic, id_model); + std::vector top_level_exprs = + groupStreamParallelRegions(hic, id_model); // Step 2: Process for-loop bodies by slicing tensors - top_level_exprs = processForLoopBodies(hic, id_model, std::move(top_level_exprs)); + top_level_exprs = + processForLoopBodies(hic, id_model, std::move(top_level_exprs)); // Step 3: Add stream management and synchronization top_level_exprs = addStreamManagement(std::move(top_level_exprs)); diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index 7a219c534a3..5e0410d95d5 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -15,6 +15,9 @@ namespace nvfuser { +// When keep_reduction_axis is true, all reduction axis are kept in the +// SelectOp's consumer. This is used in the context of HostIr where SelectOp is +// used to index into Stream-parallelized axes. NVF_API TensorView* select( TensorView* tv, int64_t dim, From ac7e09a7ceca9ec697d744187043955bcaf9c89c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Fri, 18 Apr 2025 09:53:38 -0700 Subject: [PATCH 23/68] Adding IndexPutAccumulateOp (#4063) Things done in this PR is to support embedding backward, which requires `torch.index_put_(..., accumulate=True)`. Stacked PRs: - [x] #4063 <-- This PR - [ ] #4066 What this PR does: * Added fusion IR node IndexPutAccumulateOp. 1. IR signature ```TensorView* IndexPutAccumulateOp( TensorView* out, TensorView* acc, TensorView* index, TensorView* value)``` 2. Allow expression evaluation execution via `at::index_put` * Added c++ API ```TensorView* indexPutAccumulate(TensorView* acc_tv, TensorView* index_tv, TensorView* value_tv``` There are two things worth noting: 1. The signature of the op requires a `acc` input, which is supposed to be used by codegen later as an IO buffer. (We'll be atomic add into the buffer and output it directly as the same tensor). Currently the reference implementation doesn't do that yet and uses an out-of-place version instead. We'll modify that once we have proper support on outputs aliasing outputs. 2. We currently reject `IndexPutAccumulateOp` from schedulers, we'll remove it when we have proper codegen support. --- CMakeLists.txt | 1 + csrc/device_lower/utils.cpp | 1 + csrc/dispatch.h | 1 + csrc/ir/internal_nodes.h | 56 ++++++++++++++ csrc/ir/nodes.cpp | 47 +++++++++++ csrc/logical_domain_map.cpp | 49 ++++++++---- csrc/ops/indexing.cpp | 38 +++++++++ csrc/ops/indexing.h | 6 ++ csrc/scheduler/expr_eval_sched.cpp | 8 +- csrc/scheduler/registry.cpp | 7 +- tests/cpp/test_index_put.cpp | 120 +++++++++++++++++++++++++++++ 11 files changed, 316 insertions(+), 18 deletions(-) create mode 100644 tests/cpp/test_index_put.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 226e1acc396..8ea2e6a31db 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -578,6 +578,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp ${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp ${NVFUSER_ROOT}/tests/cpp/test_index_select.cpp + ${NVFUSER_ROOT}/tests/cpp/test_index_put.cpp ${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp ${NVFUSER_ROOT}/tests/cpp/test_interval_analysis.cpp ${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 4bd53083c6f..ac4981ccafa 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -139,6 +139,7 @@ bool isTvOp(const Expr* expr) { TensorConstruct, SelectOp, IndexSelectOp, + IndexPutAccumulateOp, GatherOp, ScatterOp, RNGOp, diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 1ca0a2c460d..f1f4153d1d2 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -84,6 +84,7 @@ class Val; f(TensorConstruct); \ f(SelectOp); \ f(IndexSelectOp); \ + f(IndexPutAccumulateOp); \ f(GatherOp); \ f(ScatterOp); \ f(RNGOp); \ diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 91d3ca4ec39..a77545d63cd 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -129,6 +129,62 @@ class IndexSelectOp : public Expr { } }; +class IndexPutAccumulateOp : public Expr { + public: + using Expr::Expr; + + // [ Note -- IndexPutAccumulateOp semantics ] + // + // logical ID groups of IndexPutAccumulateOp + // args: + // acc [ ID_indexed_g0, ID_g0 ] + // index [ ID_indexing_g1, ID_broadcast ] + // value [ ID_indexing_g1, ID_g0 ] + // output: + // out [ ID_indexed_g0, ID_g0 ] + // + // Note that: + // 1. indexed ID for `out` and `acc` share the same extent. + // 2. indexed ID for `index` and `value` share the same extent. + IndexPutAccumulateOp( + IrBuilderPasskey, + Val* out, + Val* acc, + Val* index, + Val* value); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "IndexPutAccumulateOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; + + TensorView* accumulateTv() const { + return input(0)->as(); + } + + TensorView* indexTv() const { + return input(1)->as(); + } + + TensorView* valueTv() const { + return input(2)->as(); + } + + // return ID_indexing_g1 from value + IterDomain* getIndexingIDOfValue() const; + + // return ID_indexing_g1 from index, for IndexPutAccumulate, there's only one + // indexing ID, while the remaining ID is broadcast + IterDomain* getIndexingID() const; +}; + class NVF_API GatherOp : public Expr { public: using Expr::Expr; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index de40861cb6c..577168d14cf 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -177,6 +177,53 @@ std::vector IndexSelectOp::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(IndexSelectOp) +IndexPutAccumulateOp::IndexPutAccumulateOp( + IrBuilderPasskey passkey, + Val* out, + Val* acc, + Val* index, + Val* value) + : Expr(passkey) { + addInput(acc); + addInput(index); + addInput(value); + addOutput(out); +} + +std::string IndexPutAccumulateOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << output(0)->toString() << "\n"; + indent_size++; + indent(ss, indent_size) << " = indexPutAccumulate( "; + ss << input(0)->toString() << ", " << input(1)->toString() << ", " + << input(2)->toString() << " )\n"; + return ss.str(); +} + +std::string IndexPutAccumulateOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "Tensor op can not be printed inline"); +} + +IterDomain* IndexPutAccumulateOp::getIndexingIDOfValue() const { + return TensorDomain::noReductions(valueTv()->getLogicalDomain()).front(); +} + +IterDomain* IndexPutAccumulateOp::getIndexingID() const { + return TensorDomain::noReductions(indexTv()->getLogicalDomain()).front(); +} + +std::vector IndexPutAccumulateOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + return {at::index_put( + /*self=*/inputs.at(0).as(), + /*indices=*/{inputs.at(1).as()}, + /*values=*/inputs.at(2).as(), + /*accumulate=*/true)}; +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(IndexPutAccumulateOp) + GatherOp::GatherOp( IrBuilderPasskey passkey, Val* out, diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 9092fe44ffa..8e1cc09b527 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -74,32 +74,47 @@ PairwiseLogicalDomainMap::PairwiseLogicalDomainMap( namespace { -// Returns a producer ID that is indirectly accessed. A bool is also -// returned indicating there's a corresponding consumer ID. For -// example, select doesn't have a consumer ID, whereas index_select -// does. -std::pair getIndexedDomainInfo( +// Returns producer IDs that don't map identically to consumer. A bool is +// returned indicating whether corresponding consumer IDs exists. For example, +// select doesn't have a consumer ID, whereas index_select does. +std::pair, bool> getNonMappingDomainInfo( const TensorView* producer_tv, const TensorView* consumer_tv) { - IterDomain* indexed_id = nullptr; + std::unordered_set non_mapping_ids; bool has_consumer_id = false; if (auto sop = dynamic_cast(consumer_tv->definition())) { - indexed_id = sop->getIndexedID(); + // indexed ID is indirectly accessed + non_mapping_ids.insert(sop->getIndexedID()); has_consumer_id = false; } else if ( auto sop = dynamic_cast(consumer_tv->definition())) { + // indexed ID is indirectly accessed if (producer_tv == sop->lookupTv()) { - indexed_id = sop->getIndexedID(); + non_mapping_ids.insert(sop->getIndexedID()); has_consumer_id = true; } } else if (auto gop = dynamic_cast(consumer_tv->definition())) { + // indexed ID is indirectly accessed if (producer_tv == gop->lookupTv()) { - indexed_id = gop->getIndexedID(); + non_mapping_ids.insert(gop->getIndexedID()); + has_consumer_id = true; + } + } else if ( + auto iaop = + dynamic_cast(consumer_tv->definition())) { + // see [ Note -- IndexPutAccumulateOp semantics ] + if (producer_tv == iaop->indexTv()) { + // Indexing ID of index tv do not map to output. + non_mapping_ids.insert(iaop->getIndexingID()); + has_consumer_id = true; + } else if (producer_tv == iaop->valueTv()) { + // indexing ID of value tv do not map to output. + non_mapping_ids.insert(iaop->getIndexingIDOfValue()); has_consumer_id = true; } } - return std::make_pair(indexed_id, has_consumer_id); + return std::make_pair(non_mapping_ids, has_consumer_id); } } // namespace @@ -120,8 +135,8 @@ std::unordered_map PairwiseLogicalDomainMap::map( squeeze_flags = sop->getSqueezeDimFlags(); } - auto [indexed_producer_id, has_consumer_of_indexed_id] = - getIndexedDomainInfo(producer_tv_, consumer_tv_); + auto [non_mapping_producer_id, has_consumer_of_indexed_id] = + getNonMappingDomainInfo(producer_tv_, consumer_tv_); std::unordered_map dom_map; const auto producer_logical = TensorDomain::noReductions(producer->logical()); @@ -339,13 +354,14 @@ std::unordered_map PairwiseLogicalDomainMap::map( IterDomain* consumer_id = consumer_root.at(itc); // Conditions to check: - // 1. Indirectly accessed IDs (e.g., select) + // 1. Non mapping IDs (e.g., select) // 2. IDs that may have different extents (e.g., non indexed // domains of torchGather) // 3. Squeeze and unsqueeze - // Condition 1: when the producer ID is the dim of a select-like op - if (producer_id == indexed_producer_id) { + // Condition 1: when the producer ID is the dim of a select-like op, or when + // it doesn't map to the output IDs, like indexing IDs of indexPutAccumulate + if (non_mapping_producer_id.count(producer_id) != 0) { // If there's no corresponding consumer, skip the indexed producer if (!has_consumer_of_indexed_id) { itp++; @@ -362,7 +378,8 @@ std::unordered_map PairwiseLogicalDomainMap::map( // Condition 2: Different extents if (auto gop = dynamic_cast(consumer_tv_->definition()); gop != nullptr && !gop->exactSizes() && - producer_tv_ == gop->lookupTv() && producer_id != indexed_producer_id && + producer_tv_ == gop->lookupTv() && + non_mapping_producer_id.count(producer_id) == 0 && !map_different_extents_) { itp++; itc++; diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index 5ff75065ff2..f05601fd6a6 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -99,6 +99,44 @@ TensorView* indexSelect( return out; } +// This is a restricted version of torch.index_put(..., accumulate=true) +TensorView* indexPutAccumulate( + TensorView* acc_tv, + TensorView* index_tv, + TensorView* value_tv) { + DataType dtype = acc_tv->getDataType().value(); + NVF_CHECK( + dtype != DataType::Null, "Invalid datatype provided for new value."); + + // broadcast index_tv if applicable + if (index_tv->nDims() == 1) { + index_tv = unsqueeze(index_tv, -1); + } + + std::vector acc_domain = + TensorDomain::noReductions(acc_tv->getLogicalDomain()); + std::vector index_domain = + TensorDomain::noReductions(index_tv->getLogicalDomain()); + std::vector value_domain = + TensorDomain::noReductions(value_tv->getLogicalDomain()); + + NVF_CHECK(acc_domain.size() == 2); + NVF_CHECK(index_domain.size() == 2); + NVF_CHECK(index_domain.at(1)->isBroadcast()); + NVF_CHECK(value_domain.size() == 2); + // IndexPutAccumulateOp semantics + // + // Producers: + // accumulate [ vocab, hidden ] + // broadcast_index [ seq, broadcast ] + // value [ seq, hidden ] + // Consumers: + // output [ vocab, hidden ] + TensorView* out = ops::newValLike(acc_tv, dtype)->as(); + IrBuilder::create(out, acc_tv, index_tv, value_tv); + return out; +} + // torch.gather TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index c8152c33f82..96eceb515b5 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -23,6 +23,12 @@ NVF_API TensorView* indexSelect( int64_t dim, TensorView* index); +// This is a restricted version of torch.index_put(..., accumulate=true) +TensorView* indexPutAccumulate( + TensorView* acc_tv, + TensorView* index_tv, + TensorView* value_tv); + // torch.gather NVF_API TensorView* gather(TensorView* input, int64_t dim, TensorView* index); diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 684beae07fa..5a35c7884a7 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -51,7 +51,13 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - if (exprs.front()->isOneOf()) { + // TODO: remove IndexPutAccumulateOp + if (exprs.front() + ->isOneOf< + SdpaFwdOp, + SdpaBwdOp, + EmbeddingFwdOp, + IndexPutAccumulateOp>()) { return true; } diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 33316c9480e..6e3261008e6 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -35,7 +35,12 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) { // These ops are are only accepted in `ExprEval` // scheduler, all other schedulers should reject them. - if (ir_utils::hasOpsOfType(fusion)) { + // TODO: remove IndexPutAccumulateOp + if (ir_utils::hasOpsOfType< + SdpaFwdOp, + SdpaBwdOp, + EmbeddingFwdOp, + IndexPutAccumulateOp>(fusion)) { scheduler_debug_utils::canScheduleRejectReason( scheduler_type, "Has unsupported ops"); return false; diff --git a/tests/cpp/test_index_put.cpp b/tests/cpp/test_index_put.cpp new file mode 100644 index 00000000000..024ef910398 --- /dev/null +++ b/tests/cpp/test_index_put.cpp @@ -0,0 +1,120 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include + +namespace nvfuser { + +struct SizeParams { + int64_t vocab_size; + int64_t hidden_size; + int64_t seq_size; +}; + +std::vector generateSizeOneParams() { + int64_t vocab_size = 1024; + int64_t hidden_size = 3584; + int64_t seq_size = 3000; + std::vector params; + for (bool size_one_vocab : {true, false}) { + for (bool size_one_hidden : {true, false}) { + for (bool size_one_seq : {true, false}) { + int64_t vocab = size_one_vocab ? 1 : vocab_size; + int64_t hidden = size_one_hidden ? 1 : hidden_size; + int64_t seq = size_one_seq ? 1 : seq_size; + params.push_back({vocab, hidden, seq}); + } + } + } + return params; +} + +class IndexPut : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + NVFuserTest::SetUp(); + } +}; + +INSTANTIATE_TEST_SUITE_P( + , + IndexPut, + ::testing::ValuesIn(generateSizeOneParams())); + +// Note: The semantics doesn't support broadcast on operands, adding `size 1` +// check just to ensure the ID mapping is done correctly. +TEST_P(IndexPut, AccumulateOpWithBroadcastIDs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto [vocab, hidden, seq] = GetParam(); + + std::vector shape1({seq, hidden}); + std::vector shape2({seq, 1}); + + auto tv_value = makeSymbolicTensor(shape1); + fusion.addInput(tv_value); + auto tv_index = makeSymbolicTensor(shape2, DataType::Int); + fusion.addInput(tv_index); + auto s_vocab = IrBuilder::create(vocab, DataType::Index); + std::vector buffer_size = { + s_vocab, tv_value->axis(-1)->extent()}; + auto buf = zeros(buffer_size, DataType::Float, true); + // TODO: this should be an inplace. handle it when we have codegen support + auto out = indexPutAccumulate(buf, tv_index, tv_value); + fusion.addOutput(out); + + // check PairwiseLogicalDomainMap check if tv0 and tv1 map pairwise on + // position according to `expect_to_map` + auto map_logical = [](const std::vector& expect_to_map, + TensorView* tv0, + TensorView* tv1) { + std::unordered_map pairwise_map = + PairwiseLogicalDomainMap(tv0, tv1).mapProducerToConsumer(); + for (auto index : arange(expect_to_map.size())) { + IterDomain* id0 = tv0->getLogicalDomain().at(index); + IterDomain* id1 = tv1->getLogicalDomain().at(index); + EXPECT_EQ( + pairwise_map.find(id0) != pairwise_map.end() && + pairwise_map[id0] == id1, + expect_to_map[index]); + } + }; + + // see [ Note -- IndexPutAccumulateOp semantics ] + // args: + // buf [ ID_indexed_g0, ID_g0 ] + // tv_index [ ID_indexing_g1, ID_broadcast ] + // tv_value [ ID_indexing_g1, ID_g0 ] + // output: + // out [ ID_indexed_g0, ID_g0 ] + map_logical({true, true}, buf, out); + // depends on the size of ID_g0, it would map to ID_broadcast when hidden is + // size-1 dimension + map_logical({false, hidden == 1}, tv_index, out); + map_logical({false, true}, tv_value, out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + auto t_value = at::randn(shape1, options); + auto t_index = at::randint(0, vocab, shape2, options_i); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t_value, t_index}); + + testValidate(&fusion, outputs, {t_value, t_index}, __LINE__, __FILE__); +} + +} // namespace nvfuser From f24dc1331ce4e7fd4a46cf410f2736075c0b5d4f Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 18 Apr 2025 10:52:56 -0700 Subject: [PATCH 24/68] Minor fix on inline_ptx.cpp (#4278) --- csrc/device_lower/pass/inline_ptx.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 704475a0152..7ac4aebc577 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -378,7 +378,7 @@ class LowerToInlinePtx : public kir::ExprMutator { Val* enable_input_d = getUseInputAcc(mma); // Do MMA - registerInsertBefore( + registerReplace( mma, IrBuilder::create( "tcgen05.mma.cta_group::1.kind::f16", From 7bc8c17da13e1dd7c4e23b71ff3a485db441bcce Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Fri, 18 Apr 2025 11:05:57 -0700 Subject: [PATCH 25/68] Rename `ldstMBarrierMap` -> `mbarrierMap` (#4277) We will use this to track the mbarrier for Blackwell MmaOp --------- Co-authored-by: Ryan Spring --- csrc/device_lower/lower2device.h | 12 ++++++------ csrc/device_lower/pass/allocation.cpp | 4 ++-- csrc/device_lower/pass/circular_buffer.cpp | 20 ++++++++++---------- csrc/device_lower/pass/index.cpp | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 3e60d13621c..883c53f1ba9 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -227,12 +227,12 @@ class GpuLower : public NonCopyable { return profile_; } - std::unordered_map& ldstMBarrierMap() { - return ldst_mbarrier_map_; + std::unordered_map& mbarrierMap() { + return mbarrier_map_; } - const std::unordered_map& ldstMBarrierMap() const { - return ldst_mbarrier_map_; + const std::unordered_map& mbarrierMap() const { + return mbarrier_map_; } bool isNvFuserZeroEnabled() { @@ -432,8 +432,8 @@ class GpuLower : public NonCopyable { // precomputed values std::vector all_known_vals_; - // Keep track of the mbarrier used for each load/store operation - std::unordered_map ldst_mbarrier_map_; + // Keep track of the mbarrier used for each load/store and blackwell utcmma + std::unordered_map mbarrier_map_; // Information about tensor memory usage TensorMemoryInfo tmem_info_; diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 47e841b92ef..7f92d174cd4 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1376,7 +1376,7 @@ class AllocationInserter : public kir::ExprMutator { registerInsertBefore(expr, sync_init, expr_scope); registerInsertAfter(expr, mbarrier_inval, expr_scope); registerInsertAfter(expr, sync_inval, expr_scope); - GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier; + GpuLower::current()->mbarrierMap()[expr] = mbarrier; } } @@ -1484,7 +1484,7 @@ class AllocationInserter : public kir::ExprMutator { continue; } // Map LoadStoreOp expression to ir nodes created in this pass - GpuLower::current()->ldstMBarrierMap()[tv->definition()] = mbarrier; + GpuLower::current()->mbarrierMap()[tv->definition()] = mbarrier; } } } diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index cee3b5cbae5..c970283daa2 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -618,7 +618,7 @@ class CloneTmaCircularBufferLoopAndInsertSync return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); for (auto tv : ir_utils::filterByType(expr->inputs())) { // short-circuit: The TensorView input for current expression is not @@ -657,7 +657,7 @@ class CloneTmaCircularBufferLoopAndInsertSync !hasCircularBufferLoad()) { return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); for (auto tv : ir_utils::filterByType(expr->outputs())) { // short-circuit: The current expression is not a circular buffer load, so @@ -694,7 +694,7 @@ class CloneTmaCircularBufferLoopAndInsertSync return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); // remove expr from war_mbarriers_to_uses_ auto input_tvs = ir_utils::filterByType(expr->inputs()); for (auto tv : input_tvs) { @@ -799,7 +799,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // expressions std::unordered_map getAllMbarriersToWait() { - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); std::unordered_map wait_exprs; for (auto tv : circular_buffer_load_tvs_) { LoadStoreOp* ldst = dynamic_cast(tv->definition()); @@ -820,7 +820,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // buffer tensor tracked by this mbarrier. std::unordered_map> getAllWarMbarriersToUses() { - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); std::unordered_map> mbarrier_to_uses; auto exprs = ir_utils::flattenScopedExprs(circular_buffer_loop_->body().exprs()); @@ -949,7 +949,7 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(ldst != nullptr); // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create(all_mbarriers, currentLoadStage()); @@ -985,7 +985,7 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(ldst != nullptr); // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create( all_mbarriers, currentComputeStage()); @@ -1007,7 +1007,7 @@ class CloneTmaCircularBufferLoopAndInsertSync .stage; // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create( all_mbarriers, SimplifyingIrBuilder::addExpr(currentLoadStage(), stage_depth)); @@ -1333,8 +1333,8 @@ class CircularBufferInserter : private kir::ExprMutator { for (auto tv : circular_buffer_tvs) { auto ldst = dynamic_cast(tv->definition()); NVF_ERROR(ldst != nullptr); - auto it = GpuLower::current()->ldstMBarrierMap().find(ldst); - if (it == GpuLower::current()->ldstMBarrierMap().end()) { + auto it = GpuLower::current()->mbarrierMap().find(ldst); + if (it == GpuLower::current()->mbarrierMap().end()) { continue; } mbarriers.pushBack(it->second); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 8bfe1c242e8..c84bf7a5b8b 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1522,7 +1522,7 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) { GpuLower::current()->propagateExprInfo(ldst, back()); } else { - TensorView* mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* mbarrier = GpuLower::current()->mbarrierMap().at(ldst); Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier); // gmem indexing and expect_bytes for mbarrier From ed687366cf717837c8ea3e40f56542fec48e1616 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Fri, 18 Apr 2025 14:51:12 -0700 Subject: [PATCH 26/68] `shardAllLike` accepts a list of parallel types (#4254) By default, `shardAllLike` propagates all DID parallel types. This allows us to control exactly which parallel types to propagate, essential for #3838. --------- Co-authored-by: Jingyue Wu --- csrc/multidevice/utils.cpp | 13 ++++++++----- csrc/multidevice/utils.h | 11 +++++++++-- csrc/preseg_passes/insert_reshardings.cpp | 20 ++++++++++++++++++-- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ac6d7152592..bce292d1377 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -630,14 +630,17 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs) { +void shardAllLike( + TensorView* ref, + const std::vector& tvs, + const std::unordered_set& parallel_types) { + if (tvs.empty()) { + return; + } for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } - if (!tvs.empty()) { - scheduler_utils::parallelizeAllLike( - ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); - } + scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); } void shardBetween( diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 34c510ccb2e..4134c943fac 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -57,8 +57,15 @@ bool haveDifferentShardings( // Returns whether a resharding expr reshards an inner axis bool isInnerResharding(Expr* expr); -// Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs); +// Shards all tensors in tvs like reference. +// Accepts a set of parallel types to shard on. +// If empty, all DID parallel types are used. +void shardAllLike( + TensorView* ref, + const std::vector& tvs, + const std::unordered_set& parallel_types = { + kParallelTypeDIDs.begin(), + kParallelTypeDIDs.end()}); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index 264facb743d..52adef04b30 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -29,6 +29,20 @@ bool shouldReshardAfter(Expr* expr) { return expr->inputs().size() == 1 && expr->outputs().size() == 1; } +std::unordered_set getParallelTypesForResharding() { + // Consider a reshard case: + // input [DIDx(i0), i1] -> op -> output [i0, DIDx(i1)] + // This is decomposed into: + // input [DIDx(i0), i1] -> op -> output [DIDx(i0), i1] -> set -> + // new_output [i0, DIDx(i1)] ParallelType::Serial is required here so the + // output is sharded as [DIDx(i0), i1] instead of [DIDx(i0), DIDx(i1)] + // when sharding using input as the reference. + std::unordered_set parallel_types{ + kParallelTypeDIDs.begin(), kParallelTypeDIDs.end()}; + parallel_types.insert(ParallelType::Serial); + return parallel_types; +} + void insertReshardingSetsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); @@ -70,7 +84,8 @@ void insertReshardingSetsBefore(Fusion* fusion) { new_inputs.push_back(new_input); expr = ir_utils::replaceValInExprInputs(expr, input, new_input); } - shardAllLike(output, new_inputs); + + shardAllLike(output, new_inputs, getParallelTypesForResharding()); } } @@ -110,7 +125,8 @@ void insertReshardingSetsAfter(Fusion* fusion) { // Update shardings new_output takes output's sharding, // output takes input's sharding shardAllLike(output, {new_output}); - shardAllLike(input, {output}); + + shardAllLike(input, {output}, getParallelTypesForResharding()); } } } From c969903a8d357159df2dcaf27913f60aece20ae7 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 18 Apr 2025 21:10:01 -0700 Subject: [PATCH 27/68] Tensor-parallelize the DeepSeek V3 transformer layer (#4062) cc @syed-ahmed --- tests/python/multidevice/test_deepseek_v3.py | 143 +++++++++++++++++++ tests/python/test_deepseek_v3.py | 60 -------- 2 files changed, 143 insertions(+), 60 deletions(-) create mode 100644 tests/python/multidevice/test_deepseek_v3.py delete mode 100644 tests/python/test_deepseek_v3.py diff --git a/tests/python/multidevice/test_deepseek_v3.py b/tests/python/multidevice/test_deepseek_v3.py new file mode 100644 index 00000000000..0bbfe4d1a75 --- /dev/null +++ b/tests/python/multidevice/test_deepseek_v3.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import nvfuser +import pytest +import transformers +import torch +import torch.distributed as dist +from contextlib import contextmanager +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.parallel import ( + parallelize_module, + RowwiseParallel, + ColwiseParallel, +) + + +# Set up the default process group for torch APIs like +# dist.device_mesh.init_device_mesh. +@pytest.fixture(scope="module") +def setup_process_group(): + communicator = nvfuser.Communicator.instance() + + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. + dist.init_process_group( + backend="nccl", + init_method="tcp://localhost:29500", + world_size=communicator.size(), + rank=communicator.rank(), + ) + yield + dist.destroy_process_group() + + +@contextmanager +def default_tensor_type(dtype=torch.float32, device="cpu"): + # Save + prev_dtype = torch.get_default_dtype() + prev_device = torch.get_default_device() + + # Set + torch.set_default_dtype(dtype) + torch.set_default_device(device) + + yield + + # Restore + torch.set_default_dtype(prev_dtype) + torch.set_default_device(prev_device) + + +# This test timed out once when downloading +# "/deepseek-ai/DeepSeek-V3/resolve/main/configuration_deepseek.py" (cf. +# http://nv/eCm). I consider this a one-off, but please let me know if this +# error becomes consistent. +@pytest.mark.mpi +def test_transformer_layer(setup_process_group): + config = transformers.AutoConfig.from_pretrained( + "deepseek-ai/deepseek-v3", trust_remote_code=True + ) + + # Create only one layer which is sufficient for the test. + config.num_hidden_layers = 1 + # Without this, the first and only layer will have a dense MLP instead of MoE. + config.first_k_dense_replace = 0 + # Disable quantization so the test can run on A100 and is made easier for nvFuser. + delattr(config, "quantization_config") + + d = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + # This ensures the input tokens are identically replicated on all ranks. + # Otherwise, some ranks may skip an expert because they have no tokens to + # send, while other ranks don't. This will cause a deadlock because a NCCL + # collective is expected to be called by all ranks in the process group. + torch.manual_seed(0) + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + + with default_tensor_type(dtype=config.torch_dtype, device="cuda"): + model = transformers.AutoModel.from_config(config, trust_remote_code=True) + # Training is unavailable (cf. https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L439) + model.eval() + + transformer_layer = model.layers[0] + + # By default, RowwiseParallel and ColwiseParallel output a local tensor + # and therefore num_heads needs to be adjusted to accomodate the local + # size. Alternatively, I could RowwiseParallel(use_local_output=False) + # so the linear layer outputs a DTensor, which can be viewed using the + # original num_heads. This requires all activations, parameters, and + # buffers to be DTensor; otherwise aten ops would complain "got mixed + # torch.Tensor and DTensor". Doing so is challenging because + # DeepseekV3RotaryEmbedding creates cos_cached and sin_cached during + # the first forward call (cf. + # https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L143-L144). + transformer_layer.self_attn.num_heads //= d + + # Create the parallel plan + parallel_plan = { + "self_attn.q_b_proj": ColwiseParallel(), + "self_attn.kv_b_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(), + } + + for expert in range(config.n_routed_experts): + parallel_plan[f"mlp.experts.{expert}.gate_proj"] = ColwiseParallel() + parallel_plan[f"mlp.experts.{expert}.up_proj"] = ColwiseParallel() + parallel_plan[f"mlp.experts.{expert}.down_proj"] = RowwiseParallel() + + parallel_plan["mlp.shared_experts.gate_proj"] = ColwiseParallel() + parallel_plan["mlp.shared_experts.up_proj"] = ColwiseParallel() + parallel_plan["mlp.shared_experts.down_proj"] = RowwiseParallel() + + transformer_layer = parallelize_module( + transformer_layer, + mesh, + parallel_plan, + ) + + # Sanity-check parameters are indeed distributed + distributed_params: list[str] = [ + name + for name, parameter in transformer_layer.named_parameters() + if isinstance(parameter.data, DTensor) + ] + assert len(distributed_params) == 3 + (config.n_routed_experts + 1) * 3 + + batch_size = 1 + seq_len = 2048 + inp = torch.randn(batch_size, seq_len, config.hidden_size) + mask = transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask( + None, [batch_size, seq_len], inp, past_key_values_length=0 + ) + (out,) = transformer_layer(inp, attention_mask=mask) + # Finish all computation and communication. Otherwise, + # destroy_process_group may deadlock. + torch.cuda.synchronize() + + assert out.size() == (batch_size, seq_len, config.hidden_size) + assert out.dtype == config.torch_dtype + assert out.is_cuda diff --git a/tests/python/test_deepseek_v3.py b/tests/python/test_deepseek_v3.py deleted file mode 100644 index d13b853e706..00000000000 --- a/tests/python/test_deepseek_v3.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import transformers -import torch -from contextlib import contextmanager - - -@contextmanager -def default_tensor_type(dtype=torch.float32, device="cpu"): - # Save - prev_dtype = torch.get_default_dtype() - prev_device = torch.get_default_device() - - # Set - torch.set_default_dtype(dtype) - torch.set_default_device(device) - - yield - - # Restore - torch.set_default_dtype(prev_dtype) - torch.set_default_device(prev_device) - - -# This test timed out once when downloading -# "/deepseek-ai/DeepSeek-V3/resolve/main/configuration_deepseek.py" (cf. -# http://nv/eCm). I consider this a one-off, but please let me know if this -# error becomes consistent. -def test_transformer_layer(): - config = transformers.AutoConfig.from_pretrained( - "deepseek-ai/deepseek-v3", trust_remote_code=True - ) - - # Create only one layer which is sufficient for the test. - config.num_hidden_layers = 1 - # Without this, the first and only layer will have a dense MLP instead of MoE. - config.first_k_dense_replace = 0 - # Disable quantization so the test can run on A100 and is made easier for nvFuser. - delattr(config, "quantization_config") - - with default_tensor_type(dtype=config.torch_dtype, device="cuda"): - model = transformers.AutoModel.from_config(config, trust_remote_code=True) - # Training is unavailable (cf. https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L439) - model.eval() - - transformer_layer = model.layers[0] - - batch_size = 1 - seq_len = 2048 - inp = torch.randn(batch_size, seq_len, config.hidden_size) - mask = transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask( - None, [batch_size, seq_len], inp, past_key_values_length=0 - ) - (out,) = transformer_layer(inp, attention_mask=mask) - - assert out.size() == (batch_size, seq_len, config.hidden_size) - assert out.dtype == config.torch_dtype - assert out.is_cuda From bb5b38cdadf498a007df2f78359503e1ed57a809 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 18 Apr 2025 21:39:26 -0700 Subject: [PATCH 28/68] Disable two flaky tests to keep CI green (#4283) --- tests/cpp/test_multidevice_communications.cpp | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index af0c0719aa7..1b6ce59801c 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -417,7 +417,7 @@ INSTANTIATE_TEST_SUITE_P( using P2PCommunicationTest = MultiDeviceTest; -TEST_F(P2PCommunicationTest, CudaComm) { +TEST_F(P2PCommunicationTest, DISABLED_CudaComm) { static constexpr int kTensorSize = 8; static constexpr int kNumRepetitions = 32; diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 88286d6e4c0..0b6efbd15a4 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -478,7 +478,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_linear) { EXPECT_TRUE(torch::allclose(out_ref, out_at, 1e-1, 1e-1)); } -TEST_F(MultiDeviceTest, ShareIpcMemHandles) { +TEST_F(MultiDeviceTest, DISABLED_ShareIpcMemHandles) { static constexpr int kTensorSize = 4; static constexpr int kNumRepetitions = 10; From 39aec1668e1666f88a11e763103a1428c4beacc4 Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Sat, 19 Apr 2025 15:01:21 -0700 Subject: [PATCH 29/68] Use `tcgen05` as namespace for TMem ld/st (#4279) For better consistency --- csrc/kernel_ir.cpp | 4 ++-- tests/cpp/test_tmem.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index 546e4c970a3..d14c007515b 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -457,7 +457,7 @@ std::string Asm::utility() const { std::regex ld_pattern(R"(tcgen05\.ld\.sync\.aligned\.([^.]+)\.x\d+\.b32)"); std::smatch match; if (std::regex_match(code, match, ld_pattern)) { - std::string result = "tmem::load"; + std::string result = "tcgen05::load"; result.append(match[1]); return result; } @@ -466,7 +466,7 @@ std::string Asm::utility() const { std::regex st_pattern(R"(tcgen05\.st\.sync\.aligned\.([^.]+)\.x\d+\.b32)"); std::smatch match; if (std::regex_match(code, match, st_pattern)) { - std::string result = "tmem::store"; + std::string result = "tcgen05::store"; result.append(match[1]); return result; } diff --git a/tests/cpp/test_tmem.cpp b/tests/cpp/test_tmem.cpp index 547b0de840f..ab2b7aebfcb 100644 --- a/tests/cpp/test_tmem.cpp +++ b/tests/cpp/test_tmem.cpp @@ -290,7 +290,7 @@ TEST_F(TMemTestCompileOnly, SetTMemDimSepPosNonTMem) { // But in the TMem load/store's loop domain, Ix (the ID parallelized on TIDx) // have extent 32. Then we will generate code like: // if (threadIdx.x < 32) { -// tmem::load +// tcgen05::load // } // For threadIdx.y == 0, it is correct. But for threadIdx.y == 1, it is wrong // because we are using the thread id 33-65 for the load, which is not a warp. @@ -342,7 +342,7 @@ TEST_F(TMemTestCompileOnly, WrongStride) { // map is [TIDy, TIDx] = [2, 33], but in the TMem load/store's loop domain, // we have Iy{1}, Ix{32}. the generated code will be like: // if (threadIdx.x < 32 && threadIdx.y < 1) { -// tmem::load +// tcgen05::load // } // This is valid because we are using a whole warp for the load. TEST_F(TMemTest, InexactParallelType) { From 857d1df833b70b38ca31e3640016b436c972503b Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" Date: Sat, 19 Apr 2025 18:05:42 -0700 Subject: [PATCH 30/68] Use mbarrier to sync Blackwell MMA (#4276) Stacked on https://github.com/NVIDIA/Fuser/pull/4277 Replaces `nanosleep` with mbarrier.
Show example generated code: ```CUDA // Codegen generated code namespace tcgen05 { __device__ __inline__ void alloc(uint32_t in0, uint32_t in1) { asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n"::"r"(in0), "r"(in1)); } __device__ __inline__ void relinquishAllocPermit() { asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n"); } __device__ __inline__ void mma_f16(uint32_t in0, uint64_t in1, uint64_t in2, uint32_t in3, bool in4) { asm volatile( "{\n" " .reg .pred p0; \n" " setp.ne.b32 p0, %4, 0;\n" " tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p0;\n" "}\n" : :"r"(in0), "l"(in1), "l"(in2), "r"(in3), "r"((uint32_t)(in4)) ); } __device__ __inline__ void commit(uint32_t in0) { asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];\n"::"r"(in0)); } __device__ __inline__ void waitLoad() { asm volatile("tcgen05.wait::ld.sync.aligned;\n"); } __device__ __inline__ void dealloc(uint32_t in0, uint32_t in1) { asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n"::"r"(in0), "r"(in1)); } } // namespace tcgen05 namespace tmem { __device__ __inline__ void load32x32b(Array& out0, uint32_t in0) { asm( "tcgen05.ld.sync.aligned.32x32b.x1.b32 {%0}, [%1];\n" :"=f"(out0[0]) :"r"(in0) ); } } // namespace tmem __global__ void nvfuser_none_f0_c0_r0_g0(Tensor<__half, 3, 3> T0, Tensor<__half, 3, 3> T1, Tensor T6) { alignas(16) extern __shared__ char array[]; const unsigned smem_offset = 0; nvfuser_index_t i0; i0 = ((nvfuser_index_t)threadIdx.x) % 8; nvfuser_index_t i1; i1 = (((nvfuser_index_t)threadIdx.x) / 32) % 2; nvfuser_index_t i2; i2 = (T1.alloc_stride[1LL] * i1) + (T1.alloc_stride[2LL] * i0); nvfuser_index_t i3; i3 = 8 * T1.alloc_stride[1LL]; nvfuser_index_t i4; i4 = 2 * T1.alloc_stride[1LL]; nvfuser_index_t i5; i5 = ((nvfuser_index_t)threadIdx.x) / 64; nvfuser_index_t i6; i6 = 8 * T1.alloc_stride[2LL]; nvfuser_index_t i7; i7 = ((nvfuser_index_t)threadIdx.x) / 8; nvfuser_index_t i8; i8 = i7 % 4; nvfuser_index_t i9; i9 = i5 % 2; nvfuser_index_t i10; i10 = i7 % 8; nvfuser_index_t i11; i11 = ((T0.alloc_stride[0LL] * i10) + ((8 * T0.alloc_stride[1LL]) * i9)) + (T0.alloc_stride[1LL] * i0); nvfuser_index_t i12; i12 = 8 * T0.alloc_stride[0LL]; __half* T3 = reinterpret_cast<__half*>(array + smem_offset + 4608); __half* T2 = reinterpret_cast<__half*>(array + smem_offset + 128); uint32_t i13; i13 = toSmem(T2); uint32_t i14; i14 = toSmem(T3); uint16_t i15; i15 = (uint16_t)((32LL * (((nvfuser_index_t)threadIdx.x) / 32LL))); nvfuser_index_t i16; i16 = 16 * ((nvfuser_index_t)threadIdx.x); bool b17; b17 = ((nvfuser_index_t)threadIdx.x) < 32LL; bool b18; b18 = ((8 * i9) + i0) < 16; nvfuser_index_t i19; i19 = -128 + i10; uint32_t* T7 = reinterpret_cast(array + smem_offset + 0); if (b17) { tcgen05::alloc((uint32_t)(toSmem(T7)), 32U); } if (b17) { tcgen05::relinquishAllocPermit(); } __syncthreads(); #pragma unroll for(nvfuser_index_t i20 = 0; i20 < 4; ++i20) { T3[(((nvfuser_index_t)threadIdx.x) + (128 * i20))] = 0.000000000e+00f; } #pragma unroll for(nvfuser_index_t i20 = 0; i20 < 4; ++i20) { nvfuser_index_t i21; i21 = (i5 + (2 * i20)) % 4; nvfuser_index_t i22; i22 = i20 / 2; nvfuser_index_t i23; i23 = i8 ^ i21; nvfuser_index_t i24; i24 = i0 + (8 * i23); if ((((((i1 + (8 * i22)) + (2 * i21)) < 16) && (i24 >= 0)) && (i24 < 16))) { T3[(((nvfuser_index_t)threadIdx.x) + (128 * i20))] = T1[(((i2 + (i3 * i22)) + (i4 * i21)) + (i6 * i23))]; } } #pragma unroll for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) { T2[(((nvfuser_index_t)threadIdx.x) + (128 * i25))] = 0.000000000e+00f; } #pragma unroll for(nvfuser_index_t i25 = 0; i25 < 16; ++i25) { if ((b18 && (i19 < (-(8 * i25))))) { T2[(((nvfuser_index_t)threadIdx.x) + (128 * i25))] = T0[(i11 + (i12 * i25))]; } } TMemTensor T4(T7[0], 0, (uint16_t)(0)); uint64_t* T8 = reinterpret_cast(array + smem_offset + 16); mbarrier::init(toSmem(T8), 1U); __syncthreads(); if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.x) < 32ULL))) { tcgen05::mma_f16((uint32_t)(T4 + Array{0, 0}), (70437464178688ULL | ((262143ULL & (uint64_t)(i13)) >> 4ULL)), (9223442543037906944ULL | ((262143ULL & (uint64_t)(i14)) >> 4ULL)), 134545424U, false); tcgen05::commit(toSmem(T8)); } mbarrier::waitParity(toSmem((&T8[0])), 0U); __syncthreads(); mbarrier::inval(toSmem(T8)); Array T5; #pragma unroll for(nvfuser_index_t i26 = 0; i26 < 16; ++i26) { tmem::load32x32b((*reinterpret_cast*>(&T5[i26])), (uint32_t)(T4 + (Array{i15, (uint16_t)(i26)}))); tcgen05::waitLoad(); } __syncthreads(); if (b17) { tcgen05::dealloc(T7[0], 32U); } #pragma unroll for(nvfuser_index_t i27 = 0; i27 < 16; ++i27) { T6[(i16 + i27)] = T5[i27]; } } ```
--- csrc/device_lower/pass/allocation.cpp | 8 +++++--- csrc/device_lower/pass/index.cpp | 8 ++++++++ csrc/device_lower/pass/insert_syncs.cpp | 16 +++++----------- csrc/kernel_ir.cpp | 2 ++ 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 7f92d174cd4..8a869073be3 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1347,7 +1347,8 @@ class AllocationInserter : public kir::ExprMutator { // circular buffering pass. // * Assume that the tma load is in ComputeWarp if it is not circular // buffered. - if (ir_utils::isCpAsyncBulkLoad(expr) && circular_buffer_depth == 1) { + if ((ir_utils::isCpAsyncBulkLoad(expr) && circular_buffer_depth == 1) || + (expr->isA() && expr->as()->isBlackwell())) { // create and allocate a memory barrier TensorView* mbarrier = TensorViewBuilder() .shape(std::vector{}) @@ -1359,8 +1360,9 @@ class AllocationInserter : public kir::ExprMutator { mbarrier, simplifyExpr(SimplifyingIrBuilder::maybeCastExpr( DataType::UInt32, - lower_utils::getNumThreadsInTensorView( - expr->output(0)->as())))); + expr->isA() ? expr->fusion()->oneVal() + : lower_utils::getNumThreadsInTensorView( + expr->output(0)->as())))); auto sync_init = IrBuilder::create( /*war_sync=*/false, /*optional_compute_or_load_sync=*/true); auto mbarrier_inval = diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index c84bf7a5b8b..ba0082e9dbc 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2697,6 +2697,14 @@ void IndexLowering::handle(const MmaOp* mma) { auto mma_indexed = IrBuilder::create(out, a, b, mma->init(), mma->macro()); pushBack(mma_indexed); + if (mma->isBlackwell()) { + pushBack(IrBuilder::create( + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64", + std::vector{}, + std::vector{lower_utils::u32IndexScalarSmemTv( + GpuLower::current()->mbarrierMap().at(mma))}, + kir::Asm::Options{/*volatile=*/true})); + } GpuLower::current()->propagateExprInfo(mma, back()); } diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index e3223da7904..09fa9df9666 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -486,19 +486,13 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // async mma pipeline has not been flushed yet. flush_async_mma_pipeline_ = false; } else if (mma->isBlackwell()) { - // TODO: This is clearly a wrong way to sync, but as an intermediate - // step to enable incremental development, we use nanosleep to sync the - // mma. We should replace this with a correct sync method. - registerInsertBefore(expr, IrBuilder::create()); registerInsertAfter( expr, - IrBuilder::create( - "nanosleep.u32", - std::vector{}, - std::vector{ - IrBuilder::create(4000000000, DataType::UInt32)}, - kir::Asm::Options{/*volatile=*/true})); - registerInsertAfter(expr, IrBuilder::create()); + IrBuilder::create( + IrBuilder::create( + GpuLower::current()->mbarrierMap().at(expr), + expr->fusion()->zeroVal()), + expr->fusion()->zeroVal(DataType::UInt32))); } } else if (ir_utils::isCpAsyncBulkStore(expr)) { // Add a fence before TMA store so that writes in the generic proxy is diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index d14c007515b..efc60b809e6 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -432,6 +432,8 @@ std::string Asm::utility() const { "tcgen05::relinquishAllocPermit"}, {"tcgen05.dealloc.cta_group::1.sync.aligned.b32", "tcgen05::dealloc"}, {"tcgen05.mma.cta_group::1.kind::f16", "tcgen05::mma_f16"}, + {"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64", + "tcgen05::commit"}, {"wgmma.fence.sync.aligned", "wgmma::fence"}, {"fence.proxy.async", "fenceAsyncProxy"}, {"wgmma.commit_group.sync.aligned", "wgmma::commit"}, From 5dac8bd9a80717b6ca2e918e8f323efe86cf0479 Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Mon, 21 Apr 2025 09:35:30 -0400 Subject: [PATCH 31/68] Add segmentation helper functions for edge processing (#4222) Trying to make DAG modifications a little easier in Segmentation. Primarily by adding APIs for removing edges and connecting groups (adding an edge). One still tricky thing is being careful when iterating over the producer or consumer edges of a group and updating edges at the same time. Standard concerns of potentially invalidating a vector you're iterating over while modifying it remain. - Adds: - void removeEdge(SegmentedEdge* edge); - void connectGroups(SegmentedGroup* from, SegmentedGroup* to, Val* val); - Updates: - void disconnectGroup(SegmentedGroup* group); - Removes some dead code --- csrc/fusion_segmenter.cpp | 400 +++++++++++++++----------------------- csrc/fusion_segmenter.h | 32 +-- 2 files changed, 178 insertions(+), 254 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index ada881996b7..bf7693c6028 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -276,43 +276,6 @@ std::vector SegmentedGroup:: return merge_candidates; } -void SegmentedGroup::clearTraversalInfo() { - level_ = -1; - merge_with_ = nullptr; - merge_through_ = nullptr; - merged_ = false; -} - -std::vector SegmentedGroup::edgesToVals( - const std::vector& se_v) { - std::vector ret_v; - ret_v.reserve(se_v.size()); - - std::transform( - se_v.cbegin(), - se_v.cend(), - std::back_inserter(ret_v), - [](SegmentedEdge* se) { return se->val; }); - return ret_v; -} - -template -void insertUniquePredicated( - std::vector& v, - const std::vector& e, - PREDICATE pred) { - VectorOfUniqueEntries to_add; - for (auto edge : e) { - to_add.pushBack(edge->val); - } - - std::copy_if( - to_add.vector().begin(), - to_add.vector().end(), - std::back_inserter(v), - [pred](Val* val) { return pred(val); }); -} - // TODO: Reevaluate what's being done in finalize void SegmentedGroup::finalize() { // Make sure all inputs and outputs of the group are now in input and output @@ -604,7 +567,7 @@ void SegmentedFusion::deserialize(const serde::SegmentedFusion* buffer) { // Construct segmented groups first because they are necessary for the // segmented edge's constructor - // NOTE: Use regular for-loop to avoid unused variable ‘idx’ error + // NOTE: Use regular for-loop to avoid unused variable 'idx' error for (size_t idx = 0; idx < buffer->groups()->size(); ++idx) { newGroup(); } @@ -683,25 +646,57 @@ SegmentedEdge* SegmentedFusion::Impl::makeEdge( return edges_.back().get(); } +void SegmentedFusion::removeEdge(SegmentedEdge* edge) { + NVF_ERROR(edge != nullptr, "Edge is nullptr"); + // Validate edge exists in all expected locations + SegmentedGroup* producer = edge->from; + SegmentedGroup* consumer = edge->to; + auto& producer_consumer_edges = producer->consumer_edges; + auto& consumer_producer_edges = consumer->producer_edges; + + // Remove edge from producer's consumer edges + auto producer_edge_it = std::find( + producer_consumer_edges.begin(), producer_consumer_edges.end(), edge); + NVF_ERROR( + producer_edge_it != producer_consumer_edges.end(), + "Edge not found in producer's consumer edges"); + producer_consumer_edges.erase(producer_edge_it); + + // Remove edge from consumer's producer edges + auto consumer_edge_it = std::find( + consumer_producer_edges.begin(), consumer_producer_edges.end(), edge); + NVF_ERROR( + consumer_edge_it != consumer_producer_edges.end(), + "Edge not found in consumer's producer edges"); + consumer_producer_edges.erase(consumer_edge_it); + + // Remove edge from global edge list + auto edge_it = std::find(edges_.begin(), edges_.end(), edge); + NVF_ERROR(edge_it != edges_.end(), "Edge not found in global edge list"); + edges_.erase(edge_it); +} + void SegmentedFusion::Impl::cleanUnused() { std::unordered_set g_used( owning_fusion_->groups().begin(), owning_fusion_->groups().end()); std::unordered_set e_used( owning_fusion_->edges().begin(), owning_fusion_->edges().end()); - groups_.erase( - std::remove_if( - groups_.begin(), - groups_.end(), - [&g_used](auto& g) { return g_used.count(g.get()) == 0; }), - groups_.end()); - + // Remove any edges that are no longer in use edges_.erase( std::remove_if( edges_.begin(), edges_.end(), [&e_used](auto& e) { return e_used.count(e.get()) == 0; }), edges_.end()); + + // Remove any groups that are no longer in use + groups_.erase( + std::remove_if( + groups_.begin(), + groups_.end(), + [&g_used](auto& g) { return g_used.count(g.get()) == 0; }), + groups_.end()); } //! Return mapping from SegmentedGroup to integer id @@ -2016,55 +2011,28 @@ void SegmentCandidateFinder::resetLevels() { } // Disconect group from neighbors, and return edges that were disconnected -std::unordered_set SegmentCandidateFinder::disconnectGroup( - SegmentedGroup* group) { - std::unordered_set removed_edges( +void SegmentCandidateFinder::disconnectGroup(SegmentedGroup* group) { + // Remove producer edges + std::vector producer_edges( group->producer_edges.begin(), group->producer_edges.end()); - - for (auto edge : group->producer_edges) { - auto from = edge->from; - auto& from_edges = from->consumer_edges; - auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge); - NVF_ERROR( - from_edge_it != from_edges.end(), "Could not find edge to remove."); - from_edges.erase(from_edge_it); + for (auto edge : producer_edges) { + segmented_fusion_->removeEdge(edge); } - for (auto edge : group->consumer_edges) { - removed_edges.insert(edge); - auto to = edge->to; - auto& to_edges = to->producer_edges; - auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge); - NVF_ERROR(to_edge_it != to_edges.end(), "Could not find edge to remove."); - to_edges.erase(to_edge_it); + // Remove consumer edges + std::vector consumer_edges( + group->consumer_edges.begin(), group->consumer_edges.end()); + for (auto edge : consumer_edges) { + segmented_fusion_->removeEdge(edge); } - - group->producer_edges.clear(); - group->consumer_edges.clear(); - - return removed_edges; } void SegmentCandidateFinder::eraseGroups( std::unordered_set& groups_to_erase) { - std::unordered_set edges_to_erase; for (auto group : groups_to_erase) { - auto disconnected_edges = disconnectGroup(group); - edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end()); + disconnectGroup(group); } - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [&edges_to_erase](SegmentedEdge* edge) { - if (edges_to_erase.find(edge) != edges_to_erase.end()) { - return true; - }; - return false; - }), - edges().end()); - groups().erase( std::remove_if( groups().begin(), @@ -2078,6 +2046,30 @@ void SegmentCandidateFinder::eraseGroups( groups().end()); } +std::vector SegmentedFusion::getEdgesBetween( + const SegmentedGroup* producer, + const SegmentedGroup* consumer) const { + std::vector edges_between; + + // Look through producer's consumer edges + for (SegmentedEdge* edge : producer->consumer_edges) { + if (edge->to == consumer) { + edges_between.push_back(edge); + } + } + + return edges_between; +} + +void SegmentedFusion::connectGroups( + SegmentedGroup* producer, + SegmentedGroup* consumer, + Val* val) { + SegmentedEdge* new_edge = newEdge(producer, consumer, val); + producer->consumer_edges.push_back(new_edge); + consumer->producer_edges.push_back(new_edge); +} + SegmentedGroup* SegmentCandidateFinder::mergeNodes() { FUSER_PERF_SCOPE("SegmentCandidateFinder::mergeNodes"); SegmentedGroup* last_merged = nullptr; @@ -2093,90 +2085,65 @@ SegmentedGroup* SegmentCandidateFinder::mergeNodes() { // Make the new joined node auto joined_group = segmented_fusion_->newGroup(); + // Merge input and output vals joined_group->input_vals_ = group1->input_vals_.computeUnion(group2->input_vals_); - joined_group->output_vals_ = group1->output_vals_.computeUnion(group2->output_vals_); + // Merge expressions joined_group->exprs_ = group1->exprs_; joined_group->exprs_.insert( joined_group->exprs_.end(), group2->exprs_.begin(), group2->exprs_.end()); + // Get all edges that will connect to the new joined group auto producer_edges = getMergedProducerEdges(group1, group2); - // Connect joined group to resulting neighbors - for (auto edge : producer_edges) { - auto from = edge->from; - auto val = edge->val; + auto consumer_edges = getMergedConsumerEdges(group1, group2); - auto new_edge = segmented_fusion_->newEdge(from, joined_group, val); - joined_group->producer_edges.push_back(new_edge); - from->consumer_edges.push_back(new_edge); + // Connect all producer edges to the new joined group + for (auto edge : producer_edges) { + segmented_fusion_->connectGroups(edge->from, joined_group, edge->val); } - auto consumer_edges = getMergedConsumerEdges(group1, group2); - + // Connect all consumer edges from the new joined group for (auto edge : consumer_edges) { - auto to = edge->to; - auto val = edge->val; - - auto new_edge = segmented_fusion_->newEdge(joined_group, to, val); - joined_group->consumer_edges.push_back(new_edge); - edge->to->producer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(joined_group, edge->to, edge->val); } - // Disconnect the merged groups before deriveSchedulerType, which - // may temporarily inject type cast and can get confused if stale - // edges exist + // Now that all new connections are made, disconnect the old groups, this + // invalidates producer_edges and consumer_edges for (auto merged_group : {group1, group2}) { - auto disconnected_edges = disconnectGroup(merged_group); - clean_up_edges_.insert( - disconnected_edges.begin(), disconnected_edges.end()); + disconnectGroup(merged_group); } + // Set scheduler type for the new group joined_group->setSchedulerType(deriveSchedulerType(joined_group)); - // Need to maintain the group dependency data if it has been intialized - // by previous merging + + // Update group dependency data if initialized if (group_dependency_) { group_dependency_->as()->mergeGroups( group1, group2, joined_group); } + last_merged = joined_group; } to_merge_.clear(); - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [this](SegmentedEdge* edge) { - if (this->clean_up_edges_.find(edge) != - this->clean_up_edges_.end()) { - return true; - }; - return false; - }), - edges().end()); - + // Clean up merged groups groups().erase( std::remove_if( groups().begin(), groups().end(), [this](SegmentedGroup* group) { - if (this->clean_up_groups_.find(group) != - this->clean_up_groups_.end()) { - return true; - }; - return false; + return this->clean_up_groups_.find(group) != + this->clean_up_groups_.end(); }), groups().end()); - clean_up_edges_.clear(); clean_up_groups_.clear(); - return last_merged; } @@ -2187,77 +2154,55 @@ SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( const std::vector& groups_to_merge) { NVF_ERROR( !groups_to_merge.empty(), - "fusion segment :(mergeAllGivenGroups) tried to merge no groups") + "fusion segment :(mergeAllGivenGroups) tried to merge no groups"); // Make a set to detect internal edges std::unordered_set group_set( groups_to_merge.begin(), groups_to_merge.end()); - // Sets to de-duplicate multiple uses of - // edge values and re-computations of exprs - std::unordered_set used_edge_vals_set; - std::unordered_set exprs_set; - // Create new group auto joined_group = segmented_fusion_->newGroup(); - // Populate edges, exprs, global vals - // from each of the groups + // Track unique vals and exprs to avoid duplicates + std::unordered_set used_edge_vals_set; + std::unordered_set exprs_set; + + // Merge inputs and outputs from all groups for (auto group : groups_to_merge) { - // Populate complete fusion inputs to the group joined_group->input_vals_.pushBack(group->input_vals_); joined_group->output_vals_.pushBack(group->output_vals_); + } - // Populate producer edges to the group - for (auto edge : group->producer_edges) { - if ( - // Check this is not internal edge - !group_set.count(edge->from) && - // Check this val has been added or not - !used_edge_vals_set.count(edge->val)) { - used_edge_vals_set.insert(edge->val); - auto new_producer_edge = - segmented_fusion_->newEdge(edge->from, joined_group, edge->val); - joined_group->producer_edges.push_back(new_producer_edge); - edge->from->consumer_edges.push_back(new_producer_edge); - } - } + // Get all edges that will connect to the new joined group + auto all_edges = getAllEdges(groups_to_merge); - // Populate consumer edges from the group - for (auto edge : group->consumer_edges) { - if ( - // Check this is not internal edge - !group_set.count(edge->to)) { - auto new_consumer_edge = - segmented_fusion_->newEdge(joined_group, edge->to, edge->val); - joined_group->consumer_edges.push_back(new_consumer_edge); - edge->to->producer_edges.push_back(new_consumer_edge); - } + // Connect all external edges to the new joined group + for (auto edge : all_edges) { + if (group_set.count(edge->from)) { + // This is a consumer edge from the merged group + segmented_fusion_->connectGroups(joined_group, edge->to, edge->val); + } else { + // This is a producer edge to the merged group + segmented_fusion_->connectGroups(edge->from, joined_group, edge->val); } + } + + // Disconnect all original groups before connecting the new one, this + // invalidates all_edges + for (auto group : groups_to_merge) { + disconnectGroup(group); + } - // Populate exprs + // Merge all expressions from the groups + for (auto group : groups_to_merge) { for (auto expr : group->exprs_) { - if (!exprs_set.count(expr)) { + if (exprs_set.insert(expr).second) { joined_group->exprs_.push_back(expr); - exprs_set.insert(expr); } } } - // Clean up original groups from segmented fusion - for (auto group : groups_to_merge) { - auto disconnected_edges = disconnectGroup(group); - clean_up_edges_.insert( - disconnected_edges.begin(), disconnected_edges.end()); - } - - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }), - edges().end()); - + // Clean up original groups groups().erase( std::remove_if( groups().begin(), @@ -2267,8 +2212,6 @@ SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( }), groups().end()); - clean_up_edges_.clear(); - joined_group->setSchedulerType(deriveSchedulerType(joined_group)); return joined_group; } @@ -2324,8 +2267,9 @@ class FusionSegmentGuard : public NonCopyable { num_original_exprs_ = fusion_->exprs().size(); original_tvs_ = fusion_->allTvs(); #endif // NDEBUG - lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( - segmented_fusion_->edges()); + lowered_precision_edges_ = + segmented_fusion_->castInputOutputToLowerPrecision( + segmented_fusion_->edges()); } // Insert cast and narrow the fusion to a merged group of a and b @@ -2349,7 +2293,7 @@ class FusionSegmentGuard : public NonCopyable { consumer_edges.begin(), consumer_edges.end(), std::back_inserter(all_edges)); - lowered_edges_ = + lowered_precision_edges_ = segmented_fusion_->castInputOutputToLowerPrecision(all_edges, {a, b}); auto new_inputs = getAllInputs(a, b); @@ -2373,8 +2317,9 @@ class FusionSegmentGuard : public NonCopyable { // Cast inputs and outputs of a merged group consisting of // segmented_groups. auto all_edges = getAllEdges(segmented_groups); - lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( - all_edges, segmented_groups); + lowered_precision_edges_ = + segmented_fusion_->castInputOutputToLowerPrecision( + all_edges, segmented_groups); auto new_inputs = allInputsIfTrueElseOutputs(segmented_groups, true); auto new_outputs = allInputsIfTrueElseOutputs(segmented_groups, false); @@ -2393,8 +2338,9 @@ class FusionSegmentGuard : public NonCopyable { restoreOriginalSegment(); // Revert the cast - if (segmented_fusion_ != nullptr && !lowered_edges_.empty()) { - segmented_fusion_->revertInputOutputPrecisionChanges(lowered_edges_); + if (segmented_fusion_ != nullptr && !lowered_precision_edges_.empty()) { + segmented_fusion_->revertInputOutputPrecisionChanges( + lowered_precision_edges_); } #ifndef NDEBUG @@ -2473,7 +2419,7 @@ class FusionSegmentGuard : public NonCopyable { Fusion* const fusion_ = nullptr; std::vector old_inputs_; std::vector old_outputs_; - std::vector lowered_edges_; + std::vector lowered_precision_edges_; #ifndef NDEBUG size_t num_original_exprs_ = 0; std::vector original_tvs_; @@ -3749,27 +3695,6 @@ class MergeUpAndDownCast { namespace { -//! Returns true if group1 and group2 are an immediate producer-consumer pair. -bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) { - // Check if group1 is a immediate consumer of group2 - if (std::any_of( - group1->producer_edges.begin(), - group1->producer_edges.end(), - [group2](SegmentedEdge* edge) { return edge->from == group2; })) { - return true; - } - - // Check if group1 is a immediate producer of group2 - if (std::any_of( - group1->consumer_edges.begin(), - group1->consumer_edges.end(), - [group2](SegmentedEdge* edge) { return edge->to == group2; })) { - return true; - } - - return false; -} - //! Allow the segmentation algorithm to prefer certain exprs to merge class PreferredMergeCandidatePicker { public: @@ -3879,7 +3804,8 @@ bool SegmentCandidateFinder::codeGenSupportedMerge( SegmentedGroup* group2) { FUSER_PERF_SCOPE("SegmentCandidateFinder::codeGenSupportedMerge"); NVF_ERROR( - areDirectlyConnected(group1, group2), + !segmented_fusion_->getEdgesBetween(group1, group2).empty() || + !segmented_fusion_->getEdgesBetween(group2, group1).empty(), "only support testing immediate producer-consumer groups"); // The segmemter should ideally be redesigned to be more flexible and // decoupled from the schedulers, but for now, we just return @@ -3979,9 +3905,7 @@ void SegmentCandidateFinder::buildInitialSegments() { if (isFusionInput(inp)) { expr_group->input_vals_.pushBack(inp); auto aux_group = input2group_.at(inp); - auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp); - expr_group->producer_edges.push_back(new_edge); - aux_group->consumer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(aux_group, expr_group, inp); continue; } @@ -3999,9 +3923,7 @@ void SegmentCandidateFinder::buildInitialSegments() { } auto def_group = expr2group.at(inp->definition()); - auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp); - expr_group->producer_edges.push_back(new_edge); - def_group->consumer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(def_group, expr_group, inp); } for (auto out : expr->outputs()) { if (out->isFusionOutput()) { @@ -4727,26 +4649,29 @@ void SegmentCandidateFinder::resolveNonscalarForwardedInput( SegmentedGroup* aux_group = input2group_.at(forwarded_input); NVF_ERROR(aux_group->producer_edges.empty()); - // use unordered_set to avoid duplicated group in consumers. - // duplicated entry in consumer would make use call - // codeGenSupportedMerge(input_group, consumer) twice. Where the second time - // the connection has already been severed by mergeNodes(). GroupSet consumers; for (SegmentedEdge* edge : aux_group->consumer_edges) { consumers.pushBack(edge->to); } - aux_group->consumer_edges.clear(); for (SegmentedGroup* consumer : consumers) { SegmentedGroup* input_group = createInputGroup(forwarded_input); - - for (SegmentedEdge*& edge : consumer->producer_edges) { + std::vector edges_to_remove; + std::vector producer_edge_copy = consumer->producer_edges; + // Use a copy to iterate over edges as connect group can invalidate the + // original iterator + for (SegmentedEdge* edge : producer_edge_copy) { if (edge->from == aux_group && edge->val == forwarded_input) { - edge->from = input_group; - input_group->consumer_edges.push_back(edge); + // Create new edges before removing old ones + segmented_fusion_->connectGroups( + input_group, consumer, forwarded_input); + // Now safe to remove old edges + edges_to_remove.push_back(edge); } } - + for (auto edge_to_remove : edges_to_remove) { + segmented_fusion_->removeEdge(edge_to_remove); + } consumer->input_vals_.erase(forwarded_input); if (codeGenSupportedMerge(input_group, consumer)) { @@ -4765,21 +4690,18 @@ void SegmentCandidateFinder::removeScalarEdges() { // translation. // we will not need them after scalar // resolution - auto remove_scalar_edges_from_vec = [](std::vector& edges) { - edges.erase( - std::remove_if( - edges.begin(), - edges.end(), - [](SegmentedEdge* segmented_edge) { - return segmented_edge->val->isScalar(); - }), - edges.end()); - }; - remove_scalar_edges_from_vec(edges()); - for (auto group : groups()) { - remove_scalar_edges_from_vec(group->producer_edges); - remove_scalar_edges_from_vec(group->consumer_edges); + // Collect all scalar edges first since removeEdge modifies the edge lists + std::vector scalar_edges; + for (auto edge : edges()) { + if (edge->val->isScalar()) { + scalar_edges.push_back(edge); + } + } + + // Remove each scalar edge through the proper API + for (auto edge : scalar_edges) { + segmented_fusion_->removeEdge(edge); } } diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index 6d6306ded2d..ace8a0bfb02 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -151,8 +151,10 @@ class SegmentedGroup { std::optional> getMaybeHeuristicParams( SchedulerRuntimeInfo& runtime_info); - //! Query if this is a group for a fusion input - bool isFusionInputGroup() const; + //! Get the SegmentedFusion this group belongs to + const SegmentedFusion* segmentedFusion() const { + return segmented_fusion_; + } public: //! "Ancestor nodes", towards inputs of segmentedDAG @@ -192,9 +194,6 @@ class SegmentedGroup { //! Theorem 4.2 int level_ = -1; - //! traversal marker, has this node already been processed - bool visited_ = false; - //! Did we select another group to merge with SegmentedGroup* merge_with_ = nullptr; @@ -205,13 +204,6 @@ class SegmentedGroup { bool merged_ = false; private: - //! Utility to convert edge vector to value vector - std::vector edgesToVals(const std::vector& se_v); - - //! Reset method to call at begining of each - //! merge node iteration - void clearTraversalInfo(); - //! To be called at the very end of segment fusion //! no more segment merging should be done beyond void finalize(); @@ -333,6 +325,13 @@ class SegmentedFusion { //! API for adding edges SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); + //! Remove an edge from the segmented fusion graph and update all affected + //! groups The edge object will be deleted and should not be used after this + //! call + void removeEdge(SegmentedEdge* edge); + + void connectGroups(SegmentedGroup* from, SegmentedGroup* to, Val* val); + HeuristicDataCache* getCachedHeuristicDataFor(SegmentedGroup* group); //! Lower FP precision of inputs and outputs specified by the given @@ -363,6 +362,11 @@ class SegmentedFusion { //! Grab edges with val std::vector getEdgesByVal(Val* val) const; + //! Get edges between two groups + std::vector getEdgesBetween( + const SegmentedGroup* from, + const SegmentedGroup* to) const; + //! Serialize SegmentedFusion using flatbuffers flatbuffers::Offset serialize( flatbuffers::FlatBufferBuilder& builder) const; @@ -400,7 +404,6 @@ class SegmentedFusion { SegmentedGroup* makeGroup(); SegmentedGroup* makeGroup(Expr*); - SegmentedGroup* makeFusionInputGroup(); SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); void cleanUnused(); std::unordered_map groups_map() const; @@ -580,7 +583,7 @@ class SegmentCandidateFinder { SegmentedGroup* group, std::vector candidates = {}); - std::unordered_set disconnectGroup(SegmentedGroup* group); + void disconnectGroup(SegmentedGroup* group); std::vector& groups() { NVF_ERROR( @@ -701,7 +704,6 @@ class SegmentCandidateFinder { SegmentCandidateFinderOptions options_; std::unordered_set clean_up_groups_; - std::unordered_set clean_up_edges_; std::vector to_merge_; From da72caeb2bcbde061d24cfaee1f8e0cc916125cf Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Mon, 21 Apr 2025 13:44:52 -0400 Subject: [PATCH 32/68] Add separate files for mutil-wave and tma approaches (#4265) Separate mutil-wave and tma approaches to different files. **Newly added files:** (1) `normalization_inner_outer_multi_wave.h/cpp`, heuristics and scheduling for multi-wave approach (2) `normalization_inner_outer_tma_ws.h/cpp`, heuristics and scheduling for TMA warp-specialized approach (3) `normalization_inner_outer_utils.h/cpp`, shared utils **Revised files:** (1) `normalization_inner_outer.h/cpp`, it servers as interface of the inner outer persistent scheduler and get heuristics and schedule the fusion using `multi-wave` or `TMA warp-specialized ` approach. --- CMakeLists.txt | 3 + csrc/scheduler/normalization_inner_outer.cpp | 1714 +---------------- .../normalization_inner_outer_multi_wave.cpp | 717 +++++++ .../normalization_inner_outer_multi_wave.h | 31 + .../normalization_inner_outer_tma_ws.cpp | 647 +++++++ .../normalization_inner_outer_tma_ws.h | 31 + .../normalization_inner_outer_utils.cpp | 301 +++ .../normalization_inner_outer_utils.h | 98 + 8 files changed, 1849 insertions(+), 1693 deletions(-) create mode 100644 csrc/scheduler/normalization_inner_outer_multi_wave.cpp create mode 100644 csrc/scheduler/normalization_inner_outer_multi_wave.h create mode 100644 csrc/scheduler/normalization_inner_outer_tma_ws.cpp create mode 100644 csrc/scheduler/normalization_inner_outer_tma_ws.h create mode 100644 csrc/scheduler/normalization_inner_outer_utils.cpp create mode 100644 csrc/scheduler/normalization_inner_outer_utils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ea2e6a31db..e83f7a13def 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,6 +239,9 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/communication.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_tma_ws.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_multi_wave.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_outer.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index 4e9f2f93bfe..8264f046382 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -6,1065 +6,18 @@ */ // clang-format on #include -#include #include -#include +#include +#include +#include #include -#include #include #include -#include -#include #include namespace nvfuser { namespace { - -// The roundup is due to the fact that the shared memory buffer is allocated -// as: ceilDiv(dim_size / vectorize_factor, threads_per_block). -// Let after_vect = dim_size / vectorize_factor; -// n_batch = ceilDiv(after_vect, threads_per_block); -// Then the shared memory buffer size is n_batch * vectorize_factor * -// threads_per_block * data_type_size. This function returns the maximum -// possible shared memory buffer size considering all possible block sizes. -int64_t roundUpSharedMemory( - int64_t tv_buffer_size, - int64_t data_type_size, - int64_t vectorize_factor, - int64_t threads_per_block_min, - int64_t threads_per_block_max, - int64_t threads_per_block_step) { - int64_t dim_size = tv_buffer_size / data_type_size; - int64_t after_vect = dim_size / vectorize_factor; - int64_t max_smem = 0; - for (int64_t threads_per_block = threads_per_block_min; - threads_per_block <= threads_per_block_max; - threads_per_block += threads_per_block_step) { - int64_t n_batch = ceilDiv(after_vect, threads_per_block); - max_smem = std::max( - max_smem, - n_batch * vectorize_factor * threads_per_block * data_type_size); - } - return max_smem; -} - -// Return the broadcast tvs that are broadcast to the iteration dimensions of -// the inner reduction tv. These tvs are reused in the loop over the iteration -// dimension. This reuse reduced the number loads from gmem and this tensor -// is likely the first candidate to be moved to shared memory when the register -// space runs low. -std::vector getOuterBroadcastTvs( - Fusion* fusion, - const std::vector& reduction_tvs) { - // set reference broadcast mask using the first inner reduction tv - std::vector ref_broadcast_mask; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - const auto& logical = tv->getLogicalDomain(); - ref_broadcast_mask.reserve(logical.size()); - for (const auto i : arange(logical.size())) { - ref_broadcast_mask.push_back(!logical.at(i)->isReduction()); - } - break; - } - } - NVF_ERROR(!ref_broadcast_mask.empty(), "ref_broadcast_mask is empty!"); - - // find the broadcast tensor whose broadcast mask is same to the reference - std::vector outer_broadcast_tvs; - for (auto tv : fusion->allTvs()) { - if (std::any_of( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - [](IterDomain* id) { return id->isBroadcast(); })) { - if (auto bcast = dynamic_cast(tv->definition())) { - if (bcast->getBroadcastDimFlags() == ref_broadcast_mask) { - outer_broadcast_tvs.emplace_back(tv); - } - } - } - } - return outer_broadcast_tvs; -} - -// Size of buffers storing intermediate outer reduction results -// TODO: check if we can directly start with [buffer_size = 1] -int64_t partialOuterReductionBufferSize( - const std::vector& reduction_tvs, - SchedulerRuntimeInfo& runtime_info) { - int64_t partial_reduction_buffer_size = 0; - for (auto buffer : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(buffer)) { - continue; - } - int64_t buffer_size = -1; - for (auto id : buffer->getLogicalDomain()) { - if (id->isReduction() || id->isBroadcast()) { - continue; - } - auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent()); - NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size."); - if (buffer_size == -1) { - buffer_size = id_size.as(); - } else { - buffer_size *= id_size.as(); - } - } - buffer_size = (buffer_size == -1) ? 0 - : buffer_size * - (int64_t)dataTypeSize(buffer->getDataType().value(), - runtime_info.getIndexType()); - partial_reduction_buffer_size += buffer_size; - } - return partial_reduction_buffer_size; -} - -// Decide where to store persistent buffers. -// By default, they reside in registers. -// If register space runs low but there's ample shared memory, -// move one or more buffers to shared memory until the register space is -// sufficient. -struct PersistentBufferStorageParams { - // representing buffers that are stored in shared memory, other buffers are - // stored in registers. - std::vector smem_persistent_buffers; - - // Total number of bytes occupied by all persistent buffers stored in shared - // memory. - int64_t smem_buffer_size = -1; - - // Total number of bytes occupied by all persistent buffers stored in - // registers. - int64_t regs_buffer_size = -1; - - // Additional shared memory usage per block that is not associated with - // persistent buffers. This includes memory for driver overhead and workspace - // for reductions. - int64_t smem_overhead = -1; - - // Flag indicating whether there are sufficient registers and shared memory - // available to accommodate all persistent buffers as required for efficient - // execution. - bool has_enough_regs_and_smem = false; - - // Flag indicating whether the persistent buffers are recomputed using inputs. - bool project_to_input = false; -}; - -// Prioritize keeping buffers used by outer broadcast tensors to shared memory -// because: -// (1) They are reused in every iteration of the outer loop, has lower IO. -// (2) Load occurs before the outer loop. Temporary register usage won't -// increase register pressure since the loop is the high-pressure region. -std::vector sortProjectableBufferInputs( - const std::vector& projectable_buffer_inputs, - const std::vector& outer_broadcast_tvs) { - // mark whether the buffer is used by outer broadcast tensors - std::unordered_map is_used_by_outer_bcast; - for (auto buffer : projectable_buffer_inputs) { - is_used_by_outer_bcast[buffer] = std::any_of( - outer_broadcast_tvs.begin(), - outer_broadcast_tvs.end(), - [&buffer](TensorView* tv) { - return DependencyCheck::isDependencyOf(buffer, tv); - }); - } - - // sort based on [is_used_by_outer_bcast] - std::vector sorted_buffer = projectable_buffer_inputs; - std::sort( - sorted_buffer.begin(), - sorted_buffer.end(), - [&](TensorView* a, TensorView* b) { - return !is_used_by_outer_bcast[a] && is_used_by_outer_bcast[b]; - }); - return sorted_buffer; -} - -PersistentBufferStorageParams getPersistentBufferStorageParams( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicDataCache* data_cache, - const std::vector& reduction_tvs, - const int64_t vectorize_factor, - const int64_t threads_per_block_min, - const int64_t threads_per_block_max) { - FUSER_PERF_SCOPE( - "normalization_inner_outer::getPersistentBufferStorageParams"); - - PersistentBufferStorageParams buffer_params; - - auto persistent_buffer_info_entry = - HeuristicDataCacheEntry( - data_cache, [&fusion]() { - return std::make_unique( - scheduler_utils::persistentBuffers(fusion)); - }); - - auto& persistent_buffer_info = persistent_buffer_info_entry.get(); - - auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( - fusion, runtime_info, persistent_buffer_info, data_cache); - - // Project to inputs when there is at least one outer broadcast tensor or - // projected persistent buffer size is smaller. When projecting to inputs, the - // outer broadcast tensor is reused in the loop over the iteration dimension, - // test shows it is faster than the non-projected version which requires - // reload from gmem for each iteration. - // Note: in current use cases (layer norm bwd and RMS norm bwd), there are - // outer broadcast tvs and always project to inputs. - // Warp specialized persistent kernel always cache inputs in shared memory, - // should project to inputs. - const auto& outer_broadcast_tvs = getOuterBroadcastTvs(fusion, reduction_tvs); - bool skip_check_buffer_size = !outer_broadcast_tvs.empty() || - isOptionEnabled(EnableOption::WarpSpecializedNormalization); - normalization_scheduler_utils::BufferProjectionStrategy project_strategy = - normalization_scheduler_utils::isProjectBufferToInputs( - fusion, - runtime_info, - reduction_tvs, - persistent_buffer_info, - persistent_buffer_size_info, - InnerOuterPersistentKernelScheduler::schedulerType(), - /*can_use_smem_persistent=*/true, - !skip_check_buffer_size); - - buffer_params.project_to_input = - (project_strategy == - normalization_scheduler_utils::BufferProjectionStrategy:: - ProjectToInputs); - - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( - fusion, reduction_tvs, threads_per_block_max); - int64_t available_smem = - (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; - int64_t available_regs = scheduler_utils::register_file_size_56k; - buffer_params.smem_overhead = smem_overhead; - - // (1) Use both register and shared memory. - // Start with all the cached input buffers in shared memory, they are loaded - // from global memory uses async copy which bypasses L1 cache. Outer reduction - // buffers are used to accumulate partial results of the outer reduction. They - // are not loaded from global memory and requires frequent read/write. So, - // they are always stored in registers. - // TODO: We may also move outer reduction buffers to shared - // memory to avoid segmentation when there are many outer reductions and - // hardware has larger shared memory, but these applications are rare, so this - // is not considered here. - auto buffers = buffer_params.project_to_input - ? persistent_buffer_info.projectable_buffer_inputs - : persistent_buffer_info.persistent_buffers; - - // Add buffers that are inputs to the fusion. They are not included in - // projectable_buffer_inputs since they are not projectable. - if (buffer_params.project_to_input) { - for (auto tv : persistent_buffer_info.persistent_buffers) { - if (tv->isFusionInput()) { - buffers.push_back(tv); - } - } - } - - // Needs to use rounded shared memory size to avoid over usage. - // key : buffer tv. - // val : register size and rounded shared memory size - std::unordered_map> - required_size_regs_smem_map; - int64_t total_smem_buffer_size = 0; - for (auto buffer : buffers) { - int64_t buffer_size_regs = scheduler_utils::getPersistentBufferSizeOfTensor( - buffer, runtime_info, persistent_buffer_info); - int64_t buffer_size_smem = roundUpSharedMemory( - buffer_size_regs, - dataTypeSize(buffer->getDataType().value()), - vectorize_factor, - threads_per_block_min, - threads_per_block_max, - dev_prop->warpSize); - required_size_regs_smem_map[buffer] = - std::make_pair(buffer_size_regs, buffer_size_smem); - total_smem_buffer_size += buffer_size_smem; - } - buffer_params.smem_buffer_size = total_smem_buffer_size; - buffer_params.regs_buffer_size = - partialOuterReductionBufferSize(reduction_tvs, runtime_info); - if (buffer_params.regs_buffer_size <= available_regs && - buffer_params.smem_buffer_size <= available_smem) { - buffer_params.smem_persistent_buffers = buffers; - buffer_params.has_enough_regs_and_smem = true; - return buffer_params; - } - - // Moving outer reduction buffer to shared memory is not considered yet, - // set to false if the outer reduction buffer size exceeds the register size. - if (buffer_params.regs_buffer_size > available_regs) { - buffer_params.has_enough_regs_and_smem = false; - return buffer_params; - } - - // (2) Now, shared memory is overused, move some buffers to registers. - // (2.1) Sort the candidate persistent buffers. No need to sort since the - // sorting is based on whether the buffer is used by outer broadcast tensors. - if (!outer_broadcast_tvs.empty()) { - buffers = sortProjectableBufferInputs(buffers, outer_broadcast_tvs); - } - // (2.2) Before this loop, all cached input buffers are in shared memory. Move - // buffer from shared memory to register. - int64_t n_regs_buffer = -1; - const int n_buffers = (int)buffers.size(); - for (int i = 0; i < n_buffers; i++) { - auto current_tv = buffers[i]; - auto [buffer_size_regs, buffer_size_smem] = - required_size_regs_smem_map.at(current_tv); - buffer_params.regs_buffer_size += buffer_size_regs; - buffer_params.smem_buffer_size -= buffer_size_smem; - - // The first-i buffers to are moved from shared memory to register - // If both the register buffer size and shared memory buffer size are within - // the allowable limit, we found a good configuration. - if (buffer_params.regs_buffer_size <= available_regs && - buffer_params.smem_buffer_size <= available_smem) { - n_regs_buffer = i + 1; - break; - } - // Register buffer size exceeds the limit, can't move more to registers. - // Break the loop. - if (buffer_params.regs_buffer_size > available_regs) { - break; - } - } - - // n_regs_buffer > 0 indicats a good configuration is found. - // The first n_regs_buffer buffers are stored in registers and last [n_buffers - // - n_regs_buffer] are stored in shared memory. - if (n_regs_buffer > 0) { - buffer_params.has_enough_regs_and_smem = true; - auto n_smem_buffer = n_buffers - n_regs_buffer; - buffer_params.smem_persistent_buffers.reserve(n_smem_buffer); - for (int i = 0; i < n_smem_buffer; i++) { - buffer_params.smem_persistent_buffers.emplace_back( - buffers[n_buffers - 1 - i]); - } - } else { - buffer_params.has_enough_regs_and_smem = false; - } - return buffer_params; -} - -// The innerOuterPersistentHeuristic is tuned for layer_norm backward on A100 -// ======= Method if hidden_size > 1024 ======= -// (1) Inner reduction is one reduction per block. Reduction domain is -// parallelized by TIDx and TIDy, Iteration domain is parallelized by BIDy. -// (2) Outer reduction is done in two-steps. The first step is partial -// reduction, reduction domain is parallelized by BIDy, iteration domain is -// parallelized by TIDx and TIDy. The partial results are written to gmem -// followed by a grid sync. The second step is block reduction, the reduction -// domain is parallelized by TIDy, the iteration domain is parallelized by TIDx -// and BIDy. -// ======= Method if hidden_size <= 1024 ======= -// (1) Inner reduction is multi-reductions per blocks. Reduction domain is -// parallelized by TIDx, Iteration domain is parallelized by BIDy and TIDy. -// (2) Outer reduction is same to cases where hidden_size > 1024 except the -// second step where in this case, the reduction domain is parallelized by TIDx -// and the iteration domain is parallelized by TIDy and BIDy. This switch -// between TIDx and TIDy is because: -// (a) We can do warp reduction with TIDx -// (b) TIDx*BIDy is usually much larger than hidden_size, e.g. 128*216 = 1024*27 -// this means without switch only 1/27 of the threads is used. -std::unique_ptr innerOuterPersistentHeuristic( - const int64_t outer_dim_numel, - const int64_t inner_dim_numel, - const int64_t regs_buffer_size, - const int64_t smem_buffer_size, - const int64_t smem_overhead, - const size_t tmp_gmem_dtype_size, - const size_t vectorize_factor, - const int64_t hp_threads_per_block_min, - const int64_t hp_threads_per_block_max, - const bool project_to_input, - const PrimDataType index_type) { - auto rparams = std::make_unique( - InnerOuterPersistentKernelScheduler::schedulerType()); - rparams->project_persistent_buffers = project_to_input; - rparams->cparams.index_type = index_type; - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t device_multiprocessor_count = - (int64_t)dev_prop->multiProcessorCount; - // Parameters for inner reduction: - // Reduction dim: inner_vect, inner_batch, bdimx and bdimy - // Iteration dim: gdimy - - // Parameters for outer reduction: - // Reduction dim: bdimy - // Iteration dim: vectorization_factor_outer, bdimx, gdimy - struct InnerOuterParams { - int64_t inner_vect = -1; - int64_t inner_batch = -1; - int64_t bdimx = -1; - int64_t bdimy = -1; - int64_t bdimz = -1; - int64_t gdimy = -1; - int64_t tmp_gmem_write_vect = -1; - int64_t vectorization_factor_outer = -1; - int64_t threads_per_block = -1; - // derived metrics for sorting - int64_t warps_per_sm = -1; - int64_t required_register_per_thread = -1; - int64_t available_register_per_thread = -1; - - void verify() { - NVF_ERROR(inner_vect != -1, "inner_vect is not set."); - NVF_ERROR(inner_batch != -1, "inner_batch is not set."); - NVF_ERROR(bdimx != -1, "bdimx is not set."); - NVF_ERROR(bdimy != -1, "bdimy is not set."); - NVF_ERROR(gdimy != -1, "gdimy is not set."); - NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); - NVF_ERROR( - vectorization_factor_outer != -1, - "vectorization_factor_outer is not set."); - } - std::string toString() const { - std::stringstream ss; - ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch - << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz - << ", gdimy: " << gdimy - << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect - << ", vectorization_factor_outer: " << vectorization_factor_outer - << ", threads_per_block: " << threads_per_block - << ", warps_per_sm: " << warps_per_sm - << ", required_register_per_thread: " << required_register_per_thread - << ", available_register_per_thread: " - << available_register_per_thread; - return ss.str(); - } - }; - - // Set a minimum workload for each thread to take advantage of low - // intra-threads communication cost. - // Tuned for layer_norm backward on A100, still works fine on H100. - auto getMinimumBatch = [&]() -> int64_t { - if (inner_dim_numel >= 3072l) { - if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { - return 3l; - } else { - return 4l; - } - } else if (inner_dim_numel >= 2048l) { - return 2l; - } - return 1l; - }; - - // Estimate register usage per thread based on buffer size. - // Assuming a constant register overhead for non-buffer related usage, - // and all the register buffers are stored in registers. - auto getEstimatedRegisterUsage = [&](int64_t batch_mul_vect) { - int64_t persistent_buffer_size = - regs_buffer_size / inner_dim_numel * batch_mul_vect; - int64_t estimated_register_count = - persistent_buffer_size / scheduler_utils::bytes_per_register + - scheduler_utils::register_overhead; - return std::min( - estimated_register_count, scheduler_utils::max_registers_per_thread); - }; - - // Estimate max blocks per sm based on register and shared memory usage. - auto getBlocksPerSM = [&](const int64_t threads_per_sm, - const int64_t threads_per_block, - const int64_t warp_size) { - // check register limitation on blocks per sm - constexpr int64_t warp_allocation_granularity = 4; - const int64_t allocated_warps_per_block = - ceilDiv( - ceilDiv(threads_per_block, warp_size), - warp_allocation_granularity) * - warp_allocation_granularity; - int64_t max_blocks_per_sm_regs = scheduler_utils::safeDiv( - threads_per_sm / warp_size, allocated_warps_per_block); - // check shared memory limitation on blocks per sm - int64_t max_blocks_per_sm_smem = - (int64_t)dev_prop->sharedMemPerMultiprocessor / - (smem_overhead + smem_buffer_size); - return std::min(max_blocks_per_sm_regs, max_blocks_per_sm_smem); - }; - - // In the inner reduction part of the kernel, gdimy is used to parallelize the - // outer dimension. The kernel is a cooperative kernel, so the number of - // blocks should be as large as possible to achieve a high occupancy unless - // outer dim is too small which may lead large workload for the final outer - // reduction. So, gdimy is drvied from the number of blocks per sm and limited - // to ensure at least 8 rows per block. - // TODO: re-evaluate this 8 rows per block requirement. - auto getGdimy = [&](int64_t inner_vect, - int64_t threads_per_block, - int64_t inner_batch) { - int64_t reg_per_thread = - getEstimatedRegisterUsage(inner_vect * inner_batch); - int64_t threads_per_sm = getThreadsPerSMGivenRegPerThread(reg_per_thread); - int64_t blocks_per_sm = - getBlocksPerSM(threads_per_sm, threads_per_block, dev_prop->warpSize); - int64_t gdimy = blocks_per_sm * device_multiprocessor_count; - const int64_t outer_iter_min = 8; - const int64_t gdimy_max = scheduler_utils::roundUpToN( - ceilDiv(outer_dim_numel, outer_iter_min), device_multiprocessor_count); - while (gdimy > gdimy_max && blocks_per_sm > 1) { - blocks_per_sm -= 1; - gdimy = blocks_per_sm * device_multiprocessor_count; - } - return gdimy; - }; - - // The inner reduction part of the kernel also does a partial outer reduction - // and stores the partial results in tmp gmem and then reloaded to finish the - // outer reduciton. This function set the vectorization factor for write and - // and read of the partial outer reduction result. - // For write to tmp gmem, follows vectorization factor of inner reduction - // but don't exceed 16 bytes. - // For read from tmp gmem, since the paralelization is changed, a different - // vectorization factor is used to optimize the - // number of reaductions per thread. - auto getOuterReductionBufferVectFactor = [&](int64_t inner_vect) { - constexpr int64_t max_gmem_vect_access_bytes = 16; - const int64_t max_tmp_gmem_vect_factor = std::min( - max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); - int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; - const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; - int64_t vectorization_factor_outer = - std::min(workload_per_thread, max_tmp_gmem_vect_factor); - return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); - }; - - // In the outer reduction part of the kernel, inner and outer dims are - // parallelized as: - // --- inner dim: vect, bdimx, gdimy ---- - // --- outer dim: bdimy ----------------- - // This function splits the threads_per_block into bdimx and bdimy using: - // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) - // bdimy = threads_per_block / bdimx - auto getBdimxBdimy = [&](int64_t threads_per_block, - int64_t vectorization_factor_outer, - int64_t gdimy) { - // For widely used hidden sizes, threads_per_block has factor of 8, roundup - // to increase the probability of bdimx * bdimy == threads_per_block. - int64_t bdimx = scheduler_utils::roundUpPow2Or8( - ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); - // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. - // increase bdimx to make it divisible. Under worst case, bdimx equals to - // threads_per_block. - while (threads_per_block % bdimx) { - bdimx = std::min(bdimx + 8, threads_per_block); - } - // Set OuterParams Reduction dim: bdimy. - int64_t bdimy = threads_per_block / bdimx; - NVF_ERROR( - bdimy * bdimx == threads_per_block, - " threads_per_block must be divisible by bdimx and bdimy."); - return std::make_pair(bdimx, bdimy); - }; - - // Get the heuristics given vectorization factor and threads per block - auto getHeuristicsGivenVectThreads = [&](int64_t vect_factor, - int64_t threads_per_block) { - InnerOuterParams iop; - // (1) inner reduction - // Reduction dim: inner_batch, threads_per_block, vect_factor - // Iteration dim: gdimy - iop.inner_vect = vect_factor; - iop.threads_per_block = threads_per_block; - iop.inner_batch = - ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); - iop.gdimy = - getGdimy(iop.inner_vect, iop.threads_per_block, iop.inner_batch); - // (2) outer reduction - // Iteration dim: gdimy, bdimx, vectorization_factor_outer - // Reduction dim: bdimy - std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = - getOuterReductionBufferVectFactor(iop.inner_vect); - auto [bdimx, bdimy] = getBdimxBdimy( - threads_per_block, iop.vectorization_factor_outer, iop.gdimy); - iop.bdimx = bdimx; - iop.bdimy = bdimy; - // (3) Derived metrics warps_per_sm and register usage for sorting - iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - iop.required_register_per_thread = - getEstimatedRegisterUsage(iop.inner_vect * iop.inner_batch); - return iop; - }; - - // Use the maximum vectorization factor - const int64_t vect_factor = (int64_t)vectorize_factor; - - // Set a reasonable range for threads per block based on the number of - // elements in the inner dimension after vectorization. - // Start from 128 or a smaller number if inner dim is small. - const int64_t after_vect = inner_dim_numel / vect_factor; - const int64_t batch_min = getMinimumBatch(); - int64_t threads_per_block_min = hp_threads_per_block_min; - threads_per_block_min = std::min(threads_per_block_min, after_vect); - threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); - - // star max threads per block from min threads per block - int64_t threads_per_block_max = threads_per_block_min; - // increase to cover the whole inner dim - threads_per_block_max = - std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); - // round up to power of 2 - threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); - // don't go beyond the maximum threads per block - threads_per_block_max = - std::min(threads_per_block_max, hp_threads_per_block_max); - - // Store all the possible heuristics based on different threads per block. - // Vectorizaton is fixed at the maximum value. - std::vector iop_candidates; - for (auto threads_per_block = threads_per_block_max; - threads_per_block >= threads_per_block_min; - threads_per_block /= 2) { - iop_candidates.emplace_back( - getHeuristicsGivenVectThreads(vect_factor, threads_per_block)); - } - - // Sort the heuristics based on the register usage and occupancy. - std::stable_sort( - iop_candidates.begin(), - iop_candidates.end(), - [](const InnerOuterParams& a, const InnerOuterParams& b) { - // If a thread can use more registers than required, there is a high - // chance that it can avoid register spilling and compiler can optimize - // for better instruction level parallelism. - int64_t extra_regs_a = - a.available_register_per_thread - a.required_register_per_thread; - int64_t extra_regs_b = - b.available_register_per_thread - b.required_register_per_thread; - if (extra_regs_a > 0 && extra_regs_b < 0) { - return true; - } else if (extra_regs_a < 0 && extra_regs_b > 0) { - return false; - } - // High occupancy provides better threads level parallelism. - // 25% is sufficient since ILP is high due to persistent batch sizes - // which is equivalent to unrolling inner dim. - if (a.warps_per_sm != b.warps_per_sm && - (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { - return a.warps_per_sm > b.warps_per_sm; - } - // Tie breaker, smaller threads_per_block to reduce communication - // overhead - return a.threads_per_block < b.threads_per_block; - }); - - // Pick the best heuristic - auto iop = iop_candidates.front(); - - // Special case, when inner_dim_numel <= 1024, bdimx is usually small - // after divide by inner_vect and inner_batch. In this case, bdimy is used to - // parallelize outer_dim instead of inner_dim. This pattern is named multi - // reductions per block (mrpb). - if (inner_dim_numel <= 1024) { - rparams->multiple_reds_per_blk = true; - rparams->tidx_for_outer_reduction = true; - - // Step-1, InnerParams, Reduction dim: inner_vect(reuse), - // inner_batch(reuse), bdimx - iop.bdimx = ceilDiv(inner_dim_numel, iop.inner_vect * iop.inner_batch); - - // Step-2, InnerParams, Iteration dim: gdimy, bdimy (in next step) - iop.gdimy = getGdimy(iop.inner_vect, iop.bdimx, iop.inner_batch); - - // Step-3, OuterParams, Iteration dim: vectorization_factor_outer(reuse), - // bdimy, gdimy (in previous step). - // WAR for https://github.com/NVIDIA/Fuser/issues/3428 - iop.bdimy = 1; - - // Step-4, OuterParams, Reduction dim: bdimx (already done) - iop.warps_per_sm = ceilDiv(iop.bdimx * iop.bdimy, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - - if (iop.bdimx % dev_prop->warpSize == 0) { - rparams->pad_inner_reduction_to_warp = true; - rparams->pad_outer_reduction_to_warp = true; - } - rparams->block_dim_iter_dom = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimy * iop.gdimy < - inner_dim_numel; - } else { - rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < - inner_dim_numel; - rparams->static_bdimx = true; - rparams->static_bdimy = true; - iop.bdimz = ceilDiv( - ceilDiv( - ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), - iop.inner_batch); - NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); - } - - // check all the parameters in InnerOuterParams are set. - iop.verify(); - - rparams->persistent_kernel = true; - rparams->fastest_dim = true; - rparams->combined_inner_outer = true; - // tmp_gmem is the intermediate result of outer reduction, its dtype is float, - // so the maximum vectorization factor is 4. - rparams->vectorization_factor_outer = iop.vectorization_factor_outer; - rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; - rparams->cparams.maxrregcount = iop.available_register_per_thread; - rparams->unroll_factor_inner_reduction = iop.inner_vect; - rparams->batches_per_block_inner_reduction = iop.inner_batch; - rparams->block_dim_inner_reduction = ParallelType::TIDx; - rparams->vectorize_inner_reduction = iop.inner_vect > 1; - rparams->split_grid_dim_iter_dom_outer = true; - rparams->grid_dim_iter_dom = ParallelType::BIDy; - - rparams->lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - iop.gdimy, - LaunchParams::UNINITIALIZED_VAL, - iop.bdimx, - iop.bdimy, - LaunchParams::UNINITIALIZED_VAL); - - if (!rparams->smem_persistent_buffers.empty()) { - rparams->tag = - "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; - } else { - rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; - } - - if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { - debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" - << "outer_dim_numel: " << outer_dim_numel << "\n" - << "inner_dim_numel: " << inner_dim_numel << "\n" - << "regs_buffer_size: " << regs_buffer_size << "\n" - << "smem_buffer_size: " << smem_buffer_size << "\n" - << "smem_overhead: " << smem_overhead << "\n" - << "vectorize_factor_input: " << iop.inner_vect << "\n" - << "vectorization_factor_tmp_gmem_write: " - << iop.tmp_gmem_write_vect << "\n" - << "vectorization_factor_outer: " << iop.vectorization_factor_outer - << "\n" - << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk - << "\n" - << "warps_per_sm: " << iop.warps_per_sm << "\n" - << "gdimy: " << iop.gdimy << "\n" - << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; - debug() << rparams->toString() << std::endl; - } - return rparams; -} - -std::unique_ptr innerOuterWarpSpecializedTmaHeuristic( - const int64_t outer_dim_numel, - const int64_t inner_dim_numel, - const int64_t regs_buffer_size, - const int64_t smem_buffer_size, - const int64_t smem_overhead, - const size_t tmp_gmem_dtype_size, - const size_t vectorize_factor, - const int64_t hp_threads_per_block_min, - const int64_t hp_threads_per_block_max, - const bool project_to_input, - const PrimDataType index_type) { - auto rparams = std::make_unique( - InnerOuterPersistentKernelScheduler::schedulerType()); - rparams->project_persistent_buffers = project_to_input; - rparams->cparams.index_type = index_type; - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t device_multiprocessor_count = - (int64_t)dev_prop->multiProcessorCount; - // Parameters for inner reduction: - // Reduction dim: inner_vect, inner_batch, bdimx and bdimy - // Iteration dim: gdimy - - // Parameters for outer reduction: - // Reduction dim: bdimy - // Iteration dim: vectorization_factor_outer, bdimx, gdimy - struct InnerOuterParams { - int64_t inner_vect = -1; - int64_t inner_batch = -1; - int64_t bdimx = -1; - int64_t bdimy = -1; - int64_t bdimz = -1; - int64_t gdimy = -1; - int64_t tmp_gmem_write_vect = -1; - int64_t vectorization_factor_outer = -1; - int64_t threads_per_block = -1; - // derived metrics for sorting - int64_t warps_per_sm = -1; - int64_t required_register_per_thread = -1; - int64_t available_register_per_thread = -1; - - void verify() { - NVF_ERROR(inner_vect != -1, "inner_vect is not set."); - NVF_ERROR(inner_batch != -1, "inner_batch is not set."); - NVF_ERROR(bdimx != -1, "bdimx is not set."); - NVF_ERROR(bdimy != -1, "bdimy is not set."); - NVF_ERROR(gdimy != -1, "gdimy is not set."); - NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); - NVF_ERROR( - vectorization_factor_outer != -1, - "vectorization_factor_outer is not set."); - } - std::string toString() const { - std::stringstream ss; - ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch - << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz - << ", gdimy: " << gdimy - << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect - << ", vectorization_factor_outer: " << vectorization_factor_outer - << ", threads_per_block: " << threads_per_block - << ", warps_per_sm: " << warps_per_sm - << ", required_register_per_thread: " << required_register_per_thread - << ", available_register_per_thread: " - << available_register_per_thread; - return ss.str(); - } - }; - - // Set a minimum workload for each thread to take advantage of low - // intra-threads communication cost. - // Tuned for layer_norm backward on A100, still works fine on H100. - auto get_minimum_batch = [&]() -> int64_t { - if (inner_dim_numel >= 3072l) { - if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { - return 3l; - } else { - return 4l; - } - } else if (inner_dim_numel >= 2048l) { - return 2l; - } - return 1l; - }; - - // Estimate register usage per thread based on buffer size. - // Assuming a constant register overhead for non-buffer related usage, - // and all the register buffers are stored in registers. - auto get_estimated_register_usage = [&](int64_t batch_mul_vect) { - int64_t persistent_buffer_size = - regs_buffer_size / inner_dim_numel * batch_mul_vect; - int64_t estimated_register_count = - persistent_buffer_size / scheduler_utils::bytes_per_register + - scheduler_utils::register_overhead; - return std::min( - estimated_register_count, scheduler_utils::max_registers_per_thread); - }; - - // The inner reduction part of the kernel also does a partial outer reduction - // and stores the partial results in tmp gmem and then reloaded to finish the - // outer reduciton. This function set the vectorization factor for write and - // and read of the partial outer reduction result. - // For write to tmp gmem, follows vectorization factor of inner reduction - // but don't exceed 16 bytes. - // For read from tmp gmem, since the paralelization is changed, a different - // vectorization factor is used to optimize the - // number of reaductions per thread. - auto get_outer_reduction_buffer_vect_factor = [&](int64_t inner_vect) { - constexpr int64_t max_gmem_vect_access_bytes = 16; - const int64_t max_tmp_gmem_vect_factor = std::min( - max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); - int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; - const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; - int64_t vectorization_factor_outer = - std::min(workload_per_thread, max_tmp_gmem_vect_factor); - return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); - }; - - // In the outer reduction part of the kernel, inner and outer dims are - // parallelized as: - // --- inner dim: vect, bdimx, gdimy ---- - // --- outer dim: bdimy ----------------- - // This function splits the threads_per_block into bdimx and bdimy using: - // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) - // bdimy = threads_per_block / bdimx - auto get_bdimx_bdimy = [&](int64_t threads_per_block, - int64_t vectorization_factor_outer, - int64_t gdimy) { - // For widely used hidden sizes, threads_per_block has factor of 8, roundup - // to increase the probability of bdimx * bdimy == threads_per_block. - int64_t bdimx = scheduler_utils::roundUpPow2Or8( - ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); - // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. - // increase bdimx to make it divisible. Under worst case, bdimx equals to - // threads_per_block. - while (threads_per_block % bdimx) { - bdimx = std::min(bdimx + 8, threads_per_block); - } - // Set OuterParams Reduction dim: bdimy. - int64_t bdimy = threads_per_block / bdimx; - NVF_ERROR( - bdimy * bdimx == threads_per_block, - " threads_per_block must be divisible by bdimx and bdimy."); - return std::make_pair(bdimx, bdimy); - }; - - // Get the heuristics given vectorization factor and threads per block - auto get_heuristics_given_vect_threads = [&](int64_t vect_factor, - int64_t threads_per_block) { - InnerOuterParams iop; - // (1) inner reduction - // Reduction dim: inner_batch, threads_per_block, vect_factor - // Iteration dim: gdimy - iop.inner_vect = vect_factor; - iop.threads_per_block = threads_per_block; - iop.inner_batch = - ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); - iop.gdimy = device_multiprocessor_count; - - // (2) outer reduction - // Iteration dim: gdimy, bdimx, vectorization_factor_outer - // Reduction dim: bdimy - std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = - get_outer_reduction_buffer_vect_factor(iop.inner_vect); - auto [bdimx, bdimy] = get_bdimx_bdimy( - threads_per_block, iop.vectorization_factor_outer, iop.gdimy); - iop.bdimx = bdimx; - iop.bdimy = bdimy; - // (3) Derived metrics warps_per_sm and register usage for sorting - iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - iop.required_register_per_thread = - get_estimated_register_usage(iop.inner_vect * iop.inner_batch); - return iop; - }; - - // Use the maximum vectorization factor - const int64_t vect_factor = (int64_t)vectorize_factor; - - // Set a reasonable range for threads per block based on the number of - // elements in the inner dimension after vectorization. - // Start from 128 or a smaller number if inner dim is small. - const int64_t after_vect = inner_dim_numel / vect_factor; - const int64_t batch_min = get_minimum_batch(); - int64_t threads_per_block_min = hp_threads_per_block_min; - threads_per_block_min = std::min(threads_per_block_min, after_vect); - threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); - - // star max threads per block from min threads per block - int64_t threads_per_block_max = threads_per_block_min; - // increase to cover the whole inner dim - threads_per_block_max = - std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); - // round up to power of 2 - threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); - // don't go beyond the maximum threads per block - threads_per_block_max = - std::min(threads_per_block_max, hp_threads_per_block_max); - - // Store all the possible heuristics based on different threads per block. - // Vectorizaton is fixed at the maximum value. - std::vector iop_candidates; - for (auto threads_per_block = threads_per_block_max; - threads_per_block >= threads_per_block_min; - threads_per_block /= 2) { - iop_candidates.emplace_back( - get_heuristics_given_vect_threads(vect_factor, threads_per_block)); - } - - // Sort the heuristics based on the register usage and occupancy. - std::stable_sort( - iop_candidates.begin(), - iop_candidates.end(), - [](const InnerOuterParams& a, const InnerOuterParams& b) { - // If a thread can use more registers than required, there is a high - // chance that it can avoid register spilling and compiler can optimize - // for better instruction level parallelism. - int64_t extra_regs_a = - a.available_register_per_thread - a.required_register_per_thread; - int64_t extra_regs_b = - b.available_register_per_thread - b.required_register_per_thread; - if (extra_regs_a > 0 && extra_regs_b < 0) { - return true; - } else if (extra_regs_a < 0 && extra_regs_b > 0) { - return false; - } - // High occupancy provides better threads level parallelism. - // 25% is sufficient since ILP is high due to persistent batch sizes - // which is equivalent to unrolling inner dim. - if (a.warps_per_sm != b.warps_per_sm && - (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { - return a.warps_per_sm > b.warps_per_sm; - } - // Tie breaker, smaller threads_per_block to reduce communication - // overhead - return a.threads_per_block < b.threads_per_block; - }); - - // Pick the best heuristic - auto iop = iop_candidates.front(); - rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < inner_dim_numel; - rparams->static_bdimx = true; - rparams->static_bdimy = true; - iop.bdimz = ceilDiv( - ceilDiv(ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), - iop.inner_batch); - NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); - - // check all the parameters in InnerOuterParams are set. - iop.verify(); - - rparams->persistent_kernel = true; - rparams->fastest_dim = true; - rparams->combined_inner_outer = true; - // tmp_gmem is the intermediate result of outer reduction, its dtype is float, - // so the maximum vectorization factor is 4. - rparams->vectorization_factor_outer = iop.vectorization_factor_outer; - rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; - rparams->cparams.maxrregcount = iop.available_register_per_thread; - rparams->unroll_factor_inner_reduction = iop.inner_vect; - rparams->batches_per_block_inner_reduction = iop.inner_batch; - rparams->block_dim_inner_reduction = ParallelType::TIDx; - rparams->vectorize_inner_reduction = iop.inner_vect > 1; - rparams->split_grid_dim_iter_dom_outer = true; - rparams->grid_dim_iter_dom = ParallelType::BIDy; - - rparams->lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - iop.gdimy, - LaunchParams::UNINITIALIZED_VAL, - iop.bdimx, - iop.bdimy, - LaunchParams::UNINITIALIZED_VAL); - - rparams->tag = "TMA Warp Specialized Persistent Heuristic.\n"; - - if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { - debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" - << "outer_dim_numel: " << outer_dim_numel << "\n" - << "inner_dim_numel: " << inner_dim_numel << "\n" - << "regs_buffer_size: " << regs_buffer_size << "\n" - << "smem_buffer_size: " << smem_buffer_size << "\n" - << "smem_overhead: " << smem_overhead << "\n" - << "vectorize_factor_input: " << iop.inner_vect << "\n" - << "vectorization_factor_tmp_gmem_write: " - << iop.tmp_gmem_write_vect << "\n" - << "vectorization_factor_outer: " << iop.vectorization_factor_outer - << "\n" - << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk - << "\n" - << "warps_per_sm: " << iop.warps_per_sm << "\n" - << "gdimy: " << iop.gdimy << "\n" - << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; - debug() << rparams->toString() << std::endl; - } - return rparams; -} - std::unique_ptr getInnerOuterPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, @@ -1143,7 +96,7 @@ std::unique_ptr getInnerOuterPersistentHeuristics( NVF_ERROR( !persistent_buffer_info.persistent_buffers.empty(), "Persistent scheduler requires persistent buffers."); - auto buffer_params = getPersistentBufferStorageParams( + auto buffer_params = inner_outer_utils::getPersistentBufferStorageParams( fusion, runtime_info, data_cache, @@ -1152,14 +105,15 @@ std::unique_ptr getInnerOuterPersistentHeuristics( hp.threads_per_block_min, hp.threads_per_block_max); - std::unique_ptr rparams; - + auto rparams = std::make_unique( + InnerOuterPersistentKernelScheduler::schedulerType()); // Ultimately, we want the heuristic to decide between using the // warp-specialized version or the multi-wave version. The enable option is a // temporary configuration to facilitate testing during development without // disrupting existing behavior. if (isOptionEnabled(EnableOption::WarpSpecializedNormalization)) { - rparams = innerOuterWarpSpecializedTmaHeuristic( + inner_outer_tma_warp_specialized::getHeuristics( + rparams.get(), properties.total_iteration_numel, properties.total_reduction_numel, buffer_params.regs_buffer_size, @@ -1171,9 +125,9 @@ std::unique_ptr getInnerOuterPersistentHeuristics( hp.threads_per_block_max, buffer_params.project_to_input, runtime_info.getIndexType()); - rparams->tma_warp_specialized = true; } else { - rparams = innerOuterPersistentHeuristic( + inner_outer_multi_wave::getHeuristics( + rparams.get(), properties.total_iteration_numel, properties.total_reduction_numel, buffer_params.regs_buffer_size, @@ -1194,633 +148,6 @@ std::unique_ptr getInnerOuterPersistentHeuristics( return rparams; } -void scheduleReductionCombinedOuter( - Fusion* fusion, - const ReductionParams* rparams, - const std::vector& outer_reduction_tvs, - std::vector& cached_gmem, - std::vector& cached_gmem_reload, - std::vector& outer_reference_tvs, - std::unordered_set& boundaryNodesSet) { - auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { - int prev_i = -1; - for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (mergeReduction == tv->axis(i)->isReduction()) { - if (prev_i == -1) { - prev_i = i; - } else { - tv->merge(i, prev_i); - prev_i = i; - } - } - } - }; - for (auto& outer_reduction_tv : outer_reduction_tvs) { - // Similar to the inner reduction, we need to reorder the outer reduction tv - // when there are view operations. - if (!ir_utils::getViewOps(fusion).empty()) { - // Reorder reference_tv after propagating the view operation. This will - // reorder for better merging. - outer_reduction_tv->reorder( - scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); - } - - // merge tensorview to [reduction, iteraiton] domains - mergeReductionOrIterDomains(outer_reduction_tv, true); - mergeReductionOrIterDomains(outer_reduction_tv, false); - if (rparams->multiple_reds_per_blk) { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(rparams->block_dim_iter_dom)); - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(rparams->grid_dim_iter_dom), false); - } else { - outer_reduction_tv->split(0, rparams->lparams.gdimy()); - } - - if (rparams->multiple_reds_per_blk) { - outer_reduction_tv->rFactor({1}); - } - TensorView* partialResult = rparams->multiple_reds_per_blk - ? outer_reduction_tv->rFactor({1}) - : outer_reduction_tv->rFactor({0}); - partialResult->cacheBefore(); - partialResult->setMemoryType(MemoryType::Global); - TensorView* partialResultReload = partialResult->cacheAfter(); - - boundaryNodesSet.insert(partialResultReload); - cached_gmem.emplace_back(partialResult); - cached_gmem_reload.emplace_back(partialResultReload); - - if (rparams->multiple_reds_per_blk) { - if (rparams->tidx_for_outer_reduction) { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(ParallelType::TIDx)); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDx); - // to use warp reduction - if (rparams->pad_outer_reduction_to_warp) { - outer_reduction_tv->axis(1)->padToMultipleOfWarp(); - } - } else { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(ParallelType::TIDy)); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - } - // iteration domain - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize( - ParallelType::Vectorize); - } - - if (rparams->tidx_for_outer_reduction) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::TIDy)); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDy); - } else { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::TIDx)); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - - } else { - // reduction domain - outer_reduction_tv->split(0, rparams->lparams.bdimy()); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - - // iteration domain - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize( - ParallelType::Vectorize); - } - - if (rparams->lparams.bdimx() > 1) { - outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - } - auto outer_reference_tv = - reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); - outer_reference_tvs.emplace_back(outer_reference_tv); - } -} - -// fusion is the input IR that will be modified by this function -void scheduleInnerOuterPersistentKernel( - Fusion* fusion, - const ReductionParams* rparams) { - FusionGuard fg(fusion); - - // Grab the reduction, input, and output tensor views. dummy_outputs are - // helper tensors for persistent buffer projection. - std::vector dummy_outputs, cached_inputs, reduction_tvs, - smem_consumers; - std::vector> cached_outputs; - normalization_scheduler_utils::beforeSchedule( - fusion, - rparams, - dummy_outputs, - cached_inputs, - reduction_tvs, - smem_consumers, - cached_outputs); - - // split reduction_tvs into inner and outer reduction_tvs - std::vector inner_reduction_tvs, outer_reduction_tvs; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - inner_reduction_tvs.emplace_back(tv); - } else { - outer_reduction_tvs.emplace_back(tv); - } - } - NVF_ERROR( - !inner_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); - NVF_ERROR( - !outer_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); - - // schedule inner reduction, only schedule the first inner reduction tv, - // then will be propagated to other inner reduction tvs. - TensorView* inner_reference_tv = - normalization_scheduler_utils::scheduleReductionGeneral( - fusion, - rparams, - inner_reduction_tvs, - InnerOuterPersistentKernelScheduler::schedulerType()); - - // schedule outer reduction, schedule all the outer reduction tvs since we - // need to store the intermediate results. - std::vector cached_gmem; - std::vector cached_gmem_reload; - std::vector outer_reference_tvs; - std::unordered_set boundaryNodesSet; - scheduleReductionCombinedOuter( - fusion, - rparams, - outer_reduction_tvs, - cached_gmem, - cached_gmem_reload, - outer_reference_tvs, - boundaryNodesSet); - - // Propagate inner reduction and outer reductions - for (auto output : dummy_outputs) { - fusion->addOutput(output); - } - - const bool is_unroll_or_vectorization = rparams->isUnrolled(); - const bool is_vectorize = - rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; - const bool is_outer_grid_persistence = rparams->persistent_kernel && - rparams->cross_grid_inner_reduction && !rparams->fastest_dim; - - // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this - // propagation will not propagate to the final outer reduction. - reduction_scheduler_utils::propagateTransformation( - inner_reference_tv, boundaryNodesSet); - reduction_scheduler_utils::propagateRFactor( - inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); - - // Don't allow parallelization propagation goes through boundaryNodesSet - const auto& selected_tvs_inner = - scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); - reduction_scheduler_utils::propagateParallelization( - inner_reduction_tvs[0], - inner_reference_tv, - is_unroll_or_vectorization, - is_outer_grid_persistence, - inner_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_inner.begin(), selected_tvs_inner.end()}); - - // Propagate outer reduction. Each outer reduction is connected with its - // cached_gmem and output, since we added all the cached_gmem to the - // boundaryNodesSet, the transformation from one outer reduction can't - // propagate to other outer reductions due to the cutoff at - // boundaryNodesSet. Thus, we need a loop to initiate the propagation from - // each outer reduction. Don't allow parallelization propagation goes - // through cached_gmem, see issue 246. - for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { - const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( - {outer_reduction_tvs[i]}, {cached_gmem[i]}); - reduction_scheduler_utils::propagateTransformation( - outer_reference_tvs[i], boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - outer_reference_tvs[i], - is_vectorize, - cached_inputs, - cached_outputs); - reduction_scheduler_utils::propagateParallelization( - outer_reduction_tvs[i], - outer_reference_tvs[i], - is_unroll_or_vectorization, - is_outer_grid_persistence, - outer_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_outer.begin(), selected_tvs_outer.end()}); - } - - // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write - // is guaranteed to be smaller or equal to input vectorization factor. - if (rparams->vectorization_factor_tmp_gmem_write > 1) { - for (auto tv : cached_gmem) { - NVF_ERROR( - rparams->vectorization_factor_tmp_gmem_write <= - rparams->unroll_factor_inner_reduction, - "vectorization factor of temp gmem write should be smaller than that of inner reduction.") - if (rparams->vectorization_factor_tmp_gmem_write < - rparams->unroll_factor_inner_reduction) { - tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - } - // vectorization propagate through propagateParallelization only works for - // input and output tensors. propagate vectorization to cached_gmem_reload - // directly from output tv using parallelizeAllLike. must propagate - // seperaely for different tvs as outer reductions are transformed - // seperately. - if (rparams->vectorization_factor_outer > 1) { - for (auto tv : cached_gmem_reload) { - auto output_tvs = ir_utils::outputTvsOf(tv); - NVF_ERROR( - !output_tvs.empty(), - "cached_gmem_reload should have at least one output tensor.") - scheduler_utils::parallelizeAllLike( - output_tvs[0], - -1, - {cached_gmem_reload.begin(), cached_gmem_reload.end()}, - {ParallelType::Vectorize}); - } - } - - // Needs special handling of vectorized loading from shared memory due to - // potential different data types of inputs and shared memory tensor. - if (is_vectorize) { - reduction_scheduler_utils::sharedMemoryConsumerVectorization( - smem_consumers, rparams->unroll_factor_inner_reduction); - } - - // Remove dummy outputs as they can inadvertently affect CA positions - for (auto output : dummy_outputs) { - fusion->removeOutput(output); - } - inlineMost(); -} - -void scheduleTmaWarpSpecializedOuter( - Fusion* fusion, - const ReductionParams* rparams, - const std::vector& outer_reduction_tvs, - std::vector& cached_gmem, - std::vector& cached_gmem_reload, - std::vector& outer_reference_tvs, - std::unordered_set& boundaryNodesSet) { - auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { - int prev_i = -1; - for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (mergeReduction == tv->axis(i)->isReduction()) { - if (prev_i == -1) { - prev_i = i; - } else { - tv->merge(i, prev_i); - prev_i = i; - } - } - } - }; - for (auto& outer_reduction_tv : outer_reduction_tvs) { - // Similar to the inner reduction, we need to reorder the outer reduction tv - // when there are view operations. - if (!ir_utils::getViewOps(fusion).empty()) { - // Reorder reference_tv after propagating the view operation. This will - // reorder for better merging. - outer_reduction_tv->reorder( - scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); - } - - // merge tensorview to [reduction, iteraiton] domains - mergeReductionOrIterDomains(outer_reduction_tv, true); - mergeReductionOrIterDomains(outer_reduction_tv, false); - - // First-stage of outer reduction - outer_reduction_tv->split(0, rparams->lparams.gdimy()); - - TensorView* partialResult = outer_reduction_tv->rFactor({0}); - partialResult->cacheBefore(); - partialResult->setMemoryType(MemoryType::Global); - TensorView* partialResultReload = partialResult->cacheAfter(); - - boundaryNodesSet.insert(partialResultReload); - cached_gmem.emplace_back(partialResult); - cached_gmem_reload.emplace_back(partialResultReload); - - // Second-stage of outer reduction - // reduction domain, [I1/TIDy, TIDy] - outer_reduction_tv->split(0, rparams->lparams.bdimy()); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - // iteration domain, [BIDy, TIDx, Vect] - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Vectorize); - } - - if (rparams->lparams.bdimx() > 1) { - outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - - auto outer_reference_tv = - reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); - outer_reference_tvs.emplace_back(outer_reference_tv); - } -} - -void scheduleTmaWarpSpecializedInnerOuter( - Fusion* fusion, - const ReductionParams* rparams) { - FusionGuard fg(fusion); - - // Grab the reduction, input, and output tensor views. dummy_outputs are - // helper tensors for persistent buffer projection. - std::vector dummy_outputs, cached_inputs, reduction_tvs, - smem_consumers; - std::vector> cached_outputs; - normalization_scheduler_utils::beforeSchedule( - fusion, - rparams, - dummy_outputs, - cached_inputs, - reduction_tvs, - smem_consumers, - cached_outputs); - - // split reduction_tvs into inner and outer reduction_tvs - std::vector inner_reduction_tvs, outer_reduction_tvs; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - inner_reduction_tvs.emplace_back(tv); - } else { - outer_reduction_tvs.emplace_back(tv); - } - } - NVF_ERROR( - !inner_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); - NVF_ERROR( - !outer_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); - - // schedule inner reduction, only schedule the first inner reduction tv, - // then will be propagated to other inner reduction tvs. - TensorView* inner_reference_tv = - normalization_scheduler_utils::scheduleReductionGeneral( - fusion, - rparams, - inner_reduction_tvs, - InnerOuterPersistentKernelScheduler::schedulerType()); - - // schedule outer reduction, schedule all the outer reduction tvs since we - // need to store the intermediate results. - std::vector cached_gmem; - std::vector cached_gmem_reload; - std::vector outer_reference_tvs; - std::unordered_set boundaryNodesSet; - scheduleTmaWarpSpecializedOuter( - fusion, - rparams, - outer_reduction_tvs, - cached_gmem, - cached_gmem_reload, - outer_reference_tvs, - boundaryNodesSet); - - // Propagate inner reduction and outer reductions - for (auto output : dummy_outputs) { - fusion->addOutput(output); - } - - // Collect tvs loaded with TMA, they require special scheduling. - std::vector tma_load_tvs; - if (rparams->tma_warp_specialized) { - for (auto tv : smem_consumers) { - auto smem_tv = ir_utils::getSoleProducerTv(tv); - if (std::find(tma_load_tvs.begin(), tma_load_tvs.end(), smem_tv) == - tma_load_tvs.end()) { - tma_load_tvs.emplace_back(smem_tv); - } - } - } - - const bool is_unroll_or_vectorization = rparams->isUnrolled(); - const bool is_vectorize = - rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; - const bool is_outer_grid_persistence = rparams->persistent_kernel && - rparams->cross_grid_inner_reduction && !rparams->fastest_dim; - - // Propagate transformations for inner reduction. - // Two steps are used since tma tvs are scheduled differently. - // Step-1, propagate iteration domain in inner reduction. - // Step-2, propagate reduction domain in inner reduction. - if (rparams->tma_warp_specialized) { - // Find the axis that splits the reduction domain and iteration domain. - int first_redu_axis = -1; - int n_dims = (int)inner_reference_tv->nDims(); - for (auto i = 0; i < n_dims; i++) { - if (inner_reference_tv->axis(i)->isReduction() || - inner_reference_tv->axis(i)->isRFactorProduct()) { - first_redu_axis = i; - break; - } - } - - // Step-1, propagate iteration domain in inner reduction. - // outer_reference_tvs are excluded since they are already scheduled - // with a different pattern for the final step of outer reduciton. - if (first_redu_axis > 0) { - TransformPropagator propagator(inner_reference_tv, first_redu_axis - 1); - std::vector all_tvs_except = ir_utils::allTvsExcept( - fusion, {outer_reference_tvs.begin(), outer_reference_tvs.end()}); - SetSelector selector({all_tvs_except.begin(), all_tvs_except.end()}); - MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) - .traverse(&propagator); - } - - // Step-2, propagate reduction domain in inner reduction. - // (a) Tvs in boundaryNodesSet are excluded since they should follow outer - // reduction pattern. - // (b) TMA tvs are excluded since they require special scheduling. - // (3) Excluding tma tvs breaks the propagation path from inner reduction tv - // to cached_gmem which stores the results of the first-stage of outer - // reduction. The solution is adding a dummy output to link them. The same - // trick is used when projecting persistent buffers to inputs. - auto inner_reduction_input = - ir_utils::getSoleProducerTv(inner_reference_tv); - for (auto tv : cached_gmem) { - // T1(smem) --> T2 (l) --> T3 = OuterRedu(T2) --> T4(cached_gmem) - // outer_reduction_input: T2 - // partial_outer_redu_tv: T3 - auto partial_outer_redu_tv = ir_utils::getSoleProducerTv(tv); - auto outer_reduction_input = - ir_utils::getSoleProducerTv(partial_outer_redu_tv); - auto dummy_output = add(inner_reduction_input, outer_reduction_input); - fusion->addOutput(dummy_output); - dummy_outputs.emplace_back(dummy_output); - } - - // Tvs requiring special scheduling - std::unordered_set special_tvs{ - tma_load_tvs.begin(), tma_load_tvs.end()}; - for (auto tv : boundaryNodesSet) { - if (special_tvs.count(tv) == 0) { - special_tvs.emplace(tv); - } - } - TransformPropagator propagator(inner_reference_tv); - std::vector all_tvs_except_cache = ir_utils::allTvsExcept( - fusion, {special_tvs.begin(), special_tvs.end()}); - SetSelector selector( - {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); - MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) - .traverse(&propagator); - } else { - reduction_scheduler_utils::propagateTransformation( - inner_reference_tv, boundaryNodesSet); - } - reduction_scheduler_utils::propagateRFactor( - inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); - - // parallelization propagation - const auto& selected_tvs_inner = - scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); - reduction_scheduler_utils::propagateParallelization( - inner_reduction_tvs[0], - inner_reference_tv, - is_unroll_or_vectorization, - is_outer_grid_persistence, - inner_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_inner.begin(), selected_tvs_inner.end()}); - - // Propagate outer reduction. Each outer reduction is connected with its - // cached_gmem and output, since we added all the cached_gmem to the - // boundaryNodesSet, the transformation from one outer reduction can't - // propagate to other outer reductions due to the cutoff at - // boundaryNodesSet. Thus, we need a loop to initiate the propagation from - // each outer reduction. Don't allow parallelization propagation goes - // through cached_gmem, see issue 246. - for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { - const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( - {outer_reduction_tvs[i]}, {cached_gmem[i]}); - reduction_scheduler_utils::propagateTransformation( - outer_reference_tvs[i], boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - outer_reference_tvs[i], - is_vectorize, - cached_inputs, - cached_outputs); - reduction_scheduler_utils::propagateParallelization( - outer_reduction_tvs[i], - outer_reference_tvs[i], - is_unroll_or_vectorization, - is_outer_grid_persistence, - outer_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_outer.begin(), selected_tvs_outer.end()}); - } - - // Up to this point, the outer dimension of the TMA tv is scheduled - // the same way as the inner reduction tv. However, the inner dimension - // has not been scheduled yet. Since 1D TMA allows unrestricted load size, - // we can simply parallelize the entire inner dimension using bulk. - // Example: 2D tensor, [BIDy, S, | Bulk] - // Example: 1D tensor, [Bulk] - if (rparams->tma_warp_specialized) { - for (auto tv : tma_load_tvs) { - tv->axis(-1)->parallelize(ParallelType::Bulk); - } - } - - // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write - // is guaranteed to be smaller or equal to input vectorization factor. - if (rparams->vectorization_factor_tmp_gmem_write > 1) { - for (auto tv : cached_gmem) { - NVF_ERROR( - rparams->vectorization_factor_tmp_gmem_write <= - rparams->unroll_factor_inner_reduction, - "vectorization factor of temp gmem write should be smaller than that of inner reduction.") - if (rparams->vectorization_factor_tmp_gmem_write < - rparams->unroll_factor_inner_reduction) { - tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - } - // vectorization propagate through propagateParallelization only works for - // input and output tensors. propagate vectorization to cached_gmem_reload - // directly from output tv using parallelizeAllLike. must propagate - // seperaely for different tvs as outer reductions are transformed - // seperately. - if (rparams->vectorization_factor_outer > 1) { - for (auto tv : cached_gmem_reload) { - auto output_tvs = ir_utils::outputTvsOf(tv); - NVF_ERROR( - !output_tvs.empty(), - "cached_gmem_reload should have at least one output tensor.") - scheduler_utils::parallelizeAllLike( - output_tvs[0], - -1, - {cached_gmem_reload.begin(), cached_gmem_reload.end()}, - {ParallelType::Vectorize}); - } - } - - // Needs special handling of vectorized loading from shared memory due to - // potential different data types of inputs and shared memory tensor. - if (is_vectorize) { - reduction_scheduler_utils::sharedMemoryConsumerVectorization( - smem_consumers, rparams->unroll_factor_inner_reduction); - } - - // Remove dummy outputs as they can inadvertently affect CA positions - for (auto output : dummy_outputs) { - fusion->removeOutput(output); - } - inlineMost(); -} - } // namespace bool InnerOuterPersistentKernelScheduler::canScheduleCompileTime( @@ -2027,14 +354,15 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( scheduler_hyperparameters_entry.get(); // check if there is enough register and shared memory for persistence - const auto buffer_params = getPersistentBufferStorageParams( - fusion, - runtime_info, - data_cache, - reduction_tvs, - hp.vectorize_factor, - hp.threads_per_block_min, - hp.threads_per_block_max); + const auto buffer_params = + inner_outer_utils::getPersistentBufferStorageParams( + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; @@ -2103,9 +431,9 @@ void InnerOuterPersistentKernelScheduler::schedule( "Incorrect parameters sent to InnerOuterPersistentKernelScheduler::schedule", params); if (rparams->tma_warp_specialized) { - scheduleTmaWarpSpecializedInnerOuter(fusion, rparams); + inner_outer_tma_warp_specialized::scheduleFusion(fusion, rparams); } else { - scheduleInnerOuterPersistentKernel(fusion, rparams); + inner_outer_multi_wave::scheduleFusion(fusion, rparams); } } } // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_multi_wave.cpp b/csrc/scheduler/normalization_inner_outer_multi_wave.cpp new file mode 100644 index 00000000000..6ab87e871bf --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_multi_wave.cpp @@ -0,0 +1,717 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include + +namespace nvfuser { +namespace inner_outer_multi_wave { +// The innerOuterPersistentHeuristic is tuned for layer_norm backward on A100 +// ======= Method if hidden_size > 1024 ======= +// (1) Inner reduction is one reduction per block. Reduction domain is +// parallelized by TIDx and TIDy, Iteration domain is parallelized by BIDy. +// (2) Outer reduction is done in two-steps. The first step is partial +// reduction, reduction domain is parallelized by BIDy, iteration domain is +// parallelized by TIDx and TIDy. The partial results are written to gmem +// followed by a grid sync. The second step is block reduction, the reduction +// domain is parallelized by TIDy, the iteration domain is parallelized by TIDx +// and BIDy. +// ======= Method if hidden_size <= 1024 ======= +// (1) Inner reduction is multi-reductions per blocks. Reduction domain is +// parallelized by TIDx, Iteration domain is parallelized by BIDy and TIDy. +// (2) Outer reduction is same to cases where hidden_size > 1024 except the +// second step where in this case, the reduction domain is parallelized by TIDx +// and the iteration domain is parallelized by TIDy and BIDy. This switch +// between TIDx and TIDy is because: +// (a) We can do warp reduction with TIDx +// (b) TIDx*BIDy is usually much larger than hidden_size, e.g. 128*216 = 1024*27 +// this means without switch only 1/27 of the threads is used. +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type) { + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t device_multiprocessor_count = + (int64_t)dev_prop->multiProcessorCount; + // Parameters for inner reduction: + // Reduction dim: inner_vect, inner_batch, bdimx and bdimy + // Iteration dim: gdimy + + // Parameters for outer reduction: + // Reduction dim: bdimy + // Iteration dim: vectorization_factor_outer, bdimx, gdimy + struct InnerOuterParams { + int64_t inner_vect = -1; + int64_t inner_batch = -1; + int64_t bdimx = -1; + int64_t bdimy = -1; + int64_t bdimz = -1; + int64_t gdimy = -1; + int64_t tmp_gmem_write_vect = -1; + int64_t vectorization_factor_outer = -1; + int64_t threads_per_block = -1; + // derived metrics for sorting + int64_t warps_per_sm = -1; + int64_t required_register_per_thread = -1; + int64_t available_register_per_thread = -1; + + void verify() { + NVF_ERROR(inner_vect != -1, "inner_vect is not set."); + NVF_ERROR(inner_batch != -1, "inner_batch is not set."); + NVF_ERROR(bdimx != -1, "bdimx is not set."); + NVF_ERROR(bdimy != -1, "bdimy is not set."); + NVF_ERROR(gdimy != -1, "gdimy is not set."); + NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); + NVF_ERROR( + vectorization_factor_outer != -1, + "vectorization_factor_outer is not set."); + } + std::string toString() const { + std::stringstream ss; + ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch + << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz + << ", gdimy: " << gdimy + << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect + << ", vectorization_factor_outer: " << vectorization_factor_outer + << ", threads_per_block: " << threads_per_block + << ", warps_per_sm: " << warps_per_sm + << ", required_register_per_thread: " << required_register_per_thread + << ", available_register_per_thread: " + << available_register_per_thread; + return ss.str(); + } + }; + + // Set a minimum workload for each thread to take advantage of low + // intra-threads communication cost. + // Tuned for layer_norm backward on A100, still works fine on H100. + auto getMinimumBatch = [&]() -> int64_t { + if (inner_dim_numel >= 3072l) { + if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { + return 3l; + } else { + return 4l; + } + } else if (inner_dim_numel >= 2048l) { + return 2l; + } + return 1l; + }; + + // Estimate register usage per thread based on buffer size. + // Assuming a constant register overhead for non-buffer related usage, + // and all the register buffers are stored in registers. + auto getEstimatedRegisterUsage = [&](int64_t batch_mul_vect) { + int64_t persistent_buffer_size = + regs_buffer_size / inner_dim_numel * batch_mul_vect; + int64_t estimated_register_count = + persistent_buffer_size / scheduler_utils::bytes_per_register + + scheduler_utils::register_overhead; + return std::min( + estimated_register_count, scheduler_utils::max_registers_per_thread); + }; + + // Estimate max blocks per sm based on register and shared memory usage. + auto getBlocksPerSM = [&](const int64_t threads_per_sm, + const int64_t threads_per_block, + const int64_t warp_size) { + // check register limitation on blocks per sm + constexpr int64_t warp_allocation_granularity = 4; + const int64_t allocated_warps_per_block = + ceilDiv( + ceilDiv(threads_per_block, warp_size), + warp_allocation_granularity) * + warp_allocation_granularity; + int64_t max_blocks_per_sm_regs = scheduler_utils::safeDiv( + threads_per_sm / warp_size, allocated_warps_per_block); + // check shared memory limitation on blocks per sm + int64_t max_blocks_per_sm_smem = + (int64_t)dev_prop->sharedMemPerMultiprocessor / + (smem_overhead + smem_buffer_size); + return std::min(max_blocks_per_sm_regs, max_blocks_per_sm_smem); + }; + + // In the inner reduction part of the kernel, gdimy is used to parallelize the + // outer dimension. The kernel is a cooperative kernel, so the number of + // blocks should be as large as possible to achieve a high occupancy unless + // outer dim is too small which may lead large workload for the final outer + // reduction. So, gdimy is drvied from the number of blocks per sm and limited + // to ensure at least 8 rows per block. + // TODO: re-evaluate this 8 rows per block requirement. + auto getGdimy = [&](int64_t inner_vect, + int64_t threads_per_block, + int64_t inner_batch) { + int64_t reg_per_thread = + getEstimatedRegisterUsage(inner_vect * inner_batch); + int64_t threads_per_sm = getThreadsPerSMGivenRegPerThread(reg_per_thread); + int64_t blocks_per_sm = + getBlocksPerSM(threads_per_sm, threads_per_block, dev_prop->warpSize); + int64_t gdimy = blocks_per_sm * device_multiprocessor_count; + const int64_t outer_iter_min = 8; + const int64_t gdimy_max = scheduler_utils::roundUpToN( + ceilDiv(outer_dim_numel, outer_iter_min), device_multiprocessor_count); + while (gdimy > gdimy_max && blocks_per_sm > 1) { + blocks_per_sm -= 1; + gdimy = blocks_per_sm * device_multiprocessor_count; + } + return gdimy; + }; + + // The inner reduction part of the kernel also does a partial outer reduction + // and stores the partial results in tmp gmem and then reloaded to finish the + // outer reduciton. This function set the vectorization factor for write and + // and read of the partial outer reduction result. + // For write to tmp gmem, follows vectorization factor of inner reduction + // but don't exceed 16 bytes. + // For read from tmp gmem, since the paralelization is changed, a different + // vectorization factor is used to optimize the + // number of reaductions per thread. + auto getOuterReductionBufferVectFactor = [&](int64_t inner_vect) { + constexpr int64_t max_gmem_vect_access_bytes = 16; + const int64_t max_tmp_gmem_vect_factor = std::min( + max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); + int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; + const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; + int64_t vectorization_factor_outer = + std::min(workload_per_thread, max_tmp_gmem_vect_factor); + return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); + }; + + // In the outer reduction part of the kernel, inner and outer dims are + // parallelized as: + // --- inner dim: vect, bdimx, gdimy ---- + // --- outer dim: bdimy ----------------- + // This function splits the threads_per_block into bdimx and bdimy using: + // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) + // bdimy = threads_per_block / bdimx + auto getBdimxBdimy = [&](int64_t threads_per_block, + int64_t vectorization_factor_outer, + int64_t gdimy) { + // For widely used hidden sizes, threads_per_block has factor of 8, roundup + // to increase the probability of bdimx * bdimy == threads_per_block. + int64_t bdimx = scheduler_utils::roundUpPow2Or8( + ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); + // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. + // increase bdimx to make it divisible. Under worst case, bdimx equals to + // threads_per_block. + while (threads_per_block % bdimx) { + bdimx = std::min(bdimx + 8, threads_per_block); + } + // Set OuterParams Reduction dim: bdimy. + int64_t bdimy = threads_per_block / bdimx; + NVF_ERROR( + bdimy * bdimx == threads_per_block, + " threads_per_block must be divisible by bdimx and bdimy."); + return std::make_pair(bdimx, bdimy); + }; + + // Get the heuristics given vectorization factor and threads per block + auto getHeuristicsGivenVectThreads = [&](int64_t vect_factor, + int64_t threads_per_block) { + InnerOuterParams iop; + // (1) inner reduction + // Reduction dim: inner_batch, threads_per_block, vect_factor + // Iteration dim: gdimy + iop.inner_vect = vect_factor; + iop.threads_per_block = threads_per_block; + iop.inner_batch = + ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); + iop.gdimy = + getGdimy(iop.inner_vect, iop.threads_per_block, iop.inner_batch); + // (2) outer reduction + // Iteration dim: gdimy, bdimx, vectorization_factor_outer + // Reduction dim: bdimy + std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = + getOuterReductionBufferVectFactor(iop.inner_vect); + auto [bdimx, bdimy] = getBdimxBdimy( + threads_per_block, iop.vectorization_factor_outer, iop.gdimy); + iop.bdimx = bdimx; + iop.bdimy = bdimy; + // (3) Derived metrics warps_per_sm and register usage for sorting + iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + iop.required_register_per_thread = + getEstimatedRegisterUsage(iop.inner_vect * iop.inner_batch); + return iop; + }; + + // Use the maximum vectorization factor + const int64_t vect_factor = (int64_t)vectorize_factor; + + // Set a reasonable range for threads per block based on the number of + // elements in the inner dimension after vectorization. + // Start from 128 or a smaller number if inner dim is small. + const int64_t after_vect = inner_dim_numel / vect_factor; + const int64_t batch_min = getMinimumBatch(); + int64_t threads_per_block_min = hp_threads_per_block_min; + threads_per_block_min = std::min(threads_per_block_min, after_vect); + threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); + + // star max threads per block from min threads per block + int64_t threads_per_block_max = threads_per_block_min; + // increase to cover the whole inner dim + threads_per_block_max = + std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); + // round up to power of 2 + threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); + // don't go beyond the maximum threads per block + threads_per_block_max = + std::min(threads_per_block_max, hp_threads_per_block_max); + + // Store all the possible heuristics based on different threads per block. + // Vectorizaton is fixed at the maximum value. + std::vector iop_candidates; + for (auto threads_per_block = threads_per_block_max; + threads_per_block >= threads_per_block_min; + threads_per_block /= 2) { + iop_candidates.emplace_back( + getHeuristicsGivenVectThreads(vect_factor, threads_per_block)); + } + + // Sort the heuristics based on the register usage and occupancy. + std::stable_sort( + iop_candidates.begin(), + iop_candidates.end(), + [](const InnerOuterParams& a, const InnerOuterParams& b) { + // If a thread can use more registers than required, there is a high + // chance that it can avoid register spilling and compiler can optimize + // for better instruction level parallelism. + int64_t extra_regs_a = + a.available_register_per_thread - a.required_register_per_thread; + int64_t extra_regs_b = + b.available_register_per_thread - b.required_register_per_thread; + if (extra_regs_a > 0 && extra_regs_b < 0) { + return true; + } else if (extra_regs_a < 0 && extra_regs_b > 0) { + return false; + } + // High occupancy provides better threads level parallelism. + // 25% is sufficient since ILP is high due to persistent batch sizes + // which is equivalent to unrolling inner dim. + if (a.warps_per_sm != b.warps_per_sm && + (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { + return a.warps_per_sm > b.warps_per_sm; + } + // Tie breaker, smaller threads_per_block to reduce communication + // overhead + return a.threads_per_block < b.threads_per_block; + }); + + // Pick the best heuristic + auto iop = iop_candidates.front(); + + // Special case, when inner_dim_numel <= 1024, bdimx is usually small + // after divide by inner_vect and inner_batch. In this case, bdimy is used to + // parallelize outer_dim instead of inner_dim. This pattern is named multi + // reductions per block (mrpb). + if (inner_dim_numel <= 1024) { + rparams->multiple_reds_per_blk = true; + rparams->tidx_for_outer_reduction = true; + + // Step-1, InnerParams, Reduction dim: inner_vect(reuse), + // inner_batch(reuse), bdimx + iop.bdimx = ceilDiv(inner_dim_numel, iop.inner_vect * iop.inner_batch); + + // Step-2, InnerParams, Iteration dim: gdimy, bdimy (in next step) + iop.gdimy = getGdimy(iop.inner_vect, iop.bdimx, iop.inner_batch); + + // Step-3, OuterParams, Iteration dim: vectorization_factor_outer(reuse), + // bdimy, gdimy (in previous step). + // WAR for https://github.com/NVIDIA/Fuser/issues/3428 + iop.bdimy = 1; + + // Step-4, OuterParams, Reduction dim: bdimx (already done) + iop.warps_per_sm = ceilDiv(iop.bdimx * iop.bdimy, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + + if (iop.bdimx % dev_prop->warpSize == 0) { + rparams->pad_inner_reduction_to_warp = true; + rparams->pad_outer_reduction_to_warp = true; + } + rparams->block_dim_iter_dom = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimy * iop.gdimy < + inner_dim_numel; + } else { + rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < + inner_dim_numel; + rparams->static_bdimx = true; + rparams->static_bdimy = true; + iop.bdimz = ceilDiv( + ceilDiv( + ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), + iop.inner_batch); + NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); + } + + // check all the parameters in InnerOuterParams are set. + iop.verify(); + + rparams->persistent_kernel = true; + rparams->fastest_dim = true; + rparams->combined_inner_outer = true; + // tmp_gmem is the intermediate result of outer reduction, its dtype is float, + // so the maximum vectorization factor is 4. + rparams->vectorization_factor_outer = iop.vectorization_factor_outer; + rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; + rparams->cparams.maxrregcount = iop.available_register_per_thread; + rparams->unroll_factor_inner_reduction = iop.inner_vect; + rparams->batches_per_block_inner_reduction = iop.inner_batch; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->vectorize_inner_reduction = iop.inner_vect > 1; + rparams->split_grid_dim_iter_dom_outer = true; + rparams->grid_dim_iter_dom = ParallelType::BIDy; + + rparams->lparams = LaunchParams( + LaunchParams::UNINITIALIZED_VAL, + iop.gdimy, + LaunchParams::UNINITIALIZED_VAL, + iop.bdimx, + iop.bdimy, + LaunchParams::UNINITIALIZED_VAL); + + if (!rparams->smem_persistent_buffers.empty()) { + rparams->tag = + "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; + } else { + rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; + } + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" + << "outer_dim_numel: " << outer_dim_numel << "\n" + << "inner_dim_numel: " << inner_dim_numel << "\n" + << "regs_buffer_size: " << regs_buffer_size << "\n" + << "smem_buffer_size: " << smem_buffer_size << "\n" + << "smem_overhead: " << smem_overhead << "\n" + << "vectorize_factor_input: " << iop.inner_vect << "\n" + << "vectorization_factor_tmp_gmem_write: " + << iop.tmp_gmem_write_vect << "\n" + << "vectorization_factor_outer: " << iop.vectorization_factor_outer + << "\n" + << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk + << "\n" + << "warps_per_sm: " << iop.warps_per_sm << "\n" + << "gdimy: " << iop.gdimy << "\n" + << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; + debug() << rparams->toString() << std::endl; + } +} + +void scheduleOuterReduction( + Fusion* fusion, + const ReductionParams* rparams, + const std::vector& outer_reduction_tvs, + std::vector& cached_gmem, + std::vector& cached_gmem_reload, + std::vector& outer_reference_tvs, + std::unordered_set& boundaryNodesSet) { + auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { + int prev_i = -1; + for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { + if (mergeReduction == tv->axis(i)->isReduction()) { + if (prev_i == -1) { + prev_i = i; + } else { + tv->merge(i, prev_i); + prev_i = i; + } + } + } + }; + for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + + // merge tensorview to [reduction, iteraiton] domains + mergeReductionOrIterDomains(outer_reduction_tv, true); + mergeReductionOrIterDomains(outer_reduction_tv, false); + if (rparams->multiple_reds_per_blk) { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(rparams->block_dim_iter_dom)); + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(rparams->grid_dim_iter_dom), false); + } else { + outer_reduction_tv->split(0, rparams->lparams.gdimy()); + } + + if (rparams->multiple_reds_per_blk) { + outer_reduction_tv->rFactor({1}); + } + TensorView* partialResult = rparams->multiple_reds_per_blk + ? outer_reduction_tv->rFactor({1}) + : outer_reduction_tv->rFactor({0}); + partialResult->cacheBefore(); + partialResult->setMemoryType(MemoryType::Global); + TensorView* partialResultReload = partialResult->cacheAfter(); + + boundaryNodesSet.insert(partialResultReload); + cached_gmem.emplace_back(partialResult); + cached_gmem_reload.emplace_back(partialResultReload); + + if (rparams->multiple_reds_per_blk) { + if (rparams->tidx_for_outer_reduction) { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(ParallelType::TIDx)); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDx); + // to use warp reduction + if (rparams->pad_outer_reduction_to_warp) { + outer_reduction_tv->axis(1)->padToMultipleOfWarp(); + } + } else { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(ParallelType::TIDy)); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + } + // iteration domain + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize( + ParallelType::Vectorize); + } + + if (rparams->tidx_for_outer_reduction) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::TIDy)); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDy); + } else { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::TIDx)); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + + } else { + // reduction domain + outer_reduction_tv->split(0, rparams->lparams.bdimy()); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + + // iteration domain + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize( + ParallelType::Vectorize); + } + + if (rparams->lparams.bdimx() > 1) { + outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + } + auto outer_reference_tv = + reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); + outer_reference_tvs.emplace_back(outer_reference_tv); + } +} + +// fusion is the input IR that will be modified by this function +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams) { + FusionGuard fg(fusion); + + // Grab the reduction, input, and output tensor views. dummy_outputs are + // helper tensors for persistent buffer projection. + std::vector dummy_outputs, cached_inputs, reduction_tvs, + smem_consumers; + std::vector> cached_outputs; + normalization_scheduler_utils::beforeSchedule( + fusion, + rparams, + dummy_outputs, + cached_inputs, + reduction_tvs, + smem_consumers, + cached_outputs); + + // split reduction_tvs into inner and outer reduction_tvs + std::vector inner_reduction_tvs, outer_reduction_tvs; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + inner_reduction_tvs.emplace_back(tv); + } else { + outer_reduction_tvs.emplace_back(tv); + } + } + NVF_ERROR( + !inner_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); + NVF_ERROR( + !outer_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); + + // schedule inner reduction, only schedule the first inner reduction tv, + // then will be propagated to other inner reduction tvs. + TensorView* inner_reference_tv = + normalization_scheduler_utils::scheduleReductionGeneral( + fusion, + rparams, + inner_reduction_tvs, + SchedulerType::InnerOuterPersistent); + + // schedule outer reduction, schedule all the outer reduction tvs since we + // need to store the intermediate results. + std::vector cached_gmem; + std::vector cached_gmem_reload; + std::vector outer_reference_tvs; + std::unordered_set boundaryNodesSet; + scheduleOuterReduction( + fusion, + rparams, + outer_reduction_tvs, + cached_gmem, + cached_gmem_reload, + outer_reference_tvs, + boundaryNodesSet); + + // Propagate inner reduction and outer reductions + for (auto output : dummy_outputs) { + fusion->addOutput(output); + } + + const bool is_unroll_or_vectorization = rparams->isUnrolled(); + const bool is_vectorize = + rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; + const bool is_outer_grid_persistence = rparams->persistent_kernel && + rparams->cross_grid_inner_reduction && !rparams->fastest_dim; + + // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this + // propagation will not propagate to the final outer reduction. + reduction_scheduler_utils::propagateTransformation( + inner_reference_tv, boundaryNodesSet); + reduction_scheduler_utils::propagateRFactor( + inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); + + // Don't allow parallelization propagation goes through boundaryNodesSet + const auto& selected_tvs_inner = + scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); + reduction_scheduler_utils::propagateParallelization( + inner_reduction_tvs[0], + inner_reference_tv, + is_unroll_or_vectorization, + is_outer_grid_persistence, + inner_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_inner.begin(), selected_tvs_inner.end()}); + + // Propagate outer reduction. Each outer reduction is connected with its + // cached_gmem and output, since we added all the cached_gmem to the + // boundaryNodesSet, the transformation from one outer reduction can't + // propagate to other outer reductions due to the cutoff at + // boundaryNodesSet. Thus, we need a loop to initiate the propagation from + // each outer reduction. Don't allow parallelization propagation goes + // through cached_gmem, see issue 246. + for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { + const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( + {outer_reduction_tvs[i]}, {cached_gmem[i]}); + reduction_scheduler_utils::propagateTransformation( + outer_reference_tvs[i], boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + outer_reference_tvs[i], + is_vectorize, + cached_inputs, + cached_outputs); + reduction_scheduler_utils::propagateParallelization( + outer_reduction_tvs[i], + outer_reference_tvs[i], + is_unroll_or_vectorization, + is_outer_grid_persistence, + outer_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_outer.begin(), selected_tvs_outer.end()}); + } + + // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write + // is guaranteed to be smaller or equal to input vectorization factor. + if (rparams->vectorization_factor_tmp_gmem_write > 1) { + for (auto tv : cached_gmem) { + NVF_ERROR( + rparams->vectorization_factor_tmp_gmem_write <= + rparams->unroll_factor_inner_reduction, + "vectorization factor of temp gmem write should be smaller than that of inner reduction.") + if (rparams->vectorization_factor_tmp_gmem_write < + rparams->unroll_factor_inner_reduction) { + tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); + } + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + } + // vectorization propagate through propagateParallelization only works for + // input and output tensors. propagate vectorization to cached_gmem_reload + // directly from output tv using parallelizeAllLike. must propagate + // seperaely for different tvs as outer reductions are transformed + // seperately. + if (rparams->vectorization_factor_outer > 1) { + for (auto tv : cached_gmem_reload) { + auto output_tvs = ir_utils::outputTvsOf(tv); + NVF_ERROR( + !output_tvs.empty(), + "cached_gmem_reload should have at least one output tensor.") + scheduler_utils::parallelizeAllLike( + output_tvs[0], + -1, + {cached_gmem_reload.begin(), cached_gmem_reload.end()}, + {ParallelType::Vectorize}); + } + } + + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + + // Remove dummy outputs as they can inadvertently affect CA positions + for (auto output : dummy_outputs) { + fusion->removeOutput(output); + } + inlineMost(); +} +} // namespace inner_outer_multi_wave +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_multi_wave.h b/csrc/scheduler/normalization_inner_outer_multi_wave.h new file mode 100644 index 00000000000..3b373189246 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_multi_wave.h @@ -0,0 +1,31 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +namespace inner_outer_multi_wave { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type); + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams); +} // namespace inner_outer_multi_wave +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_tma_ws.cpp b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp new file mode 100644 index 00000000000..0da6b734435 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp @@ -0,0 +1,647 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include + +#include +namespace nvfuser { +namespace inner_outer_tma_warp_specialized { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type) { + rparams->tma_warp_specialized = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t device_multiprocessor_count = + (int64_t)dev_prop->multiProcessorCount; + // Parameters for inner reduction: + // Reduction dim: inner_vect, inner_batch, bdimx and bdimy + // Iteration dim: gdimy + + // Parameters for outer reduction: + // Reduction dim: bdimy + // Iteration dim: vectorization_factor_outer, bdimx, gdimy + struct InnerOuterParams { + int64_t inner_vect = -1; + int64_t inner_batch = -1; + int64_t bdimx = -1; + int64_t bdimy = -1; + int64_t bdimz = -1; + int64_t gdimy = -1; + int64_t tmp_gmem_write_vect = -1; + int64_t vectorization_factor_outer = -1; + int64_t threads_per_block = -1; + // derived metrics for sorting + int64_t warps_per_sm = -1; + int64_t required_register_per_thread = -1; + int64_t available_register_per_thread = -1; + + void verify() { + NVF_ERROR(inner_vect != -1, "inner_vect is not set."); + NVF_ERROR(inner_batch != -1, "inner_batch is not set."); + NVF_ERROR(bdimx != -1, "bdimx is not set."); + NVF_ERROR(bdimy != -1, "bdimy is not set."); + NVF_ERROR(gdimy != -1, "gdimy is not set."); + NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); + NVF_ERROR( + vectorization_factor_outer != -1, + "vectorization_factor_outer is not set."); + } + std::string toString() const { + std::stringstream ss; + ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch + << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz + << ", gdimy: " << gdimy + << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect + << ", vectorization_factor_outer: " << vectorization_factor_outer + << ", threads_per_block: " << threads_per_block + << ", warps_per_sm: " << warps_per_sm + << ", required_register_per_thread: " << required_register_per_thread + << ", available_register_per_thread: " + << available_register_per_thread; + return ss.str(); + } + }; + + // Set a minimum workload for each thread to take advantage of low + // intra-threads communication cost. + // Tuned for layer_norm backward on A100, still works fine on H100. + auto get_minimum_batch = [&]() -> int64_t { + if (inner_dim_numel >= 3072l) { + if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { + return 3l; + } else { + return 4l; + } + } else if (inner_dim_numel >= 2048l) { + return 2l; + } + return 1l; + }; + + // Estimate register usage per thread based on buffer size. + // Assuming a constant register overhead for non-buffer related usage, + // and all the register buffers are stored in registers. + auto get_estimated_register_usage = [&](int64_t batch_mul_vect) { + int64_t persistent_buffer_size = + regs_buffer_size / inner_dim_numel * batch_mul_vect; + int64_t estimated_register_count = + persistent_buffer_size / scheduler_utils::bytes_per_register + + scheduler_utils::register_overhead; + return std::min( + estimated_register_count, scheduler_utils::max_registers_per_thread); + }; + + // The inner reduction part of the kernel also does a partial outer reduction + // and stores the partial results in tmp gmem and then reloaded to finish the + // outer reduciton. This function set the vectorization factor for write and + // and read of the partial outer reduction result. + // For write to tmp gmem, follows vectorization factor of inner reduction + // but don't exceed 16 bytes. + // For read from tmp gmem, since the paralelization is changed, a different + // vectorization factor is used to optimize the + // number of reaductions per thread. + auto get_outer_reduction_buffer_vect_factor = [&](int64_t inner_vect) { + constexpr int64_t max_gmem_vect_access_bytes = 16; + const int64_t max_tmp_gmem_vect_factor = std::min( + max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); + int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; + const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; + int64_t vectorization_factor_outer = + std::min(workload_per_thread, max_tmp_gmem_vect_factor); + return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); + }; + + // In the outer reduction part of the kernel, inner and outer dims are + // parallelized as: + // --- inner dim: vect, bdimx, gdimy ---- + // --- outer dim: bdimy ----------------- + // This function splits the threads_per_block into bdimx and bdimy using: + // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) + // bdimy = threads_per_block / bdimx + auto get_bdimx_bdimy = [&](int64_t threads_per_block, + int64_t vectorization_factor_outer, + int64_t gdimy) { + // For widely used hidden sizes, threads_per_block has factor of 8, roundup + // to increase the probability of bdimx * bdimy == threads_per_block. + int64_t bdimx = scheduler_utils::roundUpPow2Or8( + ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); + // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. + // increase bdimx to make it divisible. Under worst case, bdimx equals to + // threads_per_block. + while (threads_per_block % bdimx) { + bdimx = std::min(bdimx + 8, threads_per_block); + } + // Set OuterParams Reduction dim: bdimy. + int64_t bdimy = threads_per_block / bdimx; + NVF_ERROR( + bdimy * bdimx == threads_per_block, + " threads_per_block must be divisible by bdimx and bdimy."); + return std::make_pair(bdimx, bdimy); + }; + + // Get the heuristics given vectorization factor and threads per block + auto get_heuristics_given_vect_threads = [&](int64_t vect_factor, + int64_t threads_per_block) { + InnerOuterParams iop; + // (1) inner reduction + // Reduction dim: inner_batch, threads_per_block, vect_factor + // Iteration dim: gdimy + iop.inner_vect = vect_factor; + iop.threads_per_block = threads_per_block; + iop.inner_batch = + ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); + iop.gdimy = device_multiprocessor_count; + + // (2) outer reduction + // Iteration dim: gdimy, bdimx, vectorization_factor_outer + // Reduction dim: bdimy + std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = + get_outer_reduction_buffer_vect_factor(iop.inner_vect); + auto [bdimx, bdimy] = get_bdimx_bdimy( + threads_per_block, iop.vectorization_factor_outer, iop.gdimy); + iop.bdimx = bdimx; + iop.bdimy = bdimy; + // (3) Derived metrics warps_per_sm and register usage for sorting + iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + iop.required_register_per_thread = + get_estimated_register_usage(iop.inner_vect * iop.inner_batch); + return iop; + }; + + // Use the maximum vectorization factor + const int64_t vect_factor = (int64_t)vectorize_factor; + + // Set a reasonable range for threads per block based on the number of + // elements in the inner dimension after vectorization. + // Start from 128 or a smaller number if inner dim is small. + const int64_t after_vect = inner_dim_numel / vect_factor; + const int64_t batch_min = get_minimum_batch(); + int64_t threads_per_block_min = hp_threads_per_block_min; + threads_per_block_min = std::min(threads_per_block_min, after_vect); + threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); + + // star max threads per block from min threads per block + int64_t threads_per_block_max = threads_per_block_min; + // increase to cover the whole inner dim + threads_per_block_max = + std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); + // round up to power of 2 + threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); + // don't go beyond the maximum threads per block + threads_per_block_max = + std::min(threads_per_block_max, hp_threads_per_block_max); + + // Store all the possible heuristics based on different threads per block. + // Vectorizaton is fixed at the maximum value. + std::vector iop_candidates; + for (auto threads_per_block = threads_per_block_max; + threads_per_block >= threads_per_block_min; + threads_per_block /= 2) { + iop_candidates.emplace_back( + get_heuristics_given_vect_threads(vect_factor, threads_per_block)); + } + + // Sort the heuristics based on the register usage and occupancy. + std::stable_sort( + iop_candidates.begin(), + iop_candidates.end(), + [](const InnerOuterParams& a, const InnerOuterParams& b) { + // If a thread can use more registers than required, there is a high + // chance that it can avoid register spilling and compiler can optimize + // for better instruction level parallelism. + int64_t extra_regs_a = + a.available_register_per_thread - a.required_register_per_thread; + int64_t extra_regs_b = + b.available_register_per_thread - b.required_register_per_thread; + if (extra_regs_a > 0 && extra_regs_b < 0) { + return true; + } else if (extra_regs_a < 0 && extra_regs_b > 0) { + return false; + } + // High occupancy provides better threads level parallelism. + // 25% is sufficient since ILP is high due to persistent batch sizes + // which is equivalent to unrolling inner dim. + if (a.warps_per_sm != b.warps_per_sm && + (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { + return a.warps_per_sm > b.warps_per_sm; + } + // Tie breaker, smaller threads_per_block to reduce communication + // overhead + return a.threads_per_block < b.threads_per_block; + }); + + // Pick the best heuristic + auto iop = iop_candidates.front(); + rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < inner_dim_numel; + rparams->static_bdimx = true; + rparams->static_bdimy = true; + iop.bdimz = ceilDiv( + ceilDiv(ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), + iop.inner_batch); + NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); + + // check all the parameters in InnerOuterParams are set. + iop.verify(); + + rparams->persistent_kernel = true; + rparams->fastest_dim = true; + rparams->combined_inner_outer = true; + // tmp_gmem is the intermediate result of outer reduction, its dtype is float, + // so the maximum vectorization factor is 4. + rparams->vectorization_factor_outer = iop.vectorization_factor_outer; + rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; + rparams->cparams.maxrregcount = iop.available_register_per_thread; + rparams->unroll_factor_inner_reduction = iop.inner_vect; + rparams->batches_per_block_inner_reduction = iop.inner_batch; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->vectorize_inner_reduction = iop.inner_vect > 1; + rparams->split_grid_dim_iter_dom_outer = true; + rparams->grid_dim_iter_dom = ParallelType::BIDy; + + rparams->lparams = LaunchParams( + LaunchParams::UNINITIALIZED_VAL, + iop.gdimy, + LaunchParams::UNINITIALIZED_VAL, + iop.bdimx, + iop.bdimy, + LaunchParams::UNINITIALIZED_VAL); + + rparams->tag = "TMA Warp Specialized Persistent Heuristic.\n"; + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" + << "outer_dim_numel: " << outer_dim_numel << "\n" + << "inner_dim_numel: " << inner_dim_numel << "\n" + << "regs_buffer_size: " << regs_buffer_size << "\n" + << "smem_buffer_size: " << smem_buffer_size << "\n" + << "smem_overhead: " << smem_overhead << "\n" + << "vectorize_factor_input: " << iop.inner_vect << "\n" + << "vectorization_factor_tmp_gmem_write: " + << iop.tmp_gmem_write_vect << "\n" + << "vectorization_factor_outer: " << iop.vectorization_factor_outer + << "\n" + << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk + << "\n" + << "warps_per_sm: " << iop.warps_per_sm << "\n" + << "gdimy: " << iop.gdimy << "\n" + << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; + debug() << rparams->toString() << std::endl; + } +} + +void scheduleOuterReduction( + Fusion* fusion, + const ReductionParams* rparams, + const std::vector& outer_reduction_tvs, + std::vector& cached_gmem, + std::vector& cached_gmem_reload, + std::vector& outer_reference_tvs, + std::unordered_set& boundaryNodesSet) { + auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { + int prev_i = -1; + for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { + if (mergeReduction == tv->axis(i)->isReduction()) { + if (prev_i == -1) { + prev_i = i; + } else { + tv->merge(i, prev_i); + prev_i = i; + } + } + } + }; + for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + + // merge tensorview to [reduction, iteraiton] domains + mergeReductionOrIterDomains(outer_reduction_tv, true); + mergeReductionOrIterDomains(outer_reduction_tv, false); + + // First-stage of outer reduction + outer_reduction_tv->split(0, rparams->lparams.gdimy()); + + TensorView* partialResult = outer_reduction_tv->rFactor({0}); + partialResult->cacheBefore(); + partialResult->setMemoryType(MemoryType::Global); + TensorView* partialResultReload = partialResult->cacheAfter(); + + boundaryNodesSet.insert(partialResultReload); + cached_gmem.emplace_back(partialResult); + cached_gmem_reload.emplace_back(partialResultReload); + + // Second-stage of outer reduction + // reduction domain, [I1/TIDy, TIDy] + outer_reduction_tv->split(0, rparams->lparams.bdimy()); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + // iteration domain, [BIDy, TIDx, Vect] + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Vectorize); + } + + if (rparams->lparams.bdimx() > 1) { + outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + + auto outer_reference_tv = + reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); + outer_reference_tvs.emplace_back(outer_reference_tv); + } +} + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams) { + FusionGuard fg(fusion); + + // Grab the reduction, input, and output tensor views. dummy_outputs are + // helper tensors for persistent buffer projection. + std::vector dummy_outputs, cached_inputs, reduction_tvs, + smem_consumers; + std::vector> cached_outputs; + normalization_scheduler_utils::beforeSchedule( + fusion, + rparams, + dummy_outputs, + cached_inputs, + reduction_tvs, + smem_consumers, + cached_outputs); + + // split reduction_tvs into inner and outer reduction_tvs + std::vector inner_reduction_tvs, outer_reduction_tvs; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + inner_reduction_tvs.emplace_back(tv); + } else { + outer_reduction_tvs.emplace_back(tv); + } + } + NVF_ERROR( + !inner_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); + NVF_ERROR( + !outer_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); + + // schedule inner reduction, only schedule the first inner reduction tv, + // then will be propagated to other inner reduction tvs. + TensorView* inner_reference_tv = + normalization_scheduler_utils::scheduleReductionGeneral( + fusion, + rparams, + inner_reduction_tvs, + SchedulerType::InnerOuterPersistent); + + // schedule outer reduction, schedule all the outer reduction tvs since we + // need to store the intermediate results. + std::vector cached_gmem; + std::vector cached_gmem_reload; + std::vector outer_reference_tvs; + std::unordered_set boundaryNodesSet; + scheduleOuterReduction( + fusion, + rparams, + outer_reduction_tvs, + cached_gmem, + cached_gmem_reload, + outer_reference_tvs, + boundaryNodesSet); + + // Propagate inner reduction and outer reductions + for (auto output : dummy_outputs) { + fusion->addOutput(output); + } + + // Collect tvs loaded with TMA, they require special scheduling. + std::vector tma_load_tvs; + if (rparams->tma_warp_specialized) { + for (auto tv : smem_consumers) { + auto smem_tv = ir_utils::getSoleProducerTv(tv); + if (std::find(tma_load_tvs.begin(), tma_load_tvs.end(), smem_tv) == + tma_load_tvs.end()) { + tma_load_tvs.emplace_back(smem_tv); + } + } + } + + const bool is_unroll_or_vectorization = rparams->isUnrolled(); + const bool is_vectorize = + rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; + const bool is_outer_grid_persistence = rparams->persistent_kernel && + rparams->cross_grid_inner_reduction && !rparams->fastest_dim; + + // Propagate transformations for inner reduction. + // Two steps are used since tma tvs are scheduled differently. + // Step-1, propagate iteration domain in inner reduction. + // Step-2, propagate reduction domain in inner reduction. + if (rparams->tma_warp_specialized) { + // Find the axis that splits the reduction domain and iteration domain. + int first_redu_axis = -1; + int n_dims = (int)inner_reference_tv->nDims(); + for (auto i = 0; i < n_dims; i++) { + if (inner_reference_tv->axis(i)->isReduction() || + inner_reference_tv->axis(i)->isRFactorProduct()) { + first_redu_axis = i; + break; + } + } + + // Step-1, propagate iteration domain in inner reduction. + // outer_reference_tvs are excluded since they are already scheduled + // with a different pattern for the final step of outer reduciton. + if (first_redu_axis > 0) { + TransformPropagator propagator(inner_reference_tv, first_redu_axis - 1); + std::vector all_tvs_except = ir_utils::allTvsExcept( + fusion, {outer_reference_tvs.begin(), outer_reference_tvs.end()}); + SetSelector selector({all_tvs_except.begin(), all_tvs_except.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } + + // Step-2, propagate reduction domain in inner reduction. + // (a) Tvs in boundaryNodesSet are excluded since they should follow outer + // reduction pattern. + // (b) TMA tvs are excluded since they require special scheduling. + // (3) Excluding tma tvs breaks the propagation path from inner reduction tv + // to cached_gmem which stores the results of the first-stage of outer + // reduction. The solution is adding a dummy output to link them. The same + // trick is used when projecting persistent buffers to inputs. + auto inner_reduction_input = + ir_utils::getSoleProducerTv(inner_reference_tv); + for (auto tv : cached_gmem) { + // T1(smem) --> T2 (l) --> T3 = OuterRedu(T2) --> T4(cached_gmem) + // outer_reduction_input: T2 + // partial_outer_redu_tv: T3 + auto partial_outer_redu_tv = ir_utils::getSoleProducerTv(tv); + auto outer_reduction_input = + ir_utils::getSoleProducerTv(partial_outer_redu_tv); + auto dummy_output = add(inner_reduction_input, outer_reduction_input); + fusion->addOutput(dummy_output); + dummy_outputs.emplace_back(dummy_output); + } + + // Tvs requiring special scheduling + std::unordered_set special_tvs{ + tma_load_tvs.begin(), tma_load_tvs.end()}; + for (auto tv : boundaryNodesSet) { + if (special_tvs.count(tv) == 0) { + special_tvs.emplace(tv); + } + } + TransformPropagator propagator(inner_reference_tv); + std::vector all_tvs_except_cache = ir_utils::allTvsExcept( + fusion, {special_tvs.begin(), special_tvs.end()}); + SetSelector selector( + {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } else { + reduction_scheduler_utils::propagateTransformation( + inner_reference_tv, boundaryNodesSet); + } + reduction_scheduler_utils::propagateRFactor( + inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); + + // parallelization propagation + const auto& selected_tvs_inner = + scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); + reduction_scheduler_utils::propagateParallelization( + inner_reduction_tvs[0], + inner_reference_tv, + is_unroll_or_vectorization, + is_outer_grid_persistence, + inner_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_inner.begin(), selected_tvs_inner.end()}); + + // Propagate outer reduction. Each outer reduction is connected with its + // cached_gmem and output, since we added all the cached_gmem to the + // boundaryNodesSet, the transformation from one outer reduction can't + // propagate to other outer reductions due to the cutoff at + // boundaryNodesSet. Thus, we need a loop to initiate the propagation from + // each outer reduction. Don't allow parallelization propagation goes + // through cached_gmem, see issue 246. + for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { + const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( + {outer_reduction_tvs[i]}, {cached_gmem[i]}); + reduction_scheduler_utils::propagateTransformation( + outer_reference_tvs[i], boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + outer_reference_tvs[i], + is_vectorize, + cached_inputs, + cached_outputs); + reduction_scheduler_utils::propagateParallelization( + outer_reduction_tvs[i], + outer_reference_tvs[i], + is_unroll_or_vectorization, + is_outer_grid_persistence, + outer_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_outer.begin(), selected_tvs_outer.end()}); + } + + // Up to this point, the outer dimension of the TMA tv is scheduled + // the same way as the inner reduction tv. However, the inner dimension + // has not been scheduled yet. Since 1D TMA allows unrestricted load size, + // we can simply parallelize the entire inner dimension using bulk. + // Example: 2D tensor, [BIDy, S, | Bulk] + // Example: 1D tensor, [Bulk] + if (rparams->tma_warp_specialized) { + for (auto tv : tma_load_tvs) { + tv->axis(-1)->parallelize(ParallelType::Bulk); + } + } + + // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write + // is guaranteed to be smaller or equal to input vectorization factor. + if (rparams->vectorization_factor_tmp_gmem_write > 1) { + for (auto tv : cached_gmem) { + NVF_ERROR( + rparams->vectorization_factor_tmp_gmem_write <= + rparams->unroll_factor_inner_reduction, + "vectorization factor of temp gmem write should be smaller than that of inner reduction.") + if (rparams->vectorization_factor_tmp_gmem_write < + rparams->unroll_factor_inner_reduction) { + tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); + } + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + } + // vectorization propagate through propagateParallelization only works for + // input and output tensors. propagate vectorization to cached_gmem_reload + // directly from output tv using parallelizeAllLike. must propagate + // seperaely for different tvs as outer reductions are transformed + // seperately. + if (rparams->vectorization_factor_outer > 1) { + for (auto tv : cached_gmem_reload) { + auto output_tvs = ir_utils::outputTvsOf(tv); + NVF_ERROR( + !output_tvs.empty(), + "cached_gmem_reload should have at least one output tensor.") + scheduler_utils::parallelizeAllLike( + output_tvs[0], + -1, + {cached_gmem_reload.begin(), cached_gmem_reload.end()}, + {ParallelType::Vectorize}); + } + } + + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + + // Remove dummy outputs as they can inadvertently affect CA positions + for (auto output : dummy_outputs) { + fusion->removeOutput(output); + } + inlineMost(); +} +} // namespace inner_outer_tma_warp_specialized +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_tma_ws.h b/csrc/scheduler/normalization_inner_outer_tma_ws.h new file mode 100644 index 00000000000..f3d05586508 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_tma_ws.h @@ -0,0 +1,31 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +namespace inner_outer_tma_warp_specialized { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type); + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams); +} // namespace inner_outer_tma_warp_specialized +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_utils.cpp b/csrc/scheduler/normalization_inner_outer_utils.cpp new file mode 100644 index 00000000000..bcaaa3db131 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_utils.cpp @@ -0,0 +1,301 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { +namespace inner_outer_utils { + +int64_t roundUpSharedMemory( + int64_t tv_buffer_size, + int64_t data_type_size, + int64_t vectorize_factor, + int64_t threads_per_block_min, + int64_t threads_per_block_max, + int64_t threads_per_block_step) { + int64_t dim_size = tv_buffer_size / data_type_size; + int64_t after_vect = dim_size / vectorize_factor; + int64_t max_smem = 0; + for (int64_t threads_per_block = threads_per_block_min; + threads_per_block <= threads_per_block_max; + threads_per_block += threads_per_block_step) { + int64_t n_batch = ceilDiv(after_vect, threads_per_block); + max_smem = std::max( + max_smem, + n_batch * vectorize_factor * threads_per_block * data_type_size); + } + return max_smem; +} + +std::vector getOuterBroadcastTvs( + Fusion* fusion, + const std::vector& reduction_tvs) { + // set reference broadcast mask using the first inner reduction tv + std::vector ref_broadcast_mask; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + const auto& logical = tv->getLogicalDomain(); + ref_broadcast_mask.reserve(logical.size()); + for (const auto i : arange(logical.size())) { + ref_broadcast_mask.push_back(!logical.at(i)->isReduction()); + } + break; + } + } + NVF_ERROR(!ref_broadcast_mask.empty(), "ref_broadcast_mask is empty!"); + + // find the broadcast tensor whose broadcast mask is same to the reference + std::vector outer_broadcast_tvs; + for (auto tv : fusion->allTvs()) { + if (std::any_of( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + [](IterDomain* id) { return id->isBroadcast(); })) { + if (auto bcast = dynamic_cast(tv->definition())) { + if (bcast->getBroadcastDimFlags() == ref_broadcast_mask) { + outer_broadcast_tvs.emplace_back(tv); + } + } + } + } + return outer_broadcast_tvs; +} + +int64_t partialOuterReductionBufferSize( + const std::vector& reduction_tvs, + SchedulerRuntimeInfo& runtime_info) { + int64_t partial_reduction_buffer_size = 0; + for (auto buffer : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(buffer)) { + continue; + } + int64_t buffer_size = -1; + for (auto id : buffer->getLogicalDomain()) { + if (id->isReduction() || id->isBroadcast()) { + continue; + } + auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent()); + NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size."); + if (buffer_size == -1) { + buffer_size = id_size.as(); + } else { + buffer_size *= id_size.as(); + } + } + buffer_size = (buffer_size == -1) ? 0 + : buffer_size * + (int64_t)dataTypeSize(buffer->getDataType().value(), + runtime_info.getIndexType()); + partial_reduction_buffer_size += buffer_size; + } + return partial_reduction_buffer_size; +} + +std::vector sortProjectableBufferInputs( + const std::vector& projectable_buffer_inputs, + const std::vector& outer_broadcast_tvs) { + // mark whether the buffer is used by outer broadcast tensors + std::unordered_map is_used_by_outer_bcast; + for (auto buffer : projectable_buffer_inputs) { + is_used_by_outer_bcast[buffer] = std::any_of( + outer_broadcast_tvs.begin(), + outer_broadcast_tvs.end(), + [&buffer](TensorView* tv) { + return DependencyCheck::isDependencyOf(buffer, tv); + }); + } + + // sort based on [is_used_by_outer_bcast] + std::vector sorted_buffer = projectable_buffer_inputs; + std::sort( + sorted_buffer.begin(), + sorted_buffer.end(), + [&](TensorView* a, TensorView* b) { + return !is_used_by_outer_bcast[a] && is_used_by_outer_bcast[b]; + }); + return sorted_buffer; +} + +PersistentBufferStorageParams getPersistentBufferStorageParams( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache, + const std::vector& reduction_tvs, + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max) { + FUSER_PERF_SCOPE( + "normalization_inner_outer::getPersistentBufferStorageParams"); + + PersistentBufferStorageParams buffer_params; + + auto persistent_buffer_info_entry = + HeuristicDataCacheEntry( + data_cache, [&fusion]() { + return std::make_unique( + scheduler_utils::persistentBuffers(fusion)); + }); + + auto& persistent_buffer_info = persistent_buffer_info_entry.get(); + + auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( + fusion, runtime_info, persistent_buffer_info, data_cache); + + // Project to inputs when there is at least one outer broadcast tensor or + // projected persistent buffer size is smaller. When projecting to inputs, the + // outer broadcast tensor is reused in the loop over the iteration dimension, + // test shows it is faster than the non-projected version which requires + // reload from gmem for each iteration. + // Note: in current use cases (layer norm bwd and RMS norm bwd), there are + // outer broadcast tvs and always project to inputs. + // Warp specialized persistent kernel always cache inputs in shared memory, + // should project to inputs. + const auto& outer_broadcast_tvs = getOuterBroadcastTvs(fusion, reduction_tvs); + bool skip_check_buffer_size = !outer_broadcast_tvs.empty() || + isOptionEnabled(EnableOption::WarpSpecializedNormalization); + normalization_scheduler_utils::BufferProjectionStrategy project_strategy = + normalization_scheduler_utils::isProjectBufferToInputs( + fusion, + runtime_info, + reduction_tvs, + persistent_buffer_info, + persistent_buffer_size_info, + InnerOuterPersistentKernelScheduler::schedulerType(), + /*can_use_smem_persistent=*/true, + !skip_check_buffer_size); + + buffer_params.project_to_input = + (project_strategy == + normalization_scheduler_utils::BufferProjectionStrategy:: + ProjectToInputs); + + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( + fusion, reduction_tvs, threads_per_block_max); + int64_t available_smem = + (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; + int64_t available_regs = scheduler_utils::register_file_size_56k; + buffer_params.smem_overhead = smem_overhead; + + // (1) Use both register and shared memory. + // Start with all the cached input buffers in shared memory, they are loaded + // from global memory uses async copy which bypasses L1 cache. Outer reduction + // buffers are used to accumulate partial results of the outer reduction. They + // are not loaded from global memory and requires frequent read/write. So, + // they are always stored in registers. + // TODO: We may also move outer reduction buffers to shared + // memory to avoid segmentation when there are many outer reductions and + // hardware has larger shared memory, but these applications are rare, so this + // is not considered here. + auto buffers = buffer_params.project_to_input + ? persistent_buffer_info.projectable_buffer_inputs + : persistent_buffer_info.persistent_buffers; + + // Add buffers that are inputs to the fusion. They are not included in + // projectable_buffer_inputs since they are not projectable. + if (buffer_params.project_to_input) { + for (auto tv : persistent_buffer_info.persistent_buffers) { + if (tv->isFusionInput()) { + buffers.push_back(tv); + } + } + } + + // Needs to use rounded shared memory size to avoid over usage. + // key : buffer tv. + // val : register size and rounded shared memory size + std::unordered_map> + required_size_regs_smem_map; + int64_t total_smem_buffer_size = 0; + for (auto buffer : buffers) { + int64_t buffer_size_regs = scheduler_utils::getPersistentBufferSizeOfTensor( + buffer, runtime_info, persistent_buffer_info); + int64_t buffer_size_smem = roundUpSharedMemory( + buffer_size_regs, + dataTypeSize(buffer->getDataType().value()), + vectorize_factor, + threads_per_block_min, + threads_per_block_max, + dev_prop->warpSize); + required_size_regs_smem_map[buffer] = + std::make_pair(buffer_size_regs, buffer_size_smem); + total_smem_buffer_size += buffer_size_smem; + } + buffer_params.smem_buffer_size = total_smem_buffer_size; + buffer_params.regs_buffer_size = + partialOuterReductionBufferSize(reduction_tvs, runtime_info); + if (buffer_params.regs_buffer_size <= available_regs && + buffer_params.smem_buffer_size <= available_smem) { + buffer_params.smem_persistent_buffers = buffers; + buffer_params.has_enough_regs_and_smem = true; + return buffer_params; + } + + // Moving outer reduction buffer to shared memory is not considered yet, + // set to false if the outer reduction buffer size exceeds the register size. + if (buffer_params.regs_buffer_size > available_regs) { + buffer_params.has_enough_regs_and_smem = false; + return buffer_params; + } + + // (2) Now, shared memory is overused, move some buffers to registers. + // (2.1) Sort the candidate persistent buffers. No need to sort since the + // sorting is based on whether the buffer is used by outer broadcast tensors. + if (!outer_broadcast_tvs.empty()) { + buffers = sortProjectableBufferInputs(buffers, outer_broadcast_tvs); + } + // (2.2) Before this loop, all cached input buffers are in shared memory. Move + // buffer from shared memory to register. + int64_t n_regs_buffer = -1; + const int n_buffers = (int)buffers.size(); + for (int i = 0; i < n_buffers; i++) { + auto current_tv = buffers[i]; + auto [buffer_size_regs, buffer_size_smem] = + required_size_regs_smem_map.at(current_tv); + buffer_params.regs_buffer_size += buffer_size_regs; + buffer_params.smem_buffer_size -= buffer_size_smem; + + // The first-i buffers to are moved from shared memory to register + // If both the register buffer size and shared memory buffer size are within + // the allowable limit, we found a good configuration. + if (buffer_params.regs_buffer_size <= available_regs && + buffer_params.smem_buffer_size <= available_smem) { + n_regs_buffer = i + 1; + break; + } + // Register buffer size exceeds the limit, can't move more to registers. + // Break the loop. + if (buffer_params.regs_buffer_size > available_regs) { + break; + } + } + + // n_regs_buffer > 0 indicats a good configuration is found. + // The first n_regs_buffer buffers are stored in registers and last [n_buffers + // - n_regs_buffer] are stored in shared memory. + if (n_regs_buffer > 0) { + buffer_params.has_enough_regs_and_smem = true; + auto n_smem_buffer = n_buffers - n_regs_buffer; + buffer_params.smem_persistent_buffers.reserve(n_smem_buffer); + for (int i = 0; i < n_smem_buffer; i++) { + buffer_params.smem_persistent_buffers.emplace_back( + buffers[n_buffers - 1 - i]); + } + } else { + buffer_params.has_enough_regs_and_smem = false; + } + return buffer_params; +} + +} // namespace inner_outer_utils +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_utils.h b/csrc/scheduler/normalization_inner_outer_utils.h new file mode 100644 index 00000000000..49b0699a00c --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_utils.h @@ -0,0 +1,98 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +class SchedulerRuntimeInfo; +class HeuristicDataCache; + +namespace inner_outer_utils { +// The roundup is due to the fact that the shared memory buffer is allocated +// as: ceilDiv(dim_size / vectorize_factor, threads_per_block). +// Let after_vect = dim_size / vectorize_factor; +// n_batch = ceilDiv(after_vect, threads_per_block); +// Then the shared memory buffer size is n_batch * vectorize_factor * +// threads_per_block * data_type_size. This function returns the maximum +// possible shared memory buffer size considering all possible block sizes. +int64_t roundUpSharedMemory( + int64_t tv_buffer_size, + int64_t data_type_size, + int64_t vectorize_factor, + int64_t threads_per_block_min, + int64_t threads_per_block_max, + int64_t threads_per_block_step); + +// Return the broadcast tvs that are broadcast to the iteration dimensions of +// the inner reduction tv. These tvs are reused in the loop over the iteration +// dimension. This reuse reduced the number loads from gmem and this tensor +// is likely the first candidate to be moved to shared memory when the register +// space runs low. +std::vector getOuterBroadcastTvs( + Fusion* fusion, + const std::vector& reduction_tvs); + +// Size of buffers storing intermediate outer reduction results +// TODO: check if we can directly start with [buffer_size = 1] +int64_t partialOuterReductionBufferSize( + const std::vector& reduction_tvs, + SchedulerRuntimeInfo& runtime_info); + +// Decide where to store persistent buffers. +// By default, they reside in registers. +// If register space runs low but there's ample shared memory, +// move one or more buffers to shared memory until the register space is +// sufficient. +struct PersistentBufferStorageParams { + // representing buffers that are stored in shared memory, other buffers are + // stored in registers. + std::vector smem_persistent_buffers; + + // Total number of bytes occupied by all persistent buffers stored in shared + // memory. + int64_t smem_buffer_size = -1; + + // Total number of bytes occupied by all persistent buffers stored in + // registers. + int64_t regs_buffer_size = -1; + + // Additional shared memory usage per block that is not associated with + // persistent buffers. This includes memory for driver overhead and workspace + // for reductions. + int64_t smem_overhead = -1; + + // Flag indicating whether there are sufficient registers and shared memory + // available to accommodate all persistent buffers as required for efficient + // execution. + bool has_enough_regs_and_smem = false; + + // Flag indicating whether the persistent buffers are recomputed using inputs. + bool project_to_input = false; +}; +PersistentBufferStorageParams getPersistentBufferStorageParams( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache, + const std::vector& reduction_tvs, + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max); + +// Prioritize keeping buffers used by outer broadcast tensors to shared memory +// because: +// (1) They are reused in every iteration of the outer loop, has lower IO. +// (2) Load occurs before the outer loop. Temporary register usage won't +// increase register pressure since the loop is the high-pressure region. +std::vector sortProjectableBufferInputs( + const std::vector& projectable_buffer_inputs, + const std::vector& outer_broadcast_tvs); + +} // namespace inner_outer_utils +} // namespace nvfuser From 5368ed00b65042f8797401412ea1f2d8a6a4cf5e Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 21 Apr 2025 15:51:52 -0700 Subject: [PATCH 33/68] Clean up multi-GPU python test fixtures (#4284) Move fixtures to conftest.py according to https://docs.pytest.org/en/7.1.x/reference/fixtures.html#conftest-py-sharing-fixtures-across-multiple-files Make setup_default_process_group a common fixture --- .../multidevice/{fixtures.py => conftest.py} | 24 ++++++++++++++++++ .../python/multidevice/test_communication.py | 4 --- tests/python/multidevice/test_deepseek_v3.py | 20 +-------------- tests/python/multidevice/test_dtensor.py | 25 ++----------------- tests/python/multidevice/test_matmul.py | 3 --- tests/python/multidevice/test_multidevice.py | 3 --- tests/python/multidevice/test_overlap.py | 3 --- .../multidevice/test_transformer_engine.py | 15 +---------- 8 files changed, 28 insertions(+), 69 deletions(-) rename tests/python/multidevice/{fixtures.py => conftest.py} (71%) diff --git a/tests/python/multidevice/fixtures.py b/tests/python/multidevice/conftest.py similarity index 71% rename from tests/python/multidevice/fixtures.py rename to tests/python/multidevice/conftest.py index 71da16c2f14..7bd82c3c6ee 100644 --- a/tests/python/multidevice/fixtures.py +++ b/tests/python/multidevice/conftest.py @@ -5,6 +5,7 @@ import nvfuser import pytest import torch +import torch.distributed as dist class MultideviceTest: @@ -65,3 +66,26 @@ def multidevice_test(): yield fixture # Sync all ranks after each test for isolation. fixture.communicator.barrier() + + +# Set up the default process group for torch APIs like +# dist.device_mesh.init_device_mesh. +# +# This fixture is used by multi-GPU tests that use torch.distributed. +# +# I use "session" instead of "module" because +# https://github.com/pytorch/pytorch/issues/119196 reported race conditions +# when reinitializing process groups. +@pytest.fixture(scope="session") +def setup_default_process_group(): + communicator = nvfuser.Communicator.instance() + + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. + dist.init_process_group( + backend="nccl", + init_method="tcp://localhost:29500", + world_size=communicator.size(), + rank=communicator.rank(), + ) + yield + dist.destroy_process_group() diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 5adf06f6882..5b0d097fe8e 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -5,14 +5,10 @@ import pytest import torch -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition -multidevice_test = fixtures.multidevice_test - - @pytest.mark.mpi def test_allgather(multidevice_test): d = multidevice_test.size diff --git a/tests/python/multidevice/test_deepseek_v3.py b/tests/python/multidevice/test_deepseek_v3.py index 0bbfe4d1a75..284563c398a 100644 --- a/tests/python/multidevice/test_deepseek_v3.py +++ b/tests/python/multidevice/test_deepseek_v3.py @@ -2,7 +2,6 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import nvfuser import pytest import transformers import torch @@ -16,23 +15,6 @@ ) -# Set up the default process group for torch APIs like -# dist.device_mesh.init_device_mesh. -@pytest.fixture(scope="module") -def setup_process_group(): - communicator = nvfuser.Communicator.instance() - - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:29500", - world_size=communicator.size(), - rank=communicator.rank(), - ) - yield - dist.destroy_process_group() - - @contextmanager def default_tensor_type(dtype=torch.float32, device="cpu"): # Save @@ -55,7 +37,7 @@ def default_tensor_type(dtype=torch.float32, device="cpu"): # http://nv/eCm). I consider this a one-off, but please let me know if this # error becomes consistent. @pytest.mark.mpi -def test_transformer_layer(setup_process_group): +def test_transformer_layer(setup_default_process_group): config = transformers.AutoConfig.from_pretrained( "deepseek-ai/deepseek-v3", trust_remote_code=True ) diff --git a/tests/python/multidevice/test_dtensor.py b/tests/python/multidevice/test_dtensor.py index f1dfc38f62f..51644b7bc0d 100644 --- a/tests/python/multidevice/test_dtensor.py +++ b/tests/python/multidevice/test_dtensor.py @@ -2,7 +2,6 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import fixtures import nvfuser import pytest import torch @@ -16,26 +15,6 @@ from typing import Callable, cast -multidevice_test = fixtures.multidevice_test - - -# Set up the default process group for torch APIs like -# dist.device_mesh.init_device_mesh. -@pytest.fixture(scope="module") -def setup_process_group(): - communicator = nvfuser.Communicator.instance() - - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:29500", - world_size=communicator.size(), - rank=communicator.rank(), - ) - yield - dist.destroy_process_group() - - class FusionDefinitionWrapper: def __init__(self, define_fusion: Callable[[FusionDefinition], None]): """Wraps a function that defines a fusion without `multidevice_schedule`.""" @@ -98,7 +77,7 @@ def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]: @pytest.mark.mpi -def test_plus_one(setup_process_group, multidevice_test): +def test_plus_one(setup_default_process_group, multidevice_test): def define_fusion(fd: FusionDefinition): inp = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) one = fd.define_scalar(1.0, dtype=DataType.Float) @@ -122,7 +101,7 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi -def test_linear(setup_process_group, multidevice_test): +def test_linear(setup_default_process_group, multidevice_test): @dataclass class LinearConfig: def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index 3002c4e6caf..76ce1939edf 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -5,12 +5,9 @@ import pytest import torch -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition -multidevice_test = fixtures.multidevice_test - # Avoid doing this when possible. This test started to exist before nvFuser # supports DID loop split. As a result of that, the weight in this test has to be diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index dd3c3fb5877..6489a0a7691 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -7,13 +7,10 @@ from enum import Enum, auto from torch.nn.attention import SDPBackend -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition from nvfuser.testing.utils import create_sdpa_rng_tensors, define_sdpa_rng_state -multidevice_test = fixtures.multidevice_test - @pytest.mark.mpi def test_sizes_and_ranks(multidevice_test): diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 0ad770e022c..34850477376 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -6,12 +6,9 @@ import torch import os -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition, CommunicatorBackend -multidevice_test = fixtures.multidevice_test - class OverlapAGMatmulStreamOutermost(FusionDefinition): def __init__(self, m, k, n, s, num_devices, communication_backend): diff --git a/tests/python/multidevice/test_transformer_engine.py b/tests/python/multidevice/test_transformer_engine.py index db3046705be..14110e147c6 100644 --- a/tests/python/multidevice/test_transformer_engine.py +++ b/tests/python/multidevice/test_transformer_engine.py @@ -58,19 +58,6 @@ class Parallelism(Enum): SEQUENCE_PARALLEL = auto() -@pytest.fixture(scope="module") -def setup_process_group(mpi_test) -> None: - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:29500", - world_size=mpi_test.size, - rank=mpi_test.rank, - ) - yield - dist.destroy_process_group() - - # This benchmark is instrumented with cudaProfilerStart/Stop. Therefore, one # can collect stats of the first few non-warmup benchmark iterations using # ```bash @@ -94,7 +81,7 @@ def setup_process_group(mpi_test) -> None: ids=["nonoverlap", "overlap"], ) def test_transformer_layer( - setup_process_group, + setup_default_process_group, monkeypatch, benchmark, compute_type: ComputeType, From fb9b9567044fedc87006ab9ef66284e369385beb Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 21 Apr 2025 16:50:35 -0700 Subject: [PATCH 34/68] Split insertion_info into Pipeline and WarpSpecialized parts (#4275) This PR is a step to separating `CircularBufferInserter` into two separate variants that handle warp specialized and pipeline circular buffering separately. * Change `CircularBufferLoopNestInspector::run` to return pipeline and warp-specialized in two separate `InsertionInfo`. * Process `pipeline` before `warp-specialized` because the circular buffer for-loops must be handled from inner to outer-most and we can only nest pipeline circular buffers inside of warp-specialized `IfThenElse`. * Enforce that only four warp-specialized for-loops occur in the fusion. They will be handled by separate warps inside of a single `AsyncWarp` warp group. --- csrc/device_lower/pass/circular_buffer.cpp | 119 ++++++++++++++++++++- 1 file changed, 115 insertions(+), 4 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index c970283daa2..3683aaba450 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1144,13 +1144,112 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor { bool result_ = false; }; +namespace { + +bool isWarpSpecialized(ForLoop* loop) { + return std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); +} + +} // namespace + // Traverse lowered loop-nests and find all circular buffer loops and // associated load expressions. class CircularBufferLoopNestInspector : private kir::IrVisitor { public: - static InsertionInfo run(const std::vector& exprs) { + static std::pair run( + const std::vector& exprs) { CircularBufferLoopNestInspector inspector(exprs); - return inspector.insertion_info_; + + // InsertionInfo holds all circular buffer for-loops. Split it into warp + // specialized and pipeline circular buffers. Enforce that we can only nest + // pipeline circular buffering inside of warp-specialization. + + // Get WarpSpecialized InsertionInfo + InsertionInfo ws_info; + int64_t inner_most_ws_position = -1; + for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) { + if (!isWarpSpecialized(cb_loop)) { + continue; + } + ws_info[cb_loop] = cb_exprs; + inner_most_ws_position = std::max( + inner_most_ws_position, inspector.loop_position_.at(cb_loop)); + } + + // WarpSpecialized circular buffering pads the thread block size by 128 + // threads. This is to support register sharing, which shares registers from + // four warps to another four warps. Thus, we can have four warps running + // concurrently in AsyncWarp. Each warp can launch an asynchronous operation + // with mbarrier completion mechanism such as TMA Load and Blackwell UTCMMA. + // + // if (Select AsyncWarp) { + // if (Select Warp 0 AND elect-sync()) { + // do-something + // } else if (Select Warp 1 AND elect-sync()) { + // do-something + // } else if (Select Warp 2 AND elect-sync()) { + // do-something + // } else if (Select Warp 3 AND elect-sync()) { + // do-something + // } + // } + NVF_ERROR( + ws_info.size() <= 4, + "At most four for-loops can run concurrently inside the AsyncWarp.\n", + "Detected ", + ws_info.size(), + " WarpSpecialized for-loops."); + + // Get Pipeline InsertionInfo + InsertionInfo pipeline_info; + for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) { + if (isWarpSpecialized(cb_loop)) { + continue; + } + + // An example of WarpSpecialized circular buffer nested in Pipeline + // circular buffer. + // * Register sharing would fail because of the return in the AsyncLoop. + // * This scenario is not actively tested, so prohibit it until a valid + // use-case occurs. + // + // warp-specialized mbarrier init + // for (prologue) { + // load something for Prologue + // } + // + // for (main) { + // load something for Main + // if (AsyncWarp) { + // launch async + // maybe return for register sharing + // } else { + // compute something for ComputeWarp + // } + // compute something for Main + // } + // + // for (epilogue) { + // if (AsyncWarp) { + // launch async + // maybe return for register sharing + // } else { + // compute something + // } + // compute something for Epilogue + // } + // warp-specialized mbarrier inval + NVF_ERROR( + inspector.loop_position_.at(cb_loop) > inner_most_ws_position, + "Warp Specialization cannot be nested in Pipeline circular buffering!"); + pipeline_info[cb_loop] = cb_exprs; + } + + return {ws_info, pipeline_info}; } private: @@ -1186,6 +1285,10 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor { validateCircularBufferLoop(circular_buffer_loop); + auto cb_loop_it = + std::find(for_loops_.begin(), for_loops_.end(), circular_buffer_loop); + loop_position_[circular_buffer_loop] = + std::distance(for_loops_.begin(), cb_loop_it); insertion_info_[circular_buffer_loop].push_back(expr); } @@ -1211,6 +1314,8 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor { loop->toString()); } + // Map circular buffer loop to its position in the for_loop_ stack. + std::unordered_map loop_position_; InsertionInfo insertion_info_; }; @@ -1728,8 +1833,14 @@ kir::TensorIndex* TmaCircularBufferInfo::getTensorIndex(const Expr* expr) { } std::vector CircularBufferPass::run(const std::vector& exprs) { - InsertionInfo insertion_info = CircularBufferLoopNestInspector::run(exprs); - return CircularBufferInserter::run(exprs, insertion_info); + auto&& [ws_insertion_info, pipeline_insertion_info] = + CircularBufferLoopNestInspector::run(exprs); + // Process circular buffer for-loops from inner to outer-most. + // Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized + // inside of Pipeline circular buffering. + std::vector result_exprs = + CircularBufferInserter::run(exprs, pipeline_insertion_info); + return CircularBufferInserter::run(result_exprs, ws_insertion_info); } } // namespace nvfuser From 0b2f5a85cd1ca6f3d41b3abb407361687fb6a756 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Mon, 21 Apr 2025 18:50:52 -0700 Subject: [PATCH 35/68] Create separate CircularBufferInserter for WarpSpecialized and Pipeline (#4280) This PR separates `CircularBufferInserter` into two variants: `WarpSpecializedCircularBufferInserter` and `PipelineCircularBufferInserter`. Stacked on: #4275 This refactor is a foundation for more complex changes required for supporting Blackwell UTCMMA and Ping-Pong WarpSpecialization. Further enhancements are not planned for pipeline circular buffering. --- csrc/device_lower/pass/circular_buffer.cpp | 333 ++++++++++++--------- 1 file changed, 198 insertions(+), 135 deletions(-) diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 3683aaba450..02bb05050ac 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -1334,10 +1334,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set& output) { } } +// Create something like below: +// for (int i = 0; i < prefetch + 1; ++i) { +// mbarrier::arrive(mbarrier0[stage + i]]); +// mbarrier::arrive(mbarrier1[stage + i]); +// ... +// } +// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i. +// +// This is needed because we prefetch data in circular buffering, and we +// need to make sure the initial prefetches are not blocked by the +// non-existing WAR hazards. +ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) { + const auto& opt = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + auto circular_buffer_tvs = + GpuLower::current()->circularBufferInfo().getCircularBufferTvs( + circular_buffer_loop->iter_domain()); + VectorOfUniqueEntries mbarriers; + for (auto tv : circular_buffer_tvs) { + auto ldst = dynamic_cast(tv->definition()); + NVF_ERROR(ldst != nullptr); + auto it = GpuLower::current()->mbarrierMap().find(ldst); + if (it == GpuLower::current()->mbarrierMap().end()) { + continue; + } + mbarriers.pushBack(it->second); + } + auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1); + for (auto mbarrier : mbarriers) { + auto mbarrier_to_arrive = IrBuilder::create( + mbarrier, + SimplifyingIrBuilder::addExpr( + prefetch_loop->indexOrStartIfTrivial(), opt.stage)); + auto prefetch = IrBuilder::create( + /*state=*/nullptr, mbarrier_to_arrive); + prefetch_loop->body().push_back(prefetch); + } + return prefetch_loop; +} + } // namespace -// Apply circular buffering transformations -class CircularBufferInserter : private kir::ExprMutator { +// Apply warp specialized circular buffering transformations +class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { public: // When there exist multiple circular buffer loops, apply // transformations to inner-most loops first. A single ExprMutator @@ -1347,14 +1388,15 @@ class CircularBufferInserter : private kir::ExprMutator { InsertionInfo insertion_info) { std::vector inserted_exprs = exprs; while (!insertion_info.empty()) { - CircularBufferInserter inserter(inserted_exprs, insertion_info); + WarpSpecializedCircularBufferInserter inserter( + inserted_exprs, insertion_info); inserted_exprs = inserter.exprs_; } return inserted_exprs; } private: - CircularBufferInserter( + WarpSpecializedCircularBufferInserter( const std::vector& exprs, InsertionInfo& insertion_info) : insertion_info_(insertion_info) { @@ -1380,143 +1422,24 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - auto has_cp_async_bulk = std::any_of( - it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); - bool use_warp_specialization = std::holds_alternative( GpuLower::current() ->circularBufferInfo() .getCircularBufferOptionsFor(loop->iter_domain()) .type); - if (use_warp_specialization) { - NVF_ERROR( - std::all_of( - it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), - "In order to use warp specialization, all buffers must be loaded by TMA"); - int64_t insertion_position = - GpuLower::current() - ->circularBufferInfo() - .getCircularBufferInsertionPosition(loop->iter_domain()); - insertTmaWarpSpecialized(loop, it->second, insertion_position); - } else if (has_cp_async_bulk) { - insertTmaPipelined(loop, it->second); - } else { - insert(loop, it->second); - } - processed_loop_ = loop; - insertion_info_.erase(loop); - } - - bool hasPrefetch(ForLoop* circular_buffer_loop) { - int64_t prefetch_distance = + NVF_ERROR(use_warp_specialization); + NVF_ERROR( + std::all_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), + "In order to use warp specialization, all buffers must be loaded by TMA"); + int64_t insertion_position = GpuLower::current() ->circularBufferInfo() - .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) - .prefetch; - return prefetch_distance > 0; - } - - // Create something like below: - // for (int i = 0; i < prefetch + 1; ++i) { - // mbarrier::arrive(mbarrier0[stage + i]]); - // mbarrier::arrive(mbarrier1[stage + i]); - // ... - // } - // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i. - // - // This is needed because we prefetch data in circular buffering, and we - // need to make sure the initial prefetches are not blocked by the - // non-existing WAR hazards. - ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) { - const auto& opt = - GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( - circular_buffer_loop->iter_domain()); - auto circular_buffer_tvs = - GpuLower::current()->circularBufferInfo().getCircularBufferTvs( - circular_buffer_loop->iter_domain()); - VectorOfUniqueEntries mbarriers; - for (auto tv : circular_buffer_tvs) { - auto ldst = dynamic_cast(tv->definition()); - NVF_ERROR(ldst != nullptr); - auto it = GpuLower::current()->mbarrierMap().find(ldst); - if (it == GpuLower::current()->mbarrierMap().end()) { - continue; - } - mbarriers.pushBack(it->second); - } - auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1); - for (auto mbarrier : mbarriers) { - auto mbarrier_to_arrive = IrBuilder::create( - mbarrier, - SimplifyingIrBuilder::addExpr( - prefetch_loop->indexOrStartIfTrivial(), opt.stage)); - auto prefetch = IrBuilder::create( - /*state=*/nullptr, mbarrier_to_arrive); - prefetch_loop->body().push_back(prefetch); - } - return prefetch_loop; - } - - static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) { - return GpuLower::current() - ->circularBufferInfo() - .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) - .usesMBarrierForWAR(); - } - - void insertTmaPipelined( - ForLoop* circular_buffer_loop, - const std::vector& loads) { - // Arrive on the WAR mbarriers to let the prefetching start. - if (usesMBarrierForWAR(circular_buffer_loop)) { - auto prefetch_loop = createArrivesForWar(circular_buffer_loop); - registerInsertBefore(circular_buffer_loop, prefetch_loop); - } - - // Prologue loop: - // - launch only - // - arrive_expect_tx and tma load operations - if (hasPrefetch(circular_buffer_loop)) { - // If there is no prefetch, then we don't need a prologue loop. - ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Prolog, - /*insertion_position=*/1); - registerInsertBefore(circular_buffer_loop, prologue_loop); - } - - // Main loop: - // - Launch and wait - // - arrive_expect_tx, tma load operations, and mbarrier_wait - ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Main, - /*insertion_position=*/1); - registerReplace(circular_buffer_loop, main_loop); + .getCircularBufferInsertionPosition(loop->iter_domain()); + insertTmaWarpSpecialized(loop, it->second, insertion_position); - if (!hasPrefetch(circular_buffer_loop)) { - // If there is no prefetch, then we don't need a epilogue loop. - return; - } - - // We can use exclude argument in - // CloneTmaCircularBufferLoopAndInsertSync clone to avoid - // duplicating allocations if main loop is trivial. - std::unordered_set expressions_allocated_in_main_loop; - getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop); - - // Epilogue loop: - // - wait only - // - mbarrier_wait - ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Epilog, - /*insertion_position=*/1, - expressions_allocated_in_main_loop); - registerInsertAfter(circular_buffer_loop, epilogue_loop); + processed_loop_ = loop; + insertion_info_.erase(loop); } void insertTmaWarpSpecialized( @@ -1610,6 +1533,145 @@ class CircularBufferInserter : private kir::ExprMutator { registerReplace(circular_buffer_loop, warp_dispatch_ite); } + private: + InsertionInfo& insertion_info_; + ForLoop* processed_loop_ = nullptr; +}; + +// Apply pipeline circular buffering transformations +class PipelineCircularBufferInserter : private kir::ExprMutator { + public: + // When there exist multiple circular buffer loops, apply + // transformations to inner-most loops first. A single ExprMutator + // pass can only process one loop. + static std::vector run( + const std::vector& exprs, + InsertionInfo insertion_info) { + std::vector inserted_exprs = exprs; + while (!insertion_info.empty()) { + PipelineCircularBufferInserter inserter(inserted_exprs, insertion_info); + inserted_exprs = inserter.exprs_; + } + return inserted_exprs; + } + + private: + PipelineCircularBufferInserter( + const std::vector& exprs, + InsertionInfo& insertion_info) + : insertion_info_(insertion_info) { + size_t num_circular_buffer_loops = insertion_info.size(); + traverseAndInsert(exprs); + NVF_ERROR(processed_loop_ != nullptr); + NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1); + } + + using kir::ExprMutator::handle; + + void handle(ForLoop* loop) final { + kir::ExprMutator::handle(loop); + + // If another loop is already taken care of, no more loop should + // be done in the same pass + if (processed_loop_ != nullptr) { + return; + } + + auto it = insertion_info_.find(loop); + if (it == insertion_info_.end()) { + return; + } + + bool use_warp_specialization = std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); + NVF_ERROR(!use_warp_specialization); + + auto has_cp_async_bulk = std::any_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); + if (has_cp_async_bulk) { + insertTmaPipelined(loop, it->second); + } else { + insert(loop, it->second); + } + + processed_loop_ = loop; + insertion_info_.erase(loop); + } + + bool hasPrefetch(ForLoop* circular_buffer_loop) { + int64_t prefetch_distance = + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) + .prefetch; + return prefetch_distance > 0; + } + + static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) { + return GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) + .usesMBarrierForWAR(); + } + + void insertTmaPipelined( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + // Arrive on the WAR mbarriers to let the prefetching start. + if (usesMBarrierForWAR(circular_buffer_loop)) { + auto prefetch_loop = createArrivesForWar(circular_buffer_loop); + registerInsertBefore(circular_buffer_loop, prefetch_loop); + } + + // Prologue loop: + // - launch only + // - arrive_expect_tx and tma load operations + if (hasPrefetch(circular_buffer_loop)) { + // If there is no prefetch, then we don't need a prologue loop. + ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Prolog, + /*insertion_position=*/1); + registerInsertBefore(circular_buffer_loop, prologue_loop); + } + + // Main loop: + // - Launch and wait + // - arrive_expect_tx, tma load operations, and mbarrier_wait + ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Main, + /*insertion_position=*/1); + registerReplace(circular_buffer_loop, main_loop); + + if (!hasPrefetch(circular_buffer_loop)) { + // If there is no prefetch, then we don't need a epilogue loop. + return; + } + + // We can use exclude argument in + // CloneTmaCircularBufferLoopAndInsertSync clone to avoid + // duplicating allocations if main loop is trivial. + std::unordered_set expressions_allocated_in_main_loop; + getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop); + + // Epilogue loop: + // - wait only + // - mbarrier_wait + ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Epilog, + /*insertion_position=*/1, + expressions_allocated_in_main_loop); + registerInsertAfter(circular_buffer_loop, epilogue_loop); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { NVF_ERROR( !usesMBarrierForWAR(circular_buffer_loop), @@ -1839,8 +1901,9 @@ std::vector CircularBufferPass::run(const std::vector& exprs) { // Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized // inside of Pipeline circular buffering. std::vector result_exprs = - CircularBufferInserter::run(exprs, pipeline_insertion_info); - return CircularBufferInserter::run(result_exprs, ws_insertion_info); + PipelineCircularBufferInserter::run(exprs, pipeline_insertion_info); + return WarpSpecializedCircularBufferInserter::run( + result_exprs, ws_insertion_info); } } // namespace nvfuser From e697ec9291601a4b937ff8776e628402a3ed1f99 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 22 Apr 2025 09:20:27 -0700 Subject: [PATCH 36/68] Add mutex guard to protect data race on options. (#4287) Credit goes to @csarofeen for finding the segfault cause by https://github.com/NVIDIA/Fuser/blob/5368ed00b65042f8797401412ea1f2d8a6a4cf5e/csrc/id_model/utils.h#L163-L166 We have observed a reliable segfault coming from `MovePadTest.CascadePadCase2`. This PR preserves the existing behavior, by adding a mutex to guard data access on global options to protect us from racing. An alternative is to make options thread_local, but we needed to figure out how to inherit existing options on the forked threads. I added a note on that. --- csrc/options.cpp | 9 +++++---- csrc/options.h | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index 391919cc825..be72af22898 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -261,10 +261,11 @@ std::unordered_map> Options< namespace { -// These may need to be thread local, or their modifications may need to -// be protected by mutual exclusion for thread safety. At this -// moment, the correctness of modifying option values has to be -// guaranteed by the modifying code. +// Note: Make options thread_local. +// We want the behavior that new threads would inherit options from the *base* +// threads. We need to figure out how to automatically do that before switching +// to thread_local. For now we are using mutex to guard option access, which is +// necessary to avoid data racing. DebugDumpOptions active_dump_options; diff --git a/csrc/options.h b/csrc/options.h index c782a8cea14..b050a0a0199 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -180,16 +181,31 @@ class Options { public: Options() : options_(getOptionsFromEnv()) {} + Options(const Options& other) { + std::lock_guard lock_other(other.mutex_); + options_ = other.options_; + } + + Options& operator=(const Options& other) { + std::lock_guard lock_other(other.mutex_); + std::lock_guard lock(mutex_); + options_ = other.options_; + return *this; + } + bool has(OptionEnum option) const { + std::lock_guard lock(mutex_); return options_.count(option); } bool hasAny() const { + std::lock_guard lock(mutex_); return !options_.empty(); } const std::vector& getArgs(OptionEnum option) const { NVF_ERROR(has(option), "Option not set"); + std::lock_guard lock(mutex_); return options_.at(option); } @@ -202,10 +218,12 @@ class Options { } void set(OptionEnum option_type, std::vector option = {}) { + std::lock_guard lock(mutex_); options_[option_type] = option; } void unset(OptionEnum option_type) { + std::lock_guard lock(mutex_); options_.erase(option_type); } @@ -214,6 +232,7 @@ class Options { protected: std::unordered_map> options_; + mutable std::mutex mutex_; }; //! Utility class to temporarily overrride the Enable options, From 5f9cfb0792e2024dc75204f5041a28cc1527af69 Mon Sep 17 00:00:00 2001 From: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com> Date: Wed, 23 Apr 2025 00:52:55 -0400 Subject: [PATCH 37/68] fix register spills in thread local outer reduction (#4184) **Issue:** For outer reduction with small reduction dim, only thread local serial reduction is used. The reduction tensor is transformed as: `T3_l_float[iblockIdx.x17{..., iV14{4}, rS11{( ceilDiv(i0, 8) )}, rUS12{1}, rUR10{8}]` where `8` is the outer unroll factor and `4` is the inner vectorization factor. Then it is reordered as: `T3_l_float[..., rS11{( ceilDiv(i0, 8) )}, iUS18{1}, iV14{4}, rUS12{1}, rUR10{8}]` The required register array is `Array T2;` However, when `i0 is const`, e.g. `i0/2 = 16/8 = 2`, the tv is reordered as: `T3_l_float[..., iUS18{1}, iV14{4}, rS11{2}, rUS12{1}, rUR10{8}]` The required register array is `Array T2;` **Changes** In this PR, code is added to look for reduction tvs with `all reduction dimensions are constants and not parallelized by threads or blocks`, these tvs require a large register array size of `vect factor x unroll factor x serial loops`. To reduce register usage, the loop domain is further reordered as: `T3_l_float[..., iUS18{1}, rS11{2}, iV14{4}, rUS12{1}, rUR10{8}]`. **Influence** Fix #4172, no register spills. --- csrc/scheduler/reduction_utils.cpp | 45 +++++++++++++++++++++++-- csrc/scheduler/reduction_utils.h | 4 ++- tests/cpp/test_gpu_outer_reduction.cpp | 46 +++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 5 deletions(-) diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index a8c7652d196..96c7f8650b3 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -303,8 +303,10 @@ TensorView* scheduleReductionTV( } } } - - auto reduction_rf_tv = sortAndRFactor(reduction_tv); + const bool is_non_persistent_outer_reduction = + !rparams->persistent_kernel && !rparams->fastest_dim; + auto reduction_rf_tv = + sortAndRFactor(reduction_tv, is_non_persistent_outer_reduction); // In the case of outer grid persistence, make sure the vectorized // domain placed at the innermost position. @@ -647,7 +649,9 @@ bool placedBefore(const IterDomain* id0, const IterDomain* id1) { } } // namespace -TensorView* sortAndRFactor(TensorView* reference_tv) { +TensorView* sortAndRFactor( + TensorView* reference_tv, + bool is_non_persistent_outer_reduction) { auto domain = reference_tv->getLoopDomain(); std::sort(domain.begin(), domain.end(), placedBefore); std::unordered_map reorder_map; @@ -659,6 +663,41 @@ TensorView* sortAndRFactor(TensorView* reference_tv) { reorder_map[old_i] = domain_pos.at(reference_tv->axis(old_i)); } reference_tv->reorder(reorder_map); + // For outer reduction, if an Id after vectorization Id is a constant + // serial Id, swap it with the vectorization Id to reduce register usage. + // For example, in a thread-local outer reduction, we want to transform: + // [..., iV{8}, rS{7}, rUS{1}, rUR{4}] + // to: + // [..., rS{7}, iV{8}, rUS{1}, rUR{4}] + // After change, each thread only needs to cache 8 × 4 elements instead of + // 8 × 7 × 4 elements. + // See https://github.com/NVIDIA/Fuser/issues/4172 for real examples. + if (is_non_persistent_outer_reduction) { + auto vect_iter = + std::find_if(domain.begin(), domain.end(), [](IterDomain* id) { + return id->getParallelType() == ParallelType::Vectorize; + }); + if (vect_iter != domain.end()) { + int64_t vect_id_pos = vect_iter - domain.begin(); + std::unordered_map reorder_map; + for (auto iter = vect_iter + 1; iter != domain.end(); iter++) { + if ((*iter)->getParallelType() == ParallelType::Serial && + (*iter)->extent()->isConstScalar()) { + int64_t id_pos = iter - domain.begin(); + reorder_map[id_pos] = vect_id_pos++; + } + } + // Although we support reordering multiple constant serial IDs after the + // vectorization ID, the current scheduler only emits one. It may be worth + // exploring performance implications if multiple such IDs are introduced + // in the future. + NVF_ERROR( + reorder_map.size() <= 1, + "Expect one constant serial Id after vectorization Id, but found ", + reorder_map.size()); + reference_tv->reorder(reorder_map); + } + } std::vector rfactor_axes; std::vector rfactor_axes_no_unswitch; diff --git a/csrc/scheduler/reduction_utils.h b/csrc/scheduler/reduction_utils.h index 9985ad83ee1..abb8253f636 100644 --- a/csrc/scheduler/reduction_utils.h +++ b/csrc/scheduler/reduction_utils.h @@ -101,7 +101,9 @@ NVF_API void propagateParallelization( // Rfactored axes are reductions bound to grid or blocks. If no axes are bound // to a grid or block dimension it will rfactor the r-unswitch dimension. // Reduction inliner expects an rfactored domain. -NVF_API TensorView* sortAndRFactor(TensorView* reference_tv); +NVF_API TensorView* sortAndRFactor( + TensorView* reference_tv, + const bool is_non_persistent_outer_reduction = false); // If project_to_inputs is true, take all projectable persistent buffers, // and move them to the inputs. Otherwise, try to project to their immediate diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index e4f8c0cc425..7646db171c0 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -2559,7 +2559,7 @@ TEST_F(OuterReductionTest, IterGroupedMultipleReductions) { } // Repro of https://github.com/NVIDIA/Fuser/pull/2766 -TEST_F(NVFuserTest, SmallOuterBlockReductionIssue2766) { +TEST_F(OuterReductionTest, SmallOuterBlockReductionIssue2766) { std::unique_ptr fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(&fusion); @@ -2595,4 +2595,48 @@ TEST_F(NVFuserTest, SmallOuterBlockReductionIssue2766) { testValidate(executor_cache.fusion(), outputs, args, __LINE__, __FILE__); } +TEST_F(OuterReductionTest, SimpleThreadLocalSerialReduction) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + std::vector shape{28, 8192, 128}; + + auto T0 = makeContigConcreteTensor(shape, DataType::BFloat16); + fusion.addInput(T0); + auto T1 = castOp(DataType::Float, T0); + auto T2 = sum(T1, {0}); + fusion.addOutput(T2); + + auto fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto at_t0 = at::randn(shape, options); + KernelArgumentHolder args = {at_t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(args); + + // If thread local reduction is used on the tested GPU, the reduction tv + // should be: [..., rS{7}, iV{x}, rUS{1}, rUR{x}] + auto runtime = executor_cache.getMostRecentKernelRuntime(); + for (auto& params : runtime->schedulerHeuristics()->heuristicsList()) { + if (!params->isA()) { + continue; + } + if (!params->as()->cross_block_inner_reduction) { + Fusion* scheduled_fusion = runtime->executors() + .back() + ->as() + ->compiledKernel() + ->kernel(); + auto redu_tv = scheduler_utils::getReductionTvs(scheduled_fusion).at(0); + EXPECT_TRUE(redu_tv->axis(-4)->isReduction()) + << "Expected redu tv is [..., rS{7}, iV{x}, rUS{1}, rUR{x}], got: " + << redu_tv->toString(); + } + } + + testValidate(&fusion_copy, outputs, {at_t0}, __LINE__, __FILE__); +} + } // namespace nvfuser From 24a5cc9d6f5cb432cdaa8c0d3bb9202928403a94 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 22 Apr 2025 22:37:20 -0700 Subject: [PATCH 38/68] Prefer static local to static global (#4289) https://google.github.io/styleguide/cppguide.html#Static_and_Global_Variables --- csrc/options.cpp | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/csrc/options.cpp b/csrc/options.cpp index be72af22898..33610d5b8fc 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -259,41 +259,32 @@ std::unordered_map> Options< return options; } -namespace { - -// Note: Make options thread_local. -// We want the behavior that new threads would inherit options from the *base* -// threads. We need to figure out how to automatically do that before switching -// to thread_local. For now we are using mutex to guard option access, which is -// necessary to avoid data racing. - -DebugDumpOptions active_dump_options; - -EnableOptions active_enable_options; - -DisableOptions active_disable_options; - -ProfilerOptions active_profiler_options; - -} // namespace - template <> Options& OptionsGuard::getCurOptions() { + // Note: Make options thread_local. + // We want the behavior that new threads would inherit options from the *base* + // threads. We need to figure out how to automatically do that before + // switching to thread_local. For now we are using mutex to guard option + // access, which is necessary to avoid data racing. + static DebugDumpOptions active_dump_options; return active_dump_options; } template <> Options& OptionsGuard::getCurOptions() { + static EnableOptions active_enable_options; return active_enable_options; } template <> Options& OptionsGuard::getCurOptions() { + static DisableOptions active_disable_options; return active_disable_options; } template <> Options& OptionsGuard::getCurOptions() { + static ProfilerOptions active_profiler_options; return active_profiler_options; } From 096b681fa68ce575aa44eb4d9086618016b7923b Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 23 Apr 2025 00:09:27 -0700 Subject: [PATCH 39/68] Move `MarkAliasAnalysisPreparePass` before `propagateShardingsPass` (#4274) This makes #3838 performance neutral. Benchmarking results on GH200 nodes: On main: ``` Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_forward 6.2744 7.0567 6.4946 0.3369 6.2961 0.4077 1;0 153.9732 5 1 test_transformer_forward 6.2781 7.0573 6.4949 0.3368 6.2962 0.4076 1;0 153.9664 5 1 ------------------------------------------------------------------------------------------------------------------- test_transformer_backward 12.5244 13.7777 13.0152 0.6278 12.5900 1.1082 1;0 76.8331 5 1 test_transformer_backward 12.5348 13.7620 13.0204 0.6094 12.6391 1.0909 1;0 76.8024 5 1 ----------------------------------------------------------------------------------------------------------------------- ``` This branch: ``` Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_forward 6.2889 7.0885 6.5132 0.3481 6.2960 0.4302 1;0 153.5349 5 1 test_transformer_forward 6.2895 7.0262 6.5010 0.3231 6.2963 0.4195 1;0 153.8221 5 1 Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations test_transformer_backward 12.4542 13.6518 12.9532 0.5625 12.6231 0.9795 1;0 77.2012 5 1 test_transformer_backward 12.4778 13.6544 12.9510 0.5641 12.5828 0.9724 1;0 77.2139 5 1 ----------------------------------------------------------------------------------------------------------------------- --- csrc/alias_analysis.cpp | 4 ---- csrc/preseg_passes/pre_segmenter.cpp | 16 ++++++++++------ tests/cpp/test_alias_analysis.cpp | 6 ++++-- tests/cpp/test_multidevice_matmul.cpp | 2 +- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 8bca891e4c4..8fd5da9da77 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -239,10 +239,6 @@ void AliasFinder::handle(const ViewOp* view) { } void AliasFinder::handle(const LoadStoreOp* set) { - if (isResharding(set)) { - return; - } - TensorView* in = dynamic_cast(set->in()); if (in == nullptr) { return; diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 042f03191f7..52d949ad95a 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -39,12 +39,6 @@ namespace nvfuser::preseg_passes { debug() << "========================================" << std::endl; } - // For resharding across GPUs. - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); // This pass should be placed before ConsecutiveCastPass as more @@ -81,6 +75,16 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + + // All the multidevice passes are moved after allocation related passes: + // MarkAliasesPreparePass, and AllocationDomainPass Multidevice passes will + // try to set the allocation domain for tvs with device mesh which will + // conflict with these passes. + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/tests/cpp/test_alias_analysis.cpp b/tests/cpp/test_alias_analysis.cpp index 937167a5517..260172fb4c2 100644 --- a/tests/cpp/test_alias_analysis.cpp +++ b/tests/cpp/test_alias_analysis.cpp @@ -270,7 +270,9 @@ TEST_F(AliasAnalysisTest, BroadcastExpandDimensions) { EXPECT_EQ(analysis.getRoot(expanded_tv), in); } -TEST_F(AliasAnalysisTest, NoAliasForReshardingExprs) { +// See PR: https://github.com/NVIDIA/Fuser/pull/4274 +// for alias analysis for resharding exprs +TEST_F(AliasAnalysisTest, AliasForReshardingExprs) { Fusion fusion; FusionGuard fg(&fusion); @@ -288,7 +290,7 @@ TEST_F(AliasAnalysisTest, NoAliasForReshardingExprs) { fusion.addOutput(out); AliasAnalysisResult analysis = findAliases(&fusion); - EXPECT_TRUE(analysis.getRoot(out) == nullptr); + EXPECT_TRUE(analysis.getRoot(out) == in); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index b8aba89aa86..ee24479a0af 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -238,7 +238,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(3)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { From ab8846a46e66167ce81a8684c23d6ef3baf30177 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 23 Apr 2025 07:26:25 -0700 Subject: [PATCH 40/68] Replace an ad-hoc toposort with stablyOrderedExprs (#4285) --- csrc/runtime/fusion_kernel_runtime.cpp | 53 +------------------------- 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index c57abe4fcca..4afe4279a85 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -199,56 +199,6 @@ flatbuffers::Offset FusionKernelRuntime::serialize( segmented_fusion_fb); } -namespace { -std::vector toposortExprs( - SegmentedFusion* fusion, - SegmentedGroup* group) { - const std::vector& exprs = group->exprs(); - std::vector exprs_to_print(exprs.begin(), exprs.end()); - std::unordered_set exprs_to_print_set(exprs.begin(), exprs.end()); - std::unordered_set exprs_visited; - std::vector sorted_list; - while (!std::all_of( - exprs_to_print.begin(), - exprs_to_print.end(), - [&exprs_visited](auto expr) { return exprs_visited.count(expr); })) { - bool expr_added_to_sorted_list = false; - for (auto expr : exprs_to_print) { - if (!exprs_visited.count(expr)) { - bool add_this_expr = true; - // Check if any of the inputs of current - // expression within the group - // hasn't been visited - for (auto input : expr->inputs()) { - if (input->definition() && - exprs_to_print_set.count(input->definition()) && - !exprs_visited.count(input->definition())) { - add_this_expr = false; - break; - } - } - - // Append the current group to sorted list - // and mark visited - if (add_this_expr) { - expr_added_to_sorted_list = true; - exprs_visited.insert(expr); - sorted_list.push_back(expr); - break; - } - } - } - NVF_ERROR( - expr_added_to_sorted_list, - "group debug print failed, exprs within given vector not a DAG"); - } - NVF_CHECK( - sorted_list.size() == group->exprs().size(), - "Exprs should not have been lost during toposortExprs"); - return sorted_list; -} -} // namespace - void FusionKernelRuntime::deserialize( const serde::FusionKernelRuntime* buffer, int8_t device_index) { @@ -540,8 +490,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { } else { // push back segment's exprs into the container as top level // expressions - for (auto* expr : - toposortExprs(segmented_fusion_.get(), group_to_run)) { + for (auto* expr : group_to_run->stablyOrderedExprs()) { auto cloned_expr = ir_cloner.clone(expr); hic->pushBackTopLevelExprs(cloned_expr); } From d8b8cf4366e89ad69caa435d9c3edd9ba3e4b420 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 23 Apr 2025 09:43:01 -0700 Subject: [PATCH 41/68] check ID coverage for reference_tv in reduction scheduler (#4223) Fixes #3811 Added a compile time check in reduction scheduler to ensure that all output IDs are covered by the reference tv. This ensures that we will not have dangling ID in the loop domain that's not scheduled. Tasks: - [x] verify that the added check doesn't break existing tests/expectations on the capability of reduction scheduler: doesn't seem to produce any code diff in CI. - [x] what about other schedulers? Normalization scheduler should have the same issue. But it's not reproducing the error, because the local TV is fully inlined (compute at == -1). --- csrc/scheduler/reduction.cpp | 13 ++++++++++ csrc/scheduler/tools/domain_map.cpp | 21 +++++++++------- csrc/scheduler/tools/domain_map.h | 2 +- tests/cpp/test_reduction_pointwise.cpp | 35 ++++++++++++++++++++++++++ 4 files changed, 61 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index bf5a7b2e38c..2ca11c97346 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -1670,6 +1671,18 @@ bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + // Reject when output IDs are not covered by reference tv. Assuming reduction + // scheduler simply uses reduction_tvs[0] as the reference, if that changes, + // this needs to be changed. see issue + // https://github.com/NVIDIA/Fuser/issues/3811 + scheduler_tools::DomainMap domain_map(fusion); + if (!domain_map.isValidReference(reduction_tvs[0], /*check_inputs=*/false)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Output contains ID that's not scheduled by reference tv."); + return false; + } + if (registry_utils::hasNonUniqueBcast(fusion)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 87dcbd358ee..f129eb6b78a 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -376,15 +376,18 @@ IterDomain* DomainMap::anyMapped( // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output -bool DomainMap::isValidReference(TensorView* tv) const { - for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { - if (input_tv->uses().empty()) { - continue; - } - // TODO: Same backward traversal from tv is done for all input - // tvs. Consider doing the analysis one for all inputs - if (!areAllInputIdsMappedTo(input_tv, tv)) { - return false; +bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { + if (check_inputs) { + for (auto input_tv : + ir_utils::filterByType(fusion_->inputs())) { + if (input_tv->uses().empty()) { + continue; + } + // TODO: Same backward traversal from tv is done for all input + // tvs. Consider doing the analysis one for all inputs + if (!areAllInputIdsMappedTo(input_tv, tv)) { + return false; + } } } // The check on outputs are optional, transpose scheduler might propose a diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h index 8a8ccb33e91..d6ed2a3a367 100644 --- a/csrc/scheduler/tools/domain_map.h +++ b/csrc/scheduler/tools/domain_map.h @@ -34,7 +34,7 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output. - bool isValidReference(TensorView* tv) const; + bool isValidReference(TensorView* tv, bool check_inputs = true) const; protected: // Determine if all IterDomains are mapped between input and the given tvs diff --git a/tests/cpp/test_reduction_pointwise.cpp b/tests/cpp/test_reduction_pointwise.cpp index 98be573ac83..be711b74fa0 100644 --- a/tests/cpp/test_reduction_pointwise.cpp +++ b/tests/cpp/test_reduction_pointwise.cpp @@ -158,4 +158,39 @@ TEST_F(NVFuserTest, InnerReductionUnrollVectorization) { testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__); } +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalID) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + // tv0 [ b0, i1 ] + auto tv0 = makeContigConcreteTensor({1, -1}); + fusion.addInput(tv0); + // tv1 [ i2, i1 ] + // current scheduler picks tv0 as the reference TV, transformations are + // propagated to other TVs. + auto tv1 = makeContigTensor(2); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {0, 1}); + fusion.addOutput(tv2); + auto tv3 = add(tv0, tv1); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1, 100}, options); + auto t1 = at::randn({5, 100}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser From 13a879c72bf304b6fb13cdd266bad066c71cc6d0 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 23 Apr 2025 11:06:58 -0700 Subject: [PATCH 42/68] Fix bug in stablyOrderedExprs (#4292) --- csrc/fusion_segmenter.cpp | 5 ++++- tests/cpp/test_expr_sort.cpp | 39 ++++++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index bf7693c6028..90041d12375 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -2531,7 +2532,9 @@ std::vector SegmentedGroup::stablyOrderedExprs() const { std::unordered_map num_producers; for (Expr* e : exprs()) { int64_t& n = num_producers[e]; - for (Val* in : e->inputs()) { + // Val::uses(), which is used later to decrement num_producers, contains + // unique `Expr`s. Therefore, it's necessary to also dedup here. + for (auto* in : VectorOfUniqueEntries(e->inputs())) { Expr* def = in->definition(); // Exprs in a SegmentedGroup come from the complete fusion, so the // producer/consumer of an Expr may be outside the group. Therefore, we diff --git a/tests/cpp/test_expr_sort.cpp b/tests/cpp/test_expr_sort.cpp index 81f63adef36..e9f3a0a7670 100644 --- a/tests/cpp/test_expr_sort.cpp +++ b/tests/cpp/test_expr_sort.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -25,6 +26,8 @@ namespace nvfuser { using ExprSortTest = NVFuserTest; using testing::ElementsAre; +using testing::IsTrue; +using testing::Property; using testing::SizeIs; // Indirect normalization pattern with zero-dimensional tensors. Originally @@ -174,7 +177,7 @@ MATCHER_P(UnaryOpTypeIs, unary_op_type, "") { } // namespace -TEST_F(ExprSortTest, SegmentedGroup) { +TEST_F(ExprSortTest, SegmentedGroup_Unary) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -199,8 +202,8 @@ TEST_F(ExprSortTest, SegmentedGroup) { FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); SegmentedFusion* segmented_fusion = runtime->fusionSegments(); ASSERT_THAT(segmented_fusion->groups(), SizeIs(1)); - SegmentedGroup* group = segmented_fusion->groups().front(); + EXPECT_THAT( group->stablyOrderedExprs(), ElementsAre( @@ -209,4 +212,36 @@ TEST_F(ExprSortTest, SegmentedGroup) { UnaryOpTypeIs(UnaryOpType::Cos))); } +TEST_F(ExprSortTest, SegmentedGroup_Binary_SameOperand) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeSymbolicTensor(1); + TensorView* out = neg(in); + out = add(out, out); + + fusion->addInput(in); + fusion->addOutput(out); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + at::Tensor in_tensor = at::randn({5}, options); + auto out_tensors = executor_cache.runFusionWithInputs({in_tensor}); + + testValidate( + executor_cache.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + ASSERT_THAT(segmented_fusion->groups(), SizeIs(1)); + SegmentedGroup* group = segmented_fusion->groups().front(); + + EXPECT_THAT( + group->stablyOrderedExprs(), + ElementsAre( + Property(&Expr::isA, IsTrue()), + Property(&Expr::isA, IsTrue()))); +} + } // namespace nvfuser From 515e65e5945cc5e76aafbd6cd3d599e0c87a4db1 Mon Sep 17 00:00:00 2001 From: Nick Sarkauskas Date: Wed, 23 Apr 2025 15:10:25 -0400 Subject: [PATCH 43/68] Deallocate HostIr Op and Test (#4286) This PR adds the Deallocate HostIr, which erases an Allocation from the expression evaluator. It also modifies LaunchKernel to take in preallocated output arguments. Lastly, it adds a gtest which allocates and deallocates a buffer in a loop, then checks the memory used is 0 bytes. --------- Co-authored-by: Jingyue Wu Co-authored-by: samnordmann --- csrc/dispatch.h | 3 +- csrc/host_ir/executor.cpp | 48 +++++++++++++++++++------- csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 23 ++++++++++++ csrc/host_ir/host_ir.h | 21 +++++++++++ tests/cpp/test_host_ir_integration.cpp | 26 ++++++++++++++ 6 files changed, 109 insertions(+), 13 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index f1f4153d1d2..fed2b39511c 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -160,7 +160,8 @@ class Val; f(Synchronize); \ f(StartCoalescing); \ f(EndCoalescing); \ - f(ShareMemHandles); + f(ShareMemHandles); \ + f(Deallocate); // Forward declarations for all Val and Expr types diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 89710eaae4b..b029f748671 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -329,20 +329,35 @@ void HostIrEvaluator::handle(LaunchKernel* launch_kernel) { for (auto& input : launch_kernel->inputs()) { args.push(getKnownConcreteValue(input)); } + + // If all output buffers are known already, pass them to the executor + KernelArgumentHolder outputs; + bool preallocated_outputs = false; + for (Val* output : launch_kernel->outputs()) { + if (isKnown(output)) { + preallocated_outputs = true; + outputs.push(getKnownConcreteValue(output)); + } + } + + NVF_ERROR( + outputs.empty() || outputs.size() == launch_kernel->outputs().size()); + args.setDeviceIndex(); // run the compiled kernel - KernelArgumentHolder outputs = - container_->getKernelExecutor(launch_kernel->getIndex()) - ->run( - args, - {}, - launch_kernel->launch_params(), - launch_kernel->compile_params()); - - // Store the outputs in the context - for (auto output_idx : arange(outputs.size())) { - bind(launch_kernel->outputs().at(output_idx), outputs[output_idx]); + outputs = container_->getKernelExecutor(launch_kernel->getIndex()) + ->run( + args, + outputs, + launch_kernel->launch_params(), + launch_kernel->compile_params()); + + if (!preallocated_outputs) { + // Store the outputs in the context + for (auto output_idx : arange(outputs.size())) { + bind(launch_kernel->outputs().at(output_idx), outputs[output_idx]); + } } } @@ -637,7 +652,7 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { "Allocation must be on a TensorView but got ", allocate->buffer()); TensorView* tv = allocate->buffer()->as(); - if (expr_evaluator_.isKnown(tv)) { + if (isKnown(tv)) { return; } GlobalBufferInfo info = @@ -654,6 +669,15 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } +void HostIrEvaluator::handle(Deallocate* deallocate) { + auto* tv = deallocate->allocation()->buffer()->as(); + NVF_ERROR( + isKnown(tv), + "Tried to free buffer associated with unknown TensorView", + tv); + invalidate(tv); +} + void HostIrEvaluator::unhandled(Statement* stmt) { NVF_ERROR(stmt->isA(), stmt, " must be an Expr"); auto* expr = stmt->as(); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index d71b74e0dda..c854d2312fc 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -135,6 +135,7 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(LinearOp* linear) override; void handle(kir::Allocate* allocate) override; void handle(ShareMemHandles* share_mem_handles) override; + void handle(Deallocate* deallocate) override; void unhandled(Statement* stmt) override; c10::cuda::CUDAStream getCUDAStream(Stream* stream); diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 9e1386d0d3d..06b20963314 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -153,6 +153,29 @@ std::string LaunchKernel::toInlineString(int indent_size) const { NVF_CHECK(false, "Can not be printed inline"); } +Deallocate::Deallocate(IrBuilderPasskey passkey, kir::Allocate* allocate) + : Expr(passkey) { + addAttribute(allocate); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Deallocate) + +const kir::Allocate* Deallocate::allocation() const { + return attributes_.at(0)->as(); +} + +std::string Deallocate::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "Deallocate {" << std::endl; + ss << allocation()->toString(indent_size + 1); + indent(ss, indent_size) << "}" << std::endl; + return ss.str(); +} + +std::string Deallocate::toInlineString(int indent_size) const { + return std::string("Deallocate ") + allocation()->buffer()->toInlineString(); +} + Stream::Stream(IrBuilderPasskey passkey, Val* index) : Val(passkey, ValType::Stream), index_(index) {} diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bad3a6ef722..09b6d9ba51a 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -155,6 +155,27 @@ class LaunchKernel : public Expr { } }; +class Deallocate : public Expr { + public: + using Expr::Expr; + Deallocate(IrBuilderPasskey passkey, kir::Allocate* allocate); + + Deallocate(const Deallocate& other) = delete; + Deallocate& operator=(const Deallocate& other) = delete; + Deallocate(Deallocate&& other) = delete; + Deallocate& operator=(Deallocate&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::Deallocate"; + } + + const kir::Allocate* allocation() const; +}; + class Stream : public Val { public: // if index is provided, the IR represents the streams whose index is the diff --git a/tests/cpp/test_host_ir_integration.cpp b/tests/cpp/test_host_ir_integration.cpp index 149f1af7310..e0cc2b7a25f 100644 --- a/tests/cpp/test_host_ir_integration.cpp +++ b/tests/cpp/test_host_ir_integration.cpp @@ -113,6 +113,32 @@ TEST_F(HostIrIntegrationTest, Sum) { ""); } +TEST_F(HostIrIntegrationTest, Deallocate) { + const std::vector sizes = {8, 64}; + c10::DeviceIndex device_index = 0; + + resetPeakMemoryStats(device_index); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + for (int i = 0; i < 10; i++) { + TensorView* tv = makeConcreteTensor(sizes); + tv->setMemoryType(MemoryType::Global); + auto* allocate = IrBuilder::create(tv, MemoryType::Global); + auto* deallocate = IrBuilder::create(allocate); + + hic->pushBackTopLevelExprs(allocate); + hic->pushBackTopLevelExprs(deallocate); + } + + HostIrEvaluator hie(std::move(hic)); + + hie.runWithInput({}); + + EXPECT_EQ(memoryAllocated(device_index), 0); +} + } // namespace hir } // namespace nvfuser From eef49fca3a91595b66f3e3846f9cbf7c972ef71e Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 23 Apr 2025 12:50:19 -0700 Subject: [PATCH 44/68] renaming benchmark (#4293) --- benchmarks/python/test_cross_entropy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/python/test_cross_entropy_loss.py b/benchmarks/python/test_cross_entropy_loss.py index 88bb2101a90..6d9124c56e3 100644 --- a/benchmarks/python/test_cross_entropy_loss.py +++ b/benchmarks/python/test_cross_entropy_loss.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) -def test_rope_fwd_benchmark( +def test_cross_entropy_fwd_benchmark( benchmark, variation: str, executor: str, @@ -52,7 +52,7 @@ def fwd_call(inp): @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) -def test_rope_bwd_benchmark( +def test_cross_entropy_bwd_benchmark( benchmark, variation: str, executor: str, From ce3d607c0730339c2122c9f42e20cb4fc23f09c5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 23 Apr 2025 13:00:07 -0700 Subject: [PATCH 45/68] Issue 4063 normalization scheduler (#4281) Extends on issue #3811 Stacked on top of #4223 Fixing the performance issue for inner/outer normalization scheduler. Leaving a dangling ID on the loop domain of an output TV is bad for performance, since we could have a very large for loop. --- csrc/scheduler/normalization_utils.cpp | 14 ++++++ tests/cpp/test_reduction_pointwise.cpp | 67 ++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index 07c64a4f9c2..3d20f9a5a5b 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -1163,6 +1164,19 @@ bool compileTimeCheck(Fusion* fusion, SchedulerType scheduler_type) { scheduler_type, "no reduction tv"); return false; } + + // Reject when output IDs are not covered by reference tv. Assuming reduction + // scheduler simply uses reduction_tvs[0] as the reference, if that changes, + // this needs to be changed. see issue + // https://github.com/NVIDIA/Fuser/issues/3811 + scheduler_tools::DomainMap domain_map(fusion); + if (!domain_map.isValidReference(reduction_tvs[0], /*check_inputs=*/false)) { + scheduler_debug_utils::canScheduleRejectReason( + scheduler_type, + "Output contains ID that's not scheduled by reference tv."); + return false; + } + auto reduction_type = reduction_scheduler_utils::getReductionType(reduction_tvs); const SchedulerType persistent_heuristic = diff --git a/tests/cpp/test_reduction_pointwise.cpp b/tests/cpp/test_reduction_pointwise.cpp index be711b74fa0..c0ae0c0a65f 100644 --- a/tests/cpp/test_reduction_pointwise.cpp +++ b/tests/cpp/test_reduction_pointwise.cpp @@ -185,6 +185,73 @@ TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalID) { FusionExecutorCache executor_cache(std::move(fusion_ptr)); auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); +} + +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalIDInnerNormalization) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {1, 2}, /*keep_dim=*/true); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + auto tv4 = add(tv0, tv1); + fusion.addOutput(tv4); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({100, 20, 1}, options); + auto t1 = at::randn({100, 20, 128}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); +} + +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalIDOuterNormalization) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({1, -1, -1}); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {0, 1}, /*keep_dim=*/true); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + auto tv4 = add(tv0, tv1); + fusion.addOutput(tv4); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1, 20, 100}, options); + auto t1 = at::randn({128, 20, 100}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); // checking segmentation auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); From db90ef0749725656a4ec3d668668aed3288b2a3d Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 23 Apr 2025 14:17:26 -0700 Subject: [PATCH 46/68] add HirAliasSelect --- csrc/dispatch.h | 3 ++- csrc/host_ir/executor.cpp | 9 ++++++++ csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 45 +++++++++++++++++++++++++++++++++++++ csrc/host_ir/host_ir.h | 43 +++++++++++++++++++++++++++++++++++ tests/cpp/test_host_irs.cpp | 38 +++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 1 deletion(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 218ccd8267a..b9874860bd6 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -158,7 +158,8 @@ class Val; f(Synchronize); \ f(StartCoalescing); \ f(EndCoalescing); \ - f(ShareMemHandles); + f(ShareMemHandles); \ + f(HirAliasSelect); // Forward declarations for all Val and Expr types diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 12cf344e549..66ac0ef1d64 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -751,6 +751,15 @@ void HostIrEvaluator::handle(ReductionOp* reduction_op) { } } +void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { + auto index = + expr_evaluator_.evaluate(hir_alias_select->index()).as(); + auto input = getKnownConcreteValue(hir_alias_select->in()->as()) + .as(); + int64_t axis = hir_alias_select->axis(); + bind(hir_alias_select->out(), input.select(axis, index)); +} + void HostIrEvaluator::unhandled(Statement* stmt) { NVF_ERROR(stmt->isA(), stmt, " must be an Expr"); auto* expr = stmt->as(); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 89ac5119681..3f147b7801b 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -142,6 +142,7 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(BinaryOp* binary_op) override; void handle(ReductionOp* reduction_op) override; void handle(ShareMemHandles* share_mem_handles) override; + void handle(HirAliasSelect* hir_alias_select) override; void unhandled(Statement* stmt) override; c10::cuda::CUDAStream getCUDAStream(Stream* stream); diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 9e1386d0d3d..bf3d5cef9eb 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -355,6 +355,51 @@ std::string ShareMemHandles::toInlineString(int indent_size) const { NVF_THROW("Cannot be printed inline"); } +HirAliasSelect::HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index) + : Expr(passkey, {in, index}, {}, {}) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + this, + "must be registered in a HostIrContainer"); + NVF_ERROR( + static_cast(in->getLogicalDomain().size()) > axis, + "Select axis ", + axis, + " is out of bounds for tensor ", + in->toString(), + " with ", + in->getLogicalDomain().size(), + " dimensions"); + // "out" is not added as an output because the current op doesn't "define" it, + // but rather sets its allocation. Since "out" will be used in another + // producing expression, this avoids unnecessary cyclic dependencies. This + // ressembles how kir::Allocate treats its allocated TensorView. + addAttribute(out); + addDataAttribute(axis); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(HirAliasSelect) + +std::string HirAliasSelect::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent_size++; + indent(ss, indent_size) << " = HirAliasSelect( " << in()->toString() + << ", axis = " << in()->getLogicalDomain().at(axis()) + << ", index = " << index()->toString() << " )\n"; + return ss.str(); +} + +std::string HirAliasSelect::toInlineString(int indent_size) const { + NVF_THROW("Cannot be printed inline"); +} + } // namespace hir } // namespace nvfuser diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bad3a6ef722..d267d23ab1f 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -351,6 +351,49 @@ class ShareMemHandles : public Expr { } }; +// This op mimicks the semantics of SelectOp but is used in HIR non-SSA context +// to index into a TensorView, returning an alias "slice" of the original +// TensorView. +class HirAliasSelect : public Expr { + public: + using Expr::Expr; + HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index); + + HirAliasSelect(const HirAliasSelect& other) = delete; + HirAliasSelect& operator=(const HirAliasSelect& other) = delete; + HirAliasSelect(HirAliasSelect&& other) = delete; + HirAliasSelect& operator=(HirAliasSelect&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::HirAliasSelect"; + } + + TensorView* in() const { + return inputs().at(0)->as(); + } + + TensorView* out() const { + return attributeVal(0)->as(); + } + + int64_t axis() const { + return attribute(1); + } + + Val* index() const { + return inputs().at(1); + } +}; + } // namespace hir } // namespace nvfuser diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 633ebc83504..eb1291de57d 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -1487,6 +1487,44 @@ TEST_F(HirReductionOpTest, NonPreAllocatedOutputs) { << "Expected output: " << expected_out; } +using HirAliasSelectHostIrTest = NVFuserTest; + +TEST_F(HirAliasSelectHostIrTest, SelectingTensor) { + constexpr int64_t ndims = 2; + constexpr int64_t dim = 1; + constexpr int64_t index = 3; + const std::vector input_sizes = {32, 32}; + + ASSERT_LT(dim, ndims); + ASSERT_EQ(input_sizes.size(), ndims); + ASSERT_LT(index, input_sizes.at(dim)); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + TensorView* in = makeContigTensor(ndims); + TensorView* out = makeContigTensor(ndims - 1); + auto* index_val = IrBuilder::create(index, DataType::Index); + auto* select_op = IrBuilder::create(in, out, dim, index_val); + + hic->addInput(in); + hic->addOutput(out); + hic->pushBackTopLevelExprs(select_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); + auto in_aten = at::randn(input_sizes, options); + std::unordered_map concrete_input_buffers = { + {in, in_aten}}; + + auto out_aten = hie.runWithInput(concrete_input_buffers)[0].as(); + + // validate + auto ref_out = in_aten.select(dim, index); + EXPECT_TRUE(ref_out.equal(out_aten)); +} + } // namespace hir } // namespace nvfuser From e32653a383c8e1689d8a1a8d5dbf5fd1d409ea92 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 23 Apr 2025 14:18:05 -0700 Subject: [PATCH 47/68] replace SelectOp by HirAliasSelect in stream lowering --- csrc/host_ir/container.cpp | 10 ++-- csrc/host_ir/pass/stream_parallel_type.cpp | 56 ++++++++++++---------- csrc/ops/indexing.cpp | 10 +--- csrc/ops/indexing.h | 9 +--- csrc/ops/utils.cpp | 27 ++++------- csrc/ops/utils.h | 14 ++---- 6 files changed, 52 insertions(+), 74 deletions(-) diff --git a/csrc/host_ir/container.cpp b/csrc/host_ir/container.cpp index 83e668770fc..9fdcfa376a6 100644 --- a/csrc/host_ir/container.cpp +++ b/csrc/host_ir/container.cpp @@ -35,11 +35,13 @@ Stream* HostIrContainer::getDefaultStream() { std::ostream& HostIrContainer::print(std::ostream& os) const { IrMathPrinter op_exprs(os); op_exprs.handle(this); - os << "Aliases:{"; - for (const auto& alias : alias_) { - os << "\n " << alias.first << " -> " << alias.second; + if (alias_.size() > 0) { + os << "Aliases:{"; + for (const auto& alias : alias_) { + os << "\n " << alias.first << " -> " << alias.second; + } + os << "\n}\n"; } - os << "\n}\n"; return os; } diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 63ebc9fc42c..3d63290ef17 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -119,6 +119,30 @@ int64_t findStreamAxisIndex( return stream_id_logical_index; } +// Helper function to create a sliced version of a tensor for stream +// parallelization +hir::HirAliasSelect* createSlicedTensor( + TensorView* tensor, + int64_t stream_axis_index, + Val* index) { + auto dom = tensor->getLogicalDomain(); + + std::vector new_root; + new_root.reserve(dom.size() - 1); + + for (auto i : arange((int64_t)dom.size())) { + if (i != stream_axis_index) { + new_root.emplace_back(dom[i]->cloneWithoutRFactor()); + } + } + + auto td = IrBuilder::create( + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); + auto out = IrBuilder::create(td, *tensor->getDataType()); + return IrBuilder::create( + tensor, out, stream_axis_index, index); +} + // Step 1: Group expressions into stream-parallel regions std::vector groupStreamParallelRegions( hir::HostIrContainer* hic, @@ -222,12 +246,9 @@ std::vector processForLoopBodies( // Create a sliced version of the input tensor for this stream // iterdomain - TensorView* input_j = select( - input, - input_stream_id_logical_index, - for_loop->index(), - /*keep_reduction_axis=*/true); - new_loop_body.push_back(input_j->definition()); + hir::HirAliasSelect* input_slicing = createSlicedTensor( + input, input_stream_id_logical_index, for_loop->index()); + new_loop_body.push_back(input_slicing); // Update all expressions that use this input to use the sliced version for (auto it_running_expr = current_loop_body.begin(); @@ -238,7 +259,7 @@ std::vector processForLoopBodies( ir_utils::filterByType(running_expr->inputs())) { if (running_input == input) { *it_running_expr = ir_utils::replaceValInExprInputs( - running_expr, input, input_j); + running_expr, input, input_slicing->out()); } } } @@ -256,17 +277,14 @@ std::vector processForLoopBodies( } // Create a sliced version of the output tensor for this stream axis - TensorView* output_j = select( - output, - output_stream_id_logical_index, - for_loop->index(), - /*keep_reduction_axis=*/true); + hir::HirAliasSelect* output_slicing = createSlicedTensor( + output, output_stream_id_logical_index, for_loop->index()); // Allocate memory for the output tensor, and place the allocation IR // before the for-loop, at the top level new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); - new_loop_body.push_back(output_j->definition()); + new_loop_body.push_back(output_slicing); // Update all expressions that use this output to use the sliced version for (auto it_running_expr = current_loop_body.begin(); @@ -276,18 +294,8 @@ std::vector processForLoopBodies( for (auto* running_output : ir_utils::filterByType(running_expr->outputs())) { if (running_output == output) { - // Create an alias for the sliced output to maintain the original - // tensor's properties Alias is needed here to avoid that - // transferDefinitionToNewOutputs throws. Indeed, HIC does not - // make the SSA assumption, but the util functions we use (such as - // transferDefinitionToNewOutputs) do, therefore we need to create - // an alias for the sliced output to not create loops in the dag. - TensorView* output_j_alias = - ops::newValLike(output_j, output_j->dtype(), true) - ->as(); - hic->markAlias(output_j, output_j_alias); *it_running_expr = ir_utils::transferDefinitionToNewOutputs( - running_expr, {output_j_alias}); + running_expr, {output_slicing->out()}); } } } diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index 80c0ff84b85..5ff75065ff2 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -19,14 +19,8 @@ namespace nvfuser { -TensorView* select( - TensorView* tv, - int64_t dim, - Val* index, - bool keep_reduction_axis) { - auto dom = keep_reduction_axis - ? tv->getLogicalDomain() - : TensorDomain::noReductions(tv->getLogicalDomain()); +TensorView* select(TensorView* tv, int64_t dim, Val* index) { + auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); NVF_CHECK(!dom.empty(), "select can not be applied to 0d tensor."); std::vector new_root; diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index 5e0410d95d5..c8152c33f82 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -15,14 +15,7 @@ namespace nvfuser { -// When keep_reduction_axis is true, all reduction axis are kept in the -// SelectOp's consumer. This is used in the context of HostIr where SelectOp is -// used to index into Stream-parallelized axes. -NVF_API TensorView* select( - TensorView* tv, - int64_t dim, - Val* index, - bool keep_reduction_axis = false); +NVF_API TensorView* select(TensorView* tv, int64_t dim, Val* index); // torch.index_select NVF_API TensorView* indexSelect( diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 5d32c22e212..8d3870d1a84 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -432,9 +432,7 @@ IterDomain* newOutputIterDomain( #pragma GCC diagnostic pop #endif -std::vector newOutputDomain( - const std::vector& vals, - bool keep_reduction_axis) { +std::vector newOutputDomain(const std::vector& vals) { std::vector tvs; for (auto val : vals) { if (auto* tv = dynamic_cast(val)) { @@ -445,20 +443,14 @@ std::vector newOutputDomain( !tvs.empty(), "Tried to create new output TensorView but received empty list."); - auto getLogicalDomain = - [keep_reduction_axis](TensorView* tv) -> std::vector { - return keep_reduction_axis - ? tv->getLogicalDomain() - : TensorDomain::noReductions(tv->getLogicalDomain()); - }; - - std::vector out_domain(getLogicalDomain(tvs[0]).size(), nullptr); + std::vector out_domain( + TensorDomain::noReductions(tvs[0]->getLogicalDomain()).size(), nullptr); for (const auto dim_i : arange(out_domain.size())) { std::vector input_ids; input_ids.reserve(tvs.size()); for (auto* tv : tvs) { - auto dom = getLogicalDomain(tv); + auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); input_ids.emplace_back(dom[dim_i]); } out_domain[dim_i] = newOutputIterDomain(input_ids); @@ -466,11 +458,8 @@ std::vector newOutputDomain( return out_domain; } -TensorView* newOutputTV( - const std::vector& vals, - DataType dtype, - bool keep_reduction_axis) { - auto out_domain = newOutputDomain(vals, keep_reduction_axis); +TensorView* newOutputTV(const std::vector& vals, DataType dtype) { + auto out_domain = newOutputDomain(vals); auto* new_out = IrBuilder::create( IrBuilder::create( out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), @@ -513,12 +502,12 @@ std::vector maybeBroadcast(const std::vector& vals) { return out_vals; } -Val* newValLike(Val* val, DataType dtype, bool keep_reduction_axis) { +Val* newValLike(Val* val, DataType dtype) { NVF_CHECK( dtype != DataType::Null, "Invalid datatype provided for new value."); if (val->isA()) { - return newOutputTV({val}, dtype, keep_reduction_axis); + return newOutputTV({val}, dtype); } return newScalar(ValType::Others, dtype); diff --git a/csrc/ops/utils.h b/csrc/ops/utils.h index 1a2abda03fc..94d6391cf45 100644 --- a/csrc/ops/utils.h +++ b/csrc/ops/utils.h @@ -99,21 +99,13 @@ IterDomain* newOutputIterDomain( // output tensorview, e.g., for BinaryOp. `vals` can contain scalars, e.g, when // creating the output TensorView for `tv0+scalar`. This is for convenience and // scalars will be ignored. -std::vector newOutputDomain( - const std::vector& vals, - bool keep_reduction_axis = false); +std::vector newOutputDomain(const std::vector& vals); -TensorView* newOutputTV( - const std::vector& vals, - DataType dtype, - bool keep_reduction_axis = false); +TensorView* newOutputTV(const std::vector& vals, DataType dtype); std::vector maybeBroadcast(const std::vector& vals); -NVF_API Val* newValLike( - Val* val, - DataType dtype, - bool keep_reduction_axis = false); +NVF_API Val* newValLike(Val* val, DataType dtype); // returns the minimum init value for reduction: // -inf for floating type; From a50b53c90e744cf469779dcccdc613c6af68958f Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 23 Apr 2025 14:48:36 -0700 Subject: [PATCH 48/68] add cache for tensor slicing --- csrc/host_ir/pass/stream_parallel_type.cpp | 92 ++++++++++++++++------ 1 file changed, 68 insertions(+), 24 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 3d63290ef17..6999b0c4ca5 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -119,29 +119,67 @@ int64_t findStreamAxisIndex( return stream_id_logical_index; } -// Helper function to create a sliced version of a tensor for stream -// parallelization -hir::HirAliasSelect* createSlicedTensor( - TensorView* tensor, - int64_t stream_axis_index, - Val* index) { - auto dom = tensor->getLogicalDomain(); - - std::vector new_root; - new_root.reserve(dom.size() - 1); - - for (auto i : arange((int64_t)dom.size())) { - if (i != stream_axis_index) { - new_root.emplace_back(dom[i]->cloneWithoutRFactor()); +// Cache for tensor slicing operations in stream parallelization. +// This cache stores previously created sliced versions of tensors to avoid +// redundant slicing operations. A sliced tensor is created by removing a +// specific axis (stream axis) from the tensor's domain and creating a new +// tensor that represents a slice of the original tensor at a given index. +// The cache key is a tuple of (original tensor, axis index to remove, slice +// index). +struct TensorSlicingCache { + // Type aliases + using Key = std::tuple; + + // Custom hash function for the tuple used as cache key + struct Hash { + size_t operator()(const Key& t) const { + auto [tv, idx, val] = t; + return std::hash{}(tv) ^ std::hash{}(idx) ^ + std::hash{}(val); } + }; + + // Map type for storing cached sliced tensors + using Map = std::unordered_map; + + // Get the expr producing the indexed version of a tensor. If the expr already + // exists in the cache, returns the cached version. Otherwise, creates a new + // expr, producing a tensor "selected" on its dimension `stream_axis_index` at + // index `index`. Returns a pair of (expr, is_new) where is_new indicates + // whether the expr was newly created. + std::pair get( + TensorView* tensor, + int64_t stream_axis_index, + Val* index) { + auto key = std::make_tuple(tensor, stream_axis_index, index); + auto it = cache_.find(key); + if (it != cache_.end()) { + return {it->second, false}; + } + + auto dom = tensor->getLogicalDomain(); + std::vector new_root; + new_root.reserve(dom.size() - 1); + + for (auto i : arange((int64_t)dom.size())) { + if (i != stream_axis_index) { + new_root.emplace_back(dom[i]->cloneWithoutRFactor()); + } + } + + auto td = IrBuilder::create( + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); + auto out = IrBuilder::create(td, *tensor->getDataType()); + auto result = IrBuilder::create( + tensor, out, stream_axis_index, index); + + cache_[key] = result; + return {result, true}; } - auto td = IrBuilder::create( - new_root, TensorDomain::getContiguityFilledWith(new_root, true)); - auto out = IrBuilder::create(td, *tensor->getDataType()); - return IrBuilder::create( - tensor, out, stream_axis_index, index); -} + private: + Map cache_; // Storage for cached sliced tensors +}; // Step 1: Group expressions into stream-parallel regions std::vector groupStreamParallelRegions( @@ -214,6 +252,8 @@ std::vector processForLoopBodies( const IdModel& id_model, std::vector top_level_exprs) { std::vector new_top_level_exprs; + // Create a cache for tensor indexing + TensorSlicingCache tensor_slicing_cache; // Process each top-level expression for (auto top_level_expr : top_level_exprs) { @@ -246,9 +286,11 @@ std::vector processForLoopBodies( // Create a sliced version of the input tensor for this stream // iterdomain - hir::HirAliasSelect* input_slicing = createSlicedTensor( + auto [input_slicing, is_new] = tensor_slicing_cache.get( input, input_stream_id_logical_index, for_loop->index()); - new_loop_body.push_back(input_slicing); + if (is_new) { + new_loop_body.push_back(input_slicing); + } // Update all expressions that use this input to use the sliced version for (auto it_running_expr = current_loop_body.begin(); @@ -277,14 +319,16 @@ std::vector processForLoopBodies( } // Create a sliced version of the output tensor for this stream axis - hir::HirAliasSelect* output_slicing = createSlicedTensor( + auto [output_slicing, is_new] = tensor_slicing_cache.get( output, output_stream_id_logical_index, for_loop->index()); + if (is_new) { + new_loop_body.push_back(output_slicing); + } // Allocate memory for the output tensor, and place the allocation IR // before the for-loop, at the top level new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); - new_loop_body.push_back(output_slicing); // Update all expressions that use this output to use the sliced version for (auto it_running_expr = current_loop_body.begin(); From df447be6c984f12febc93d339f9400c1e00b865f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 23 Apr 2025 15:53:04 -0700 Subject: [PATCH 49/68] indexAccumulate python api (#4066) Things done in this PR is to support embedding backward, which requires `torch.index_put_(..., accumulate=True)`. Stacked PRs: - [x] #4063 - [ ] #4066 <-- This PR What this PR does: * Added python API ```Tensor fd.ops.index_accumulate(Tensor acc, Tensor index, Tensor value``` --------- Co-authored-by: Ryan Spring --- csrc/python_frontend/fusion_record.h | 24 ++++++++++++++ csrc/python_frontend/python_bindings.cpp | 42 ++++++++++++++++++++++++ csrc/serde/fusion_cache.fbs | 1 + csrc/serde/fusion_record.cpp | 7 ++++ tests/python/opinfo_input_generators.py | 19 +++++++++++ tests/python/opinfos.py | 28 ++++++++++++++++ 6 files changed, 121 insertions(+) diff --git a/csrc/python_frontend/fusion_record.h b/csrc/python_frontend/fusion_record.h index 3a6af8cfeb3..b437c5e247b 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/csrc/python_frontend/fusion_record.h @@ -3095,6 +3095,30 @@ struct EmbeddingFwdOpRecord : RecordFunctor { } }; +struct IndexPutAccumulateOpRecord : RecordFunctor { + IndexPutAccumulateOpRecord( + std::vector args, + std::vector outputs) + : RecordFunctor( + std::move(args), + std::move(outputs), + "ops.index_put_accumulate", + serde::RecordType::IndexPutAccumulateOp) {} + ~IndexPutAccumulateOpRecord() override = default; + RecordFunctor* clone() final { + return new IndexPutAccumulateOpRecord(*this); + } + + void operator()(FusionState& fd) final { + auto acc = fd.getFusionState(args_.at(0).index)->as(); + auto index = fd.getFusionState(args_.at(1).index)->as(); + auto value = fd.getFusionState(args_.at(2).index)->as(); + + auto output = indexPutAccumulate(acc, index, value); + fd.setFusionState(outputs_.at(0).index, output); + } +}; + } // namespace nvfuser::python_frontend //! Creating the template specialized hash and equal_to functions for a diff --git a/csrc/python_frontend/python_bindings.cpp b/csrc/python_frontend/python_bindings.cpp index 123eec51263..d8b7ad291d4 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/csrc/python_frontend/python_bindings.cpp @@ -3091,6 +3091,48 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("index"), py::arg("dim"), py::return_value_policy::reference); + nvf_ops.def( + "index_put_accumulate", + [](FusionDefinition::Operators& self, + Tensor acc, + Tensor index, + Tensor value) -> Tensor { + FUSER_PERF_SCOPE("Operators.index_put_accumulate"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Tensor output = fd->defineTensor(acc.dims); + fd->defineRecord(new IndexPutAccumulateOpRecord( + { + fd->recordingState(acc()), + fd->recordingState(index()), + fd->recordingState(value()), + }, + {fd->recordingState(output())})); + return output; + }, + py::arg("acc"), + py::arg("index"), + py::arg("value"), + py::return_value_policy::reference, + R"doc( + Accumulates values into a tensor at specified indices. + + This function performs a restricted version of `torch.index_put`. + It adds the values from `value_tv` to the elements of `acc_tv` at the indices + specified by `index_tv`. + + acc_tv: The tensor to accumulate into (in-place modification). + index_tv: The tensor containing the indices. + value_tv: The tensor containing the values to accumulate. + + Returns: + An alias to the modified `acc_tv` tensor. + + Note: + This is a restricted version and may not support all features of the + full `torch.index_put(..., accumulate=true)` function. + )doc"); nvf_ops.def( "select", [](FusionDefinition::Operators& self, diff --git a/csrc/serde/fusion_cache.fbs b/csrc/serde/fusion_cache.fbs index c2b90b08a5c..def7a760eea 100644 --- a/csrc/serde/fusion_cache.fbs +++ b/csrc/serde/fusion_cache.fbs @@ -48,6 +48,7 @@ enum RecordType: int { FullOp, IotaOp, IndexSelectOp, + IndexPutAccumulateOp, SelectOp, GatherOp, TakeAlongAxisOp, diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index ce9412770d9..0233174d260 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -472,6 +472,13 @@ void RecordFunctorFactory::registerAllParsers() { }; registerParser(RecordType::IndexSelectOp, deserializeIndexSelectRecord); + auto deserializeIndexPutAccumulateRecord = [](const RecordFunctor* buffer) { + return new python_frontend::IndexPutAccumulateOpRecord( + parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); + }; + registerParser( + RecordType::IndexPutAccumulateOp, deserializeIndexPutAccumulateRecord); + auto deserializeSelectRecord = [](const RecordFunctor* buffer) { return new python_frontend::SelectOpRecord( parseStateArgs(buffer->args()), diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index d6653841e6f..199ced4de8b 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -822,6 +822,25 @@ def index_select_error_generator( # yield SampleInput(a, b, 0), RuntimeError, "out of bounds index value." +def index_put_accumulate_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + make_index = partial(make_tensor, device="cuda", requires_grad=False) + + # vocab_size, hidden_size, seq_size + cases = ((1024, 12, 300),) + + for vocab, hidden, seq in cases: + for index_dtype in [torch.int, torch.long]: + acc = make_arg((vocab, hidden)) + index = make_index((seq,), low=0, high=vocab, dtype=index_dtype) + value = make_arg((seq, hidden)) + yield SampleInput(acc, index, value) + + def iota_error_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index 14c4f0e8f9f..70fb1c32153 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -29,6 +29,7 @@ gather_generator, index_select_generator, index_select_error_generator, + index_put_accumulate_generator, iota_error_generator, pad_error_generator, permute_generator, @@ -1023,6 +1024,33 @@ def gather_wrapper(fn: callable, input: torch.Tensor, index: torch.Tensor, dim: ) shape_ops.append(index_select_opinfo) + +def index_put_accumulate_ref( + acc: torch.Tensor, index: torch.Tensor, value: torch.Tensor +): + return torch.index_put( + acc, + [ + index, + ], + value, + accumulate=True, + ) + + +index_put_accumulate_opinfo = OpInfo( + lambda fd: fd.ops.index_put_accumulate, + "index_put_accumulate", + sample_input_generator=index_put_accumulate_generator, + reference=index_put_accumulate_ref, + symbolic_parameter_list=( + ArgumentType.Symbolic, + ArgumentType.Symbolic, + ArgumentType.Symbolic, + ), +) +shape_ops.append(index_put_accumulate_opinfo) + # NvFuser's API is significantly different than JAX. # TODO: Change python frontend api to match JAX using a cpp wrapper function. pad_opinfo = OpInfo( From d01c5a27f64db79e568827e25730d9ab6b84cfa4 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 23 Apr 2025 15:58:29 -0700 Subject: [PATCH 50/68] separate out tensor allocation logic --- csrc/host_ir/pass/stream_parallel_type.cpp | 45 +++++++++++++++++----- 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 6999b0c4ca5..7532c145ba1 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -246,7 +246,36 @@ std::vector groupStreamParallelRegions( return new_top_level_exprs; } -// Step 2: Process for-loop bodies by slicing tensors +// Helper function to add allocations for tensors that need them +std::vector addTensorAllocations( + std::vector top_level_exprs, + const IdModel& id_model) { + std::vector new_top_level_exprs; + + for (auto* expr : top_level_exprs) { + if (expr->isA()) { + // add allocations for tensors produced in the loop that have a stream axes + auto* for_loop = expr->as(); + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* output : ir_utils::filterByType(body_expr->outputs())) { + if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) != -1) { + new_top_level_exprs.push_back( + IrBuilder::create(output, MemoryType::Global)); + } + } + } + } + new_top_level_exprs.push_back(expr); + } + + // Add all original expressions + new_top_level_exprs.insert( + new_top_level_exprs.end(), top_level_exprs.begin(), top_level_exprs.end()); + + return new_top_level_exprs; +} + +// Step 3: Process for-loop bodies by slicing tensors std::vector processForLoopBodies( hir::HostIrContainer* hic, const IdModel& id_model, @@ -325,11 +354,6 @@ std::vector processForLoopBodies( new_loop_body.push_back(output_slicing); } - // Allocate memory for the output tensor, and place the allocation IR - // before the for-loop, at the top level - new_top_level_exprs.push_back( - IrBuilder::create(output, MemoryType::Global)); - // Update all expressions that use this output to use the sliced version for (auto it_running_expr = current_loop_body.begin(); it_running_expr != current_loop_body.end(); @@ -424,7 +448,7 @@ std::vector addStreamManagement(std::vector top_level_exprs) { // 1. Identifying stream-parallelized axes in tensor operations // 2. Grouping compatible operations into stream-parallel for-loops // 3. Setting up proper stream synchronization and management -// +// 4. Adding allocations for tensors that need them // The pass ensures that: // - Input tensors don't have stream axes // - Only one stream axis exists per tensor @@ -462,11 +486,14 @@ void StreamParallelType::runPass(Fusion* fusion) { std::vector top_level_exprs = groupStreamParallelRegions(hic, id_model); - // Step 2: Process for-loop bodies by slicing tensors + // Step 2: Add allocations for tensors that need them + top_level_exprs = addTensorAllocations(std::move(top_level_exprs), id_model); + + // Step 3: Process for-loop bodies by slicing tensors top_level_exprs = processForLoopBodies(hic, id_model, std::move(top_level_exprs)); - // Step 3: Add stream management and synchronization + // Step 4: Add stream management and synchronization top_level_exprs = addStreamManagement(std::move(top_level_exprs)); // Update the container's top-level expressions From 85f98948a0e858fb78f05e9e61511056dfcc9661 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 23 Apr 2025 16:41:26 -0700 Subject: [PATCH 51/68] Revert "Deallocate HostIr Op and Test" (#4303) Reverts NVIDIA/Fuser#4286 because of http://nv/eF0 --- csrc/dispatch.h | 3 +- csrc/host_ir/executor.cpp | 48 +++++++------------------- csrc/host_ir/executor.h | 1 - csrc/host_ir/host_ir.cpp | 23 ------------ csrc/host_ir/host_ir.h | 21 ----------- tests/cpp/test_host_ir_integration.cpp | 26 -------------- 6 files changed, 13 insertions(+), 109 deletions(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index fed2b39511c..f1f4153d1d2 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -160,8 +160,7 @@ class Val; f(Synchronize); \ f(StartCoalescing); \ f(EndCoalescing); \ - f(ShareMemHandles); \ - f(Deallocate); + f(ShareMemHandles); // Forward declarations for all Val and Expr types diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index b029f748671..89710eaae4b 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -329,35 +329,20 @@ void HostIrEvaluator::handle(LaunchKernel* launch_kernel) { for (auto& input : launch_kernel->inputs()) { args.push(getKnownConcreteValue(input)); } - - // If all output buffers are known already, pass them to the executor - KernelArgumentHolder outputs; - bool preallocated_outputs = false; - for (Val* output : launch_kernel->outputs()) { - if (isKnown(output)) { - preallocated_outputs = true; - outputs.push(getKnownConcreteValue(output)); - } - } - - NVF_ERROR( - outputs.empty() || outputs.size() == launch_kernel->outputs().size()); - args.setDeviceIndex(); // run the compiled kernel - outputs = container_->getKernelExecutor(launch_kernel->getIndex()) - ->run( - args, - outputs, - launch_kernel->launch_params(), - launch_kernel->compile_params()); - - if (!preallocated_outputs) { - // Store the outputs in the context - for (auto output_idx : arange(outputs.size())) { - bind(launch_kernel->outputs().at(output_idx), outputs[output_idx]); - } + KernelArgumentHolder outputs = + container_->getKernelExecutor(launch_kernel->getIndex()) + ->run( + args, + {}, + launch_kernel->launch_params(), + launch_kernel->compile_params()); + + // Store the outputs in the context + for (auto output_idx : arange(outputs.size())) { + bind(launch_kernel->outputs().at(output_idx), outputs[output_idx]); } } @@ -652,7 +637,7 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { "Allocation must be on a TensorView but got ", allocate->buffer()); TensorView* tv = allocate->buffer()->as(); - if (isKnown(tv)) { + if (expr_evaluator_.isKnown(tv)) { return; } GlobalBufferInfo info = @@ -669,15 +654,6 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } -void HostIrEvaluator::handle(Deallocate* deallocate) { - auto* tv = deallocate->allocation()->buffer()->as(); - NVF_ERROR( - isKnown(tv), - "Tried to free buffer associated with unknown TensorView", - tv); - invalidate(tv); -} - void HostIrEvaluator::unhandled(Statement* stmt) { NVF_ERROR(stmt->isA(), stmt, " must be an Expr"); auto* expr = stmt->as(); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index c854d2312fc..d71b74e0dda 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -135,7 +135,6 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(LinearOp* linear) override; void handle(kir::Allocate* allocate) override; void handle(ShareMemHandles* share_mem_handles) override; - void handle(Deallocate* deallocate) override; void unhandled(Statement* stmt) override; c10::cuda::CUDAStream getCUDAStream(Stream* stream); diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 06b20963314..9e1386d0d3d 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -153,29 +153,6 @@ std::string LaunchKernel::toInlineString(int indent_size) const { NVF_CHECK(false, "Can not be printed inline"); } -Deallocate::Deallocate(IrBuilderPasskey passkey, kir::Allocate* allocate) - : Expr(passkey) { - addAttribute(allocate); -} - -NVFUSER_DEFINE_CLONE_AND_CREATE(Deallocate) - -const kir::Allocate* Deallocate::allocation() const { - return attributes_.at(0)->as(); -} - -std::string Deallocate::toString(int indent_size) const { - std::stringstream ss; - indent(ss, indent_size) << "Deallocate {" << std::endl; - ss << allocation()->toString(indent_size + 1); - indent(ss, indent_size) << "}" << std::endl; - return ss.str(); -} - -std::string Deallocate::toInlineString(int indent_size) const { - return std::string("Deallocate ") + allocation()->buffer()->toInlineString(); -} - Stream::Stream(IrBuilderPasskey passkey, Val* index) : Val(passkey, ValType::Stream), index_(index) {} diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index 09b6d9ba51a..bad3a6ef722 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -155,27 +155,6 @@ class LaunchKernel : public Expr { } }; -class Deallocate : public Expr { - public: - using Expr::Expr; - Deallocate(IrBuilderPasskey passkey, kir::Allocate* allocate); - - Deallocate(const Deallocate& other) = delete; - Deallocate& operator=(const Deallocate& other) = delete; - Deallocate(Deallocate&& other) = delete; - Deallocate& operator=(Deallocate&& other) = delete; - - NVFUSER_DECLARE_CLONE_AND_CREATE - - std::string toString(int indent_size = 0) const override; - std::string toInlineString(int indent_size = 0) const override; - const char* getOpString() const override { - return "hir::Deallocate"; - } - - const kir::Allocate* allocation() const; -}; - class Stream : public Val { public: // if index is provided, the IR represents the streams whose index is the diff --git a/tests/cpp/test_host_ir_integration.cpp b/tests/cpp/test_host_ir_integration.cpp index e0cc2b7a25f..149f1af7310 100644 --- a/tests/cpp/test_host_ir_integration.cpp +++ b/tests/cpp/test_host_ir_integration.cpp @@ -113,32 +113,6 @@ TEST_F(HostIrIntegrationTest, Sum) { ""); } -TEST_F(HostIrIntegrationTest, Deallocate) { - const std::vector sizes = {8, 64}; - c10::DeviceIndex device_index = 0; - - resetPeakMemoryStats(device_index); - - auto hic = std::make_unique(); - FusionGuard fg(hic.get()); - - for (int i = 0; i < 10; i++) { - TensorView* tv = makeConcreteTensor(sizes); - tv->setMemoryType(MemoryType::Global); - auto* allocate = IrBuilder::create(tv, MemoryType::Global); - auto* deallocate = IrBuilder::create(allocate); - - hic->pushBackTopLevelExprs(allocate); - hic->pushBackTopLevelExprs(deallocate); - } - - HostIrEvaluator hie(std::move(hic)); - - hie.runWithInput({}); - - EXPECT_EQ(memoryAllocated(device_index), 0); -} - } // namespace hir } // namespace nvfuser From 25b7695b2db13c85020845fdca4bb1e91f7e8359 Mon Sep 17 00:00:00 2001 From: snordmann Date: Wed, 23 Apr 2025 16:12:23 -0700 Subject: [PATCH 52/68] minor cleanup --- csrc/host_ir/pass/stream_parallel_type.cpp | 147 ++++++--------------- 1 file changed, 42 insertions(+), 105 deletions(-) diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 7532c145ba1..8bea4c3430f 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -183,12 +183,11 @@ struct TensorSlicingCache { // Step 1: Group expressions into stream-parallel regions std::vector groupStreamParallelRegions( - hir::HostIrContainer* hic, + const std::vector& top_level_exprs, const IdModel& id_model) { std::vector new_top_level_exprs; - // Process each top-level expression - for (auto expr : hic->topLevelExprs()) { + for (auto* expr : top_level_exprs) { // Skip expressions with no outputs if (expr->outputs().size() == 0) { new_top_level_exprs.push_back(expr); @@ -229,9 +228,9 @@ std::vector groupStreamParallelRegions( auto* for_loop = IrBuilder::create( stream_axis, /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream), - /*start=*/hic->zeroVal(), + /*start=*/FusionGuard::getCurFusion()->zeroVal(), /*stop=*/stream_axis->extent(), - /*step=*/hic->oneVal(), + /*step=*/FusionGuard::getCurFusion()->oneVal(), /*vectorize=*/false, /*vectorize_shift=*/nullptr, /*unroll_required=*/false, @@ -254,11 +253,14 @@ std::vector addTensorAllocations( for (auto* expr : top_level_exprs) { if (expr->isA()) { - // add allocations for tensors produced in the loop that have a stream axes + // add allocations for tensors produced in the loop that have a stream + // axes auto* for_loop = expr->as(); for (auto* body_expr : for_loop->body().exprs()) { - for (auto* output : ir_utils::filterByType(body_expr->outputs())) { - if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) != -1) { + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) != + -1) { new_top_level_exprs.push_back( IrBuilder::create(output, MemoryType::Global)); } @@ -268,131 +270,68 @@ std::vector addTensorAllocations( new_top_level_exprs.push_back(expr); } - // Add all original expressions - new_top_level_exprs.insert( - new_top_level_exprs.end(), top_level_exprs.begin(), top_level_exprs.end()); - return new_top_level_exprs; } // Step 3: Process for-loop bodies by slicing tensors std::vector processForLoopBodies( - hir::HostIrContainer* hic, - const IdModel& id_model, - std::vector top_level_exprs) { - std::vector new_top_level_exprs; - // Create a cache for tensor indexing + std::vector top_level_exprs, + const IdModel& id_model) { TensorSlicingCache tensor_slicing_cache; - // Process each top-level expression - for (auto top_level_expr : top_level_exprs) { - // Skip non-for-loop expressions - if (!top_level_expr->isA()) { - new_top_level_exprs.push_back(top_level_expr); + for (auto* expr : top_level_exprs) { + if (!expr->isA()) { continue; } - auto* for_loop = top_level_expr->as(); + auto* for_loop = expr->as(); std::vector new_loop_body; - std::vector current_loop_body = for_loop->body().exprs(); - - // Process each expression in the loop body - for (auto it_expr = current_loop_body.begin(); - it_expr != current_loop_body.end(); - ++it_expr) { - Expr* expr = *it_expr; - - // Process input tensors that might have stream axes - for (auto* input : ir_utils::filterByType(expr->inputs())) { - // Find if this input has a stream axis - int64_t input_stream_id_logical_index = - findStreamAxisIndex(input, for_loop->iterDomain(), id_model); - - // Skip if no stream axis found - if (input_stream_id_logical_index == -1) { - continue; - } - // Create a sliced version of the input tensor for this stream - // iterdomain - auto [input_slicing, is_new] = tensor_slicing_cache.get( - input, input_stream_id_logical_index, for_loop->index()); + // Lambda to process a tensor in a for-loop body + auto processTensor = [&](Expr*& expr, TensorView* tensor) { + if (auto stream_idx = + findStreamAxisIndex(tensor, for_loop->iterDomain(), id_model); + stream_idx != -1) { + auto [slicing, is_new] = + tensor_slicing_cache.get(tensor, stream_idx, for_loop->index()); if (is_new) { - new_loop_body.push_back(input_slicing); + new_loop_body.push_back(slicing); } - - // Update all expressions that use this input to use the sliced version - for (auto it_running_expr = current_loop_body.begin(); - it_running_expr != current_loop_body.end(); - ++it_running_expr) { - Expr* running_expr = *it_running_expr; - for (auto* running_input : - ir_utils::filterByType(running_expr->inputs())) { - if (running_input == input) { - *it_running_expr = ir_utils::replaceValInExprInputs( - running_expr, input, input_slicing->out()); - } - } + expr = ir_utils::replaceValInExprInputs(expr, tensor, slicing->out()); + if (expr->outputs().size() > 0 && expr->outputs()[0] == tensor) { + expr = + ir_utils::transferDefinitionToNewOutputs(expr, {slicing->out()}); } } + }; - // Process output tensors that might have stream axes - for (auto* output : ir_utils::filterByType(expr->outputs())) { - // Find if this output has a stream axis - int64_t output_stream_id_logical_index = - findStreamAxisIndex(output, for_loop->iterDomain(), id_model); - - // Skip if no stream axis found - if (output_stream_id_logical_index == -1) { - continue; - } - - // Create a sliced version of the output tensor for this stream axis - auto [output_slicing, is_new] = tensor_slicing_cache.get( - output, output_stream_id_logical_index, for_loop->index()); - if (is_new) { - new_loop_body.push_back(output_slicing); - } - - // Update all expressions that use this output to use the sliced version - for (auto it_running_expr = current_loop_body.begin(); - it_running_expr != current_loop_body.end(); - ++it_running_expr) { - Expr* running_expr = *it_running_expr; - for (auto* running_output : - ir_utils::filterByType(running_expr->outputs())) { - if (running_output == output) { - *it_running_expr = ir_utils::transferDefinitionToNewOutputs( - running_expr, {output_slicing->out()}); - } - } - } + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* input : + ir_utils::filterByType(body_expr->inputs())) { + processTensor(body_expr, input); } - - // Add the original expression to the new loop body - new_loop_body.push_back(*it_expr); + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + processTensor(body_expr, output); + } + new_loop_body.push_back(body_expr); } - // Update the for-loop body with all the processed expressions for_loop->body().clear(); for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); } - new_top_level_exprs.push_back(top_level_expr); } - return new_top_level_exprs; + return top_level_exprs; } -// Step 3: Add stream management and synchronization +// Step 4: Add stream management and synchronization std::vector addStreamManagement(std::vector top_level_exprs) { - std::vector new_top_level_exprs; - // Process each top-level expression for (auto* top_level_expr : top_level_exprs) { // Skip non-for-loop expressions if (!top_level_expr->isA()) { - new_top_level_exprs.push_back(top_level_expr); continue; } @@ -434,10 +373,9 @@ std::vector addStreamManagement(std::vector top_level_exprs) { for (auto* expr : new_loop_body) { for_loop->body().push_back(expr); } - new_top_level_exprs.push_back(top_level_expr); } - return new_top_level_exprs; + return top_level_exprs; } } // anonymous namespace @@ -484,14 +422,13 @@ void StreamParallelType::runPass(Fusion* fusion) { // Step 1: Group expressions into stream-parallel regions std::vector top_level_exprs = - groupStreamParallelRegions(hic, id_model); + groupStreamParallelRegions(hic->topLevelExprs(), id_model); // Step 2: Add allocations for tensors that need them top_level_exprs = addTensorAllocations(std::move(top_level_exprs), id_model); // Step 3: Process for-loop bodies by slicing tensors - top_level_exprs = - processForLoopBodies(hic, id_model, std::move(top_level_exprs)); + top_level_exprs = processForLoopBodies(std::move(top_level_exprs), id_model); // Step 4: Add stream management and synchronization top_level_exprs = addStreamManagement(std::move(top_level_exprs)); From a958bfc9de2af4d75940c4fe26f39f0e396e31c5 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 23 Apr 2025 17:00:02 -0700 Subject: [PATCH 53/68] Forward full op (#4269) Fusion segmenter sets aside a certain sequence of unary ops starting with fusion inputs, which we call forwarding. It effectively works as an optimization by recomputing (cheap) unary ops instead of passing tensors from one segment to another. This PR extends the forwarding optimization to those starting with factory methods. Here's a motivating example (Litgpt Llama 3 RoPE backward): ![llama_bwd](https://github.com/user-attachments/assets/84f83b2e-d7c6-4fad-9dee-6cc17578285d) The `T81` tensor is the output a full op. The tensor is used inside both yellow and gray segments. The op itself is in the yellow segment, so it's created inside the yellow segment, and that is passed, through gmem, to the gray segment. Obviously, cheap ops like this should be just replicated in the gray segment instead of passing a full tensor. Here's another way to see it: ``` g{(resize) group id: 4 inputs: T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] __bfloat outputs: T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat T54_g___bfloat[bS233{1}, iS238{8}rf, iS239{4}rf, iS235{8192}, iS236{128}] __bfloat T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] = full({1, 32, 8192, 128}, __bfloat(0)); (121) ... ``` And `T81` is used in the next segment of: ``` g{(resize) group id: 3 inputs: T18_g___bfloat[bS79{1}, iS80{32}, iS81{8192}, iS82{128}] __bfloat T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}] __bfloat T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat outputs: T75_g___bfloat[bS328{1}, iS329{8}, iS331{6}rf, iS332{8192}, iS333{128}] __bfloat T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf] = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {64, 128, 1} } ) (52) T55_g___bfloat[bS240{1}, iS241{32}, iS242{8192}, iS244{128}rf] = pad( T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf], {0, 0, 0, 0, 0, 0, 0, 64} ) (61) T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf] = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {0, 64, 1} } ) (39) T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}] = __bfloat2float(T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf]); (44) T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}] = -T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}]; (47) T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}] = __float2bfloat(T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}]); (49) T51_g___bfloat[bS217{1}, iS218{32}, iS219{8192}, iS221{128}rf] = pad( T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}], {0, 0, 0, 0, 0, 0, 64, 0} ) (54) T38_l_float[bS162{1}, iS163{32}, iS164{8192}, iS165{128}] = __bfloat2float(T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]); (101) ... ``` There are multiple ways to achieve that. What seems to most make sense to me is to extend the existing forwarding method to handle cases like this. The existing method only considers ops starting with fusion inputs, which do not include factory-created tensors. This PR applies a small change to the forwarding logic to include factory ops as well. The end result of this change with the above example case is that the full result is no longer passed around. Here's the first segment: ``` g{(resize) group id: 3 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat outputs: T49_g___bfloat[bS208{1}, iS209{32}, iS210{8192}, iS211{128}] __bfloat T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}] = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} ) (16) T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} ) (129) T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}] = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} ) (0) T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} ) (128) T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] = full({1, 32, 8192, 128}, __bfloat(0)); ... ``` Notice that `T81` is no longer a segment output. And the second segment is: ``` g{(resize) group id: 4 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat T2_g___bfloat[bS7{1}, iS8{32}, iS9{8192}, iS10{128}] __bfloat T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat outputs: T74_g___bfloat[bS321{1}, iS326{8}rf, iS327{4}rf, iS323{8192}, iS324{128}] __bfloat T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}] = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} ) (16) T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} ) (129) T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}] = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} ) (0) T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} ) (128) T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] = full({1, 32, 8192, 128}, __bfloat(0)); (121) ... ``` --- csrc/fusion_segmenter.cpp | 65 ++++++++++++++++++++++++++++++--- csrc/fusion_segmenter.h | 3 ++ tests/cpp/test_segmentation.cpp | 59 +++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 8 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 90041d12375..7619cab1ce6 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -2157,6 +2157,21 @@ SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( !groups_to_merge.empty(), "fusion segment :(mergeAllGivenGroups) tried to merge no groups"); + // The fusion input auxiliary groups should never be merged. + const auto& aux_input_groups = getAuxiliaryInputGroups(); + std::vector aux_groups_to_merge; + std::ranges::copy_if( + groups_to_merge, + std::back_inserter(aux_groups_to_merge), + [&](SegmentedGroup* group) { + return std::ranges::find(aux_input_groups, group) != + aux_input_groups.end(); + }); + NVF_ERROR( + aux_groups_to_merge.empty(), + "Trying to merge auxiliary input groups: ", + toDelimitedString(aux_groups_to_merge)); + // Make a set to detect internal edges std::unordered_set group_set( groups_to_merge.begin(), groups_to_merge.end()); @@ -3301,7 +3316,6 @@ class CombineReductions { return groups_to_merge_set.has(group); }), groups_with_reductions_.end()); - return joined_group; } } @@ -3606,6 +3620,10 @@ class MergeUpAndDownCast { SegmentedGroup* group = to_visit.front(); to_visit.pop_front(); + if (group->exprs().empty()) { + continue; + } + if (groups_to_merge_set.count(group)) { continue; } @@ -4359,12 +4377,28 @@ void SegmentCandidateFinder::forwardInputs() { excluded_inp_unary_exprs_ = {}; input2group_.clear(); + std::vector extended_fusion_inputs = completeFusion()->inputs(); + + // Grab factory ops that should be forwarded. Add created tensors to + // the fusion input list to make them handled like fusion inputs + // TODO: Handle more factory methods such as IotaOp, EyeOp, + // TensorConstruct. Probably should not include relatively expensive + // ops like RNGOp. + for (auto expr : completeFusion()->exprs()) { + if (expr->isA() && + // Don't bother if it's a fusion output + !expr->output(0)->isFusionOutput()) { + extended_fusion_inputs.push_back(expr->output(0)); + excluded_inp_unary_exprs_.pushBack(expr); + } + } + // "Terminating" outputs from the excluded input unary exprs, these will be // treated as complete fusion inputs. VectorOfUniqueEntries forwarded_inputs; { std::deque to_visit; - for (Val* inp : completeFusion()->inputs()) { + for (Val* inp : extended_fusion_inputs) { if (UnaryOp* unary_use = shouldForward(inp)) { to_visit.push_back(unary_use); } @@ -4387,11 +4421,13 @@ void SegmentCandidateFinder::forwardInputs() { } } - auto excluded_fusion_inputs = IterVisitor::getInputsTo( - {forwarded_inputs.begin(), forwarded_inputs.end()}); + // Stop traversing back at factory vals (and fusion inputs) + auto excluded_fusion_inputs = InputsOf::getInputsTo( + {forwarded_inputs.begin(), forwarded_inputs.end()}, + extended_fusion_inputs); // List of vals to treat as complete fusion inputs for segmentation - forwarded_fusion_inputs_ = completeFusion()->inputs(); + forwarded_fusion_inputs_ = extended_fusion_inputs; forwarded_fusion_inputs_.erase( std::remove_if( @@ -4430,6 +4466,16 @@ void SegmentCandidateFinder::cleanupForwardedInputs() { input2group_.clear(); } +std::vector SegmentCandidateFinder::getAuxiliaryInputGroups() + const { + std::vector aux_groups; + aux_groups.reserve(input2group_.size()); + std::ranges::transform(input2group_, aux_groups.begin(), [](const auto& kv) { + return kv.second; + }); + return aux_groups; +} + void SegmentCandidateFinder::finalMerge() { FUSER_PERF_SCOPE("SegmentCandidateFinder::finalMerge"); auto producer_check = getGroupDependency(); @@ -4642,7 +4688,14 @@ void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) { SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) { SegmentedGroup* group = segmented_fusion_->newGroup(); - group->input_vals_ = IterVisitor::getInputsTo({forwarded_input}); + for (auto inp : IterVisitor::getInputsTo({forwarded_input})) { + // inp may be a factory-created tensor, which is not an input to + // the group. + if (std::ranges::find(completeFusion()->inputs(), inp) != + completeFusion()->inputs().end()) { + group->input_vals_.pushBack(inp); + } + } group->exprs_ = StmtSort::getExprsTo({forwarded_input}); return group; } diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index ace8a0bfb02..cfcbf75fd5d 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -693,6 +693,9 @@ class SegmentCandidateFinder { val) != forwarded_fusion_inputs_.end(); }; + // Get all auxiliary groups created for fusion inputs + std::vector getAuxiliaryInputGroups() const; + protected: //! These are the merge node heuristic passes, should //! eventually should have a dedicated interface diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 24db8d7793a..acb65bff7f6 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -695,7 +695,7 @@ TEST_F(SegmentationTest, ForwardInputsToSegmenterSetIssue2658) { } // Test to verify an upcast is replicated between different segments -TEST_F(NVFuserTest, PrivatizeUpcast) { +TEST_F(SegmentationTest, PrivatizeUpcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -741,7 +741,7 @@ TEST_F(NVFuserTest, PrivatizeUpcast) { // Unlike PrivatizeUpcast, verify replicated upcast ops are // consolidated back as they are grouped into the same segment -TEST_F(NVFuserTest, RevertPrivatizedUpcast) { +TEST_F(SegmentationTest, RevertPrivatizedUpcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -807,4 +807,59 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { } } +TEST_F(SegmentationTest, ForwardFull) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // FullOp that is used in two segments + auto tv1 = full({tv0->axis(0)->extent()}, fusion.oneVal(), DataType::Float); + + auto tv2 = add(tv0, tv1); + auto tv3 = segment_set(tv2); + + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1024}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(2)); + + // Make sure the full output should not be a segment input + for (const auto& executor : runtime->executors()) { + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + bool full_op_found = false; + for (auto expr : KernelExprVisitor::getAllExprs(kernel)) { + auto out_tv = ir_utils::getTvOutput(expr); + if (out_tv == nullptr) { + continue; + } + auto full_op = dynamic_cast(out_tv->definition()); + if (full_op == nullptr) { + continue; + } + full_op_found = true; + auto output_it = + std::ranges::find_if(kernel->outputs(), [&](Val* output) { + return output->isA() && + output->name() == out_tv->name(); + }); + EXPECT_EQ(output_it, kernel->outputs().end()) + << "FullOp ouput should not be a segment output"; + } + EXPECT_TRUE(full_op_found) << "Each segment has its own FullOp"; + } +} + } // namespace nvfuser From c9d2cc9d808d04b49276422aae8ad6945c2f60a6 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:43:18 -0700 Subject: [PATCH 54/68] Update propagateSharding preseg pass for DID loop split (#3838) This PR extends the `propagateSharding` presegmentation pass for DID loop splits. Key changes: 1. We use TransformPropagator for all expressions except `ViewOp` which is handled manually since TransformPropagator does not support it without first propagating the reshape to the producer. 2. `makeReshardingContiguous` sets allocation domain for tvs with device mesh. Ideally, we need to set it only for global tensors but this is not known before segmentation, but should be set before segmentation. 3. ~The following tests are modified: See [discussion](https://github.com/NVIDIA/Fuser/pull/3838#issuecomment-2807645305)~. PR #4274 resolved this. Follow-up PRs: - `ViewOp` will be handled in a followup PR. - Currently, we only backpropagate sharding for a tv that does not already have a device dimension. This can be extended to propagate for all parallel types not present on the tv. This will be done in a followup. Backpropagating shardings can incorrectly change DIDx to serial or modify DIDx to be on another location. `shardAllLike` can be modified to specify which parallel type to propagate. Since `insertResharding` and `propagateSharding` require different behavior, I will handle it in a separate PR. - Use `TransformReplay::CasP` in lieu of TransformPropagator. - Propagate DID transforms within `castOp`: [privatizeUpcast](https://github.com/NVIDIA/Fuser/blob/ed687366cf717837c8ea3e40f56542fec48e1616/csrc/fusion_segmenter.cpp#L4235-L4238) clones cast operations, which fails segmentation since the transforms are not replicated. Findings from experiments: https://github.com/NVIDIA/Fuser/pull/3838#issuecomment-2807645305 --------- Co-authored-by: Jingyue Wu --- .../make_resharding_contiguous.cpp | 139 +++++++- .../make_resharding_contiguous.h | 15 +- csrc/preseg_passes/propagate_shardings.cpp | 331 ++++++++++++++---- tests/cpp/test_multidevice_sharding.cpp | 155 -------- tests/cpp/test_multidevice_transformer.cpp | 138 ++++++++ tests/cpp/test_sharding.cpp | 35 ++ 6 files changed, 561 insertions(+), 252 deletions(-) diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 04fbe0d7173..359f562011d 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -12,35 +12,140 @@ #include #include #include +#include namespace nvfuser::preseg_passes { namespace { -void setShardedAllocationDomain(TensorView* tv) { - if (!tv->hasAllocation()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); + +// Validates meshes (i.e. all TensorViews have a device mesh or none) and +// returns true if any TensorView has a device mesh. +bool validateMeshes(Fusion* fusion) { + // Validate that meshes are assigned to all TensorViews or none. + bool tv_with_mesh_found = false; + bool tv_without_mesh_found = false; + + for (auto tv : fusion->allTvs()) { + if (tv->isCpuScalar()) { + continue; + } + tv->hasDeviceMesh() ? tv_with_mesh_found = true + : tv_without_mesh_found = true; } + NVF_CHECK( + !(tv_with_mesh_found && tv_without_mesh_found), + "Cannot have some TensorViews with device mesh and some without."); + return tv_with_mesh_found; } + +// Reorders the loop domain in the same relative order as the allocation domain. +// Specifically: +// 1. It uses the exprs between logical and loop domain to split the allocation +// domain +// 2. It reorders the loop domain to match the split allocation domain. +// 3. It computes the contiguity of the transformed allocation domain through +// the split exprs. +// 4. Sets the allocation domain to be the same as the loop domain with the +// computed contiguity. This preserves both the sharding and any stride order. +// Note: Ideally, the loop domain can follow the logical domain and the +// allocation domain can follow the stride order specified/inferred. However, we +// currently require loop domain to be the same as allocation domain. This +// behavior will be modified in the future with allocation and loop domain being +// propagated independently. +void setLoopAndAllocationDomain(TensorView* tv) { + auto alloc_dom = tv->getMaybeAllocationDomain(); + auto contiguity = tv->getContiguity(); + + auto splitContiguity = [](std::optional contiguity) + -> std::pair, std::optional> { + if (!contiguity.has_value()) { + return std::make_pair(std::nullopt, std::nullopt); + } + if (contiguity.value()) { + return std::make_pair(true, true); + } + return std::make_pair(true, false); + }; + + // Allocation domain should be a permutation of logical domain at this point. + std::vector transform_exprs = DependencyCheck::getAllExprsBetween( + {alloc_dom.begin(), alloc_dom.end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + + NVF_ERROR( + std::all_of( + transform_exprs.begin(), + transform_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); + + for (auto* expr : transform_exprs) { + Split* split = dynamic_cast(expr); + auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); + NVF_ERROR( + find_it != alloc_dom.end(), + "Split input ", + split->in()->toString(), + " not found in given ids: ", + alloc_dom); + + auto pos = std::distance(alloc_dom.begin(), find_it); + auto [outer_contiguity, inner_contiguity] = + splitContiguity(contiguity.at(pos)); + + alloc_dom[pos] = split->inner(); + alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); + + contiguity[pos] = inner_contiguity; + contiguity.insert(contiguity.begin() + pos, outer_contiguity); + } + + std::optional> permutation = + ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); + tv->reorder(permutation.value()); + tv->setAllocationDomain(tv->getLoopDomain(), contiguity); +} + +bool isTvContiguous(TensorView* tv) { + return std::all_of( + tv->getContiguity().begin(), + tv->getContiguity().end(), + [](const std::optional& c) { return c.value_or(true); }); +} + } // namespace void MakeReshardingContiguousPass::runPass(Fusion* fusion) { + bool has_mesh = validateMeshes(fusion); + if (!has_mesh) { + return; + } + for (Expr* expr : fusion->exprs()) { - if (!isResharding(expr)) { - continue; + auto inputs = ir_utils::filterByType(expr->inputs()); + auto outputs = ir_utils::filterByType(expr->outputs()); + + for (auto tv : inputs) { + setLoopAndAllocationDomain(tv); } - for (auto* tv : ir_utils::filterByType(expr->inputs())) { - for (auto c : tv->getContiguity()) { - if (c.has_value()) { - NVF_CHECK( - c.value(), - "Resharding expression input must be contiguous: ", - expr); - } - } - setShardedAllocationDomain(tv); + for (auto tv : outputs) { + setLoopAndAllocationDomain(tv); } - for (auto tv : ir_utils::filterByType(expr->outputs())) { - setShardedAllocationDomain(tv); + + if (isResharding(expr)) { + auto check_contiguity = [&](const auto& tvs) { + return std::all_of(tvs.begin(), tvs.end(), isTvContiguous); + }; + NVF_CHECK( + check_contiguity(inputs) && check_contiguity(outputs), + "Resharding expression must have contiguous inputs and outputs: ", + expr); } } } diff --git a/csrc/preseg_passes/make_resharding_contiguous.h b/csrc/preseg_passes/make_resharding_contiguous.h index 60ded24f76d..8a719683004 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.h +++ b/csrc/preseg_passes/make_resharding_contiguous.h @@ -15,11 +15,18 @@ namespace nvfuser::preseg_passes { -// Resharding expressions are mapped to collective libraries which expect +// This pass: +// 1. Validates that all TensorViews have a device mesh or none. +// 2. Resharding expressions are mapped to collective libraries which expect // contiguous tensors and output contiguous buffers. This pass checks that -// inputs are contiguous and sets the allocation domain of inputs and outputs of -// all resharding expressions. This pass should run after all passes that add or -// update resharding expressions. +// inputs are contiguous. +// 3. Sets the allocation domain of all fusion tvs if they have a device mesh. +// The allocation domain is obtained by transforming the `maybeAllocationDomain` +// using the transforms to loop domain. This ensures that the allocation domain +// has DID loop splits. All iterdomains derived from a given logical iterdomain +// are placed together. See `setLoopAndAllocationDomain` for more details. +// Eventually, this pass should run after `markAliasesPrepare` and +// `AllocationDomainPass` after they are fixed. class MakeReshardingContiguousPass : public OptimizationPass { friend class OptimizationPass; diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 69ba5983060..b3f7344a8e7 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -13,117 +13,296 @@ #include #include #include +#include +#include namespace nvfuser::preseg_passes { namespace { -void validateMeshes(Fusion* fusion) { - // Validate that meshes are assigned to all TensorViews or none. - TensorView* tv_with_mesh = nullptr; - TensorView* tv_without_mesh = nullptr; - for (TensorView* tv : fusion->allTvs()) { - auto update_if_null = [](TensorView*& lhs, TensorView* rhs) { - if (lhs == nullptr) { - lhs = rhs; - } - }; - if (tv->isCpuScalar()) { - continue; +template +std::vector filterTvsWithMesh(const Range& tvs) { + std::vector tvs_with_mesh; + std::copy_if( + tvs.begin(), + tvs.end(), + std::back_inserter(tvs_with_mesh), + [](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); }); + return tvs_with_mesh; +} + +int64_t numDeviceDims(TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + std::mem_fn(&IterDomain::isDeviceDim)); +} + +// Sort the given tvs by the number of device dimensions in descending order. +// Break ties by the total number of dimensions. +// Only includes TensorViews that have a device mesh. +template +std::vector sortTvsByDeviceDims(const Range& tvs) { + // Filter out TVs without a device mesh + std::vector tvs_with_mesh = filterTvsWithMesh(tvs); + + // Then sort the filtered TVs + std::stable_sort( + tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { + int64_t a_device_dims = numDeviceDims(a); + int64_t b_device_dims = numDeviceDims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims > b_device_dims; + } + // Break ties by the total number of dimensions + return a->nDims() > b->nDims(); + }); + + return tvs_with_mesh; +} + +// Order the inputs of the expression based on their priority. +// For linear op, we use weights and bias before input. +// For matmul op, we use weights before input. +// For other ops, we sort the inputs by the number of device dimensions in +// descending order. +std::vector getOrderedReferenceInputs(Expr* expr) { + const auto& inputs = ir_utils::filterByType(expr->inputs()); + if (LinearOp* linear_op = dynamic_cast(expr)) { + // Use weights and bias before input. + return filterTvsWithMesh(std::vector( + {linear_op->inB(), linear_op->bias(), linear_op->inA()})); + } + + if (MatmulOp* matmul_op = dynamic_cast(expr)) { + // Use weights before input. + return filterTvsWithMesh( + std::vector({matmul_op->inB(), matmul_op->inA()})); + } + + // Sort inputs by number of device dimensions in descending order + std::vector sorted_inputs = sortTvsByDeviceDims(inputs); + + return sorted_inputs; +} + +std::vector getOutputsWithoutMesh(Expr* expr) { + const auto& outputs = ir_utils::filterByType(expr->outputs()); + std::vector outputs_without_mesh; + std::copy_if( + outputs.begin(), + outputs.end(), + std::back_inserter(outputs_without_mesh), + [](TensorView* tv) { return !tv->hasDeviceMesh(); }); + return outputs_without_mesh; +} + +// Custom selector to specify direction of transform propagation. +class PropagateShardingsSelector : public SetSelector { + private: + bool allow_c2p_; + bool allow_p2c_; + + public: + explicit PropagateShardingsSelector( + const std::unordered_set& selected_tvs, + bool allow_c2p = true, + bool allow_p2c = true) + : SetSelector(selected_tvs), + allow_c2p_(allow_c2p), + allow_p2c_(allow_p2c) {} + + bool allowC2P(TensorView* from, TensorView* to) override { + return allow_c2p_ && SetSelector::allowC2P(from, to); + } + + bool allowP2C(TensorView* from, TensorView* to) override { + return allow_p2c_ && SetSelector::allowP2C(from, to); + } +}; + +// Reorder the DID axis with the given parallel types to the front. +// Returns the number of device dimensions that were reordered to the front. +// This allows us to limit propagation to only the relevant DID axis. +int64_t selectiveReorderDIDToFront( + TensorView* tv, + const std::unordered_set& selected_parallel_types) { + std::unordered_map old2new; + int64_t current_pos = 0; + + for (auto&& [pos, id] : enumerate(tv->getLoopDomain())) { + if (id->isDeviceDim() && + selected_parallel_types.count(id->getParallelType())) { + old2new[pos] = current_pos; + current_pos++; } + } - if (tv->hasDeviceMesh()) { - update_if_null(tv_with_mesh, tv); - } else { - update_if_null(tv_without_mesh, tv); + tv->reorder(old2new); + return current_pos; +} + +// Returns the set of parallel types seen on the loop domain of the given tvs. +std::unordered_set getParallelTypesToPropagate( + std::vector tvs) { + // Get the set of parallel types seen on the loop domain of the given tvs. + std::unordered_set existing_parallel_types; + for (auto tv : tvs) { + for (auto id : tv->getLoopDomain()) { + if (id->isDeviceDim()) { + existing_parallel_types.insert(id->getParallelType()); + } + } + } + std::unordered_set selected_parallel_types; + for (ParallelType pt : kParallelTypeDIDs) { + if (!existing_parallel_types.count(pt)) { + selected_parallel_types.insert(pt); } } - NVF_CHECK( - tv_with_mesh == nullptr || tv_without_mesh == nullptr, - "Found ", - tv_with_mesh, - " assigned a mesh and ", - tv_without_mesh, - " not."); + return selected_parallel_types; +} + +void propagateDIDTransform( + TensorView* ref, + std::vector tvs, + int64_t did_pos, + bool allow_c2p, + bool allow_p2c) { + TransformPropagator propagator(ref, did_pos); + PropagateShardingsSelector selector( + {tvs.begin(), tvs.end()}, allow_c2p, allow_p2c); + MaxLogicalDomainInfoSpanningTree(ref, &selector).traverse(&propagator); } + } // namespace +// This presegmentation pass propagates shardings from fusion inputs to +// downstream tensorviews. +// 1. Forward propagating DID loop splits and parallelization from inputs to +// outputs that don't have a mesh using TransformPropagator +// 2. Reshape is handled manually since the DID loop split transforms conflict +// with the reshape root-to-logical transforms if using TransformPropagator +// 3. Back-propagating device meshes to ensure all TensorViews have consistent +// meshes. This also splits and parallelizes unsharded inputs based on outputs. +// See `MultiDevicePresegPassesTest.ResidualAdd` for an example. +// 4. Reorders the loop domain as the allocation order. Ideally, loop domain +// should follow logical domain and allocation domain should follow any stride +// order specified/inferred. However, we currently require loop domain to be the +// same as allocation domain. void PropagateShardingsPass::runPass(Fusion* fusion) { - auto num_device_parallel_dimensions = [](const TensorView* tv) -> int64_t { - return std::count_if( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - std::mem_fn(&IterDomain::isDeviceDim)); - }; - const std::vector& exprs = fusion->exprs(); + for (Expr* expr : exprs) { - const auto& inputs = ir_utils::filterByType(expr->inputs()); - // Pick the "most parallel" input tensor as the reference. This is useful - // for propagating tensor parallelism from weights to MLP's intermediate - // tensors. For example, - // - // x: [b, s, h]; replicated. - // w0: [h, 4*h]; column-wise sharded. - // w1: [4*h, h]; row-wise sharded. - // y = matmul(x, w0) - // z = matmul(y, w1) - // - // With the above heuristic, `y` can be automatically sharded column-wise. - TensorView* ref_input = nullptr; - auto max_num_dids = std::numeric_limits::min(); - for (auto* input : inputs) { - if (!input->hasDeviceMesh()) { - continue; - } - int64_t num_dids = num_device_parallel_dimensions(input); - if (num_dids > max_num_dids) { - max_num_dids = num_dids; - ref_input = input; - } + // Note: Tvs without a mesh are assumed to have no manual sharding + // annotation and are sharded like the first producer Tv. + const auto& outputs_without_mesh = getOutputsWithoutMesh(expr); + if (outputs_without_mesh.empty()) { + continue; } - if (ref_input == nullptr) { + + const auto& reference_inputs = getOrderedReferenceInputs(expr); + + if (reference_inputs.empty()) { continue; } + // Propagate shardings from reference inputs in order. + for (auto* ref_input : reference_inputs) { + // Skip if the input has no device mesh or is nullptr. + NVF_ERROR( + ref_input != nullptr && ref_input->hasDeviceMesh(), + "Reference input ", + ref_input, + " has no device mesh."); - // Note: Tvs without a mesh are assumed to have no manual sharding - // annotation and are sharded like the first producer Tv. - const auto& outputs = ir_utils::filterByType(expr->outputs()); - std::vector outputs_without_mesh; - for (auto* tv : outputs) { - if (!tv->hasDeviceMesh()) { - outputs_without_mesh.push_back(tv); - } + // Reorder the DID axis to the front only if it does not have a parallel + // type already seen on the outputs. This avoids propagating the same + // parallel type on multiple axis of the output when using multiple + // reference inputs. Consider out [M, N] = linear (inp [M, K], weight (N, + // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N + // ([DIDy(d), N/d, K]). We propagate from weights first, so the output + // will be [M, DIDx(d), N/d]. When we propagate from inp next, we should + // not propagate DIDx parallel type to the output. Otherwise, the output + // will have multiple DIDx shardings which is invalid. + std::unordered_set selected_parallel_types = + getParallelTypesToPropagate(outputs_without_mesh); + + // This restricts the transform propagation to only the relevant DID axis. + int64_t did_pos = + selectiveReorderDIDToFront(ref_input, selected_parallel_types); + + // Propagate the DID loop split to the outputs without mesh. + propagateDIDTransform( + /*ref=*/ref_input, + /*tvs=*/outputs_without_mesh, + /*did_pos=*/did_pos, + /*allow_c2p=*/false, + /*allow_p2c=*/true); + + // Apply parallelization on the outputs without mesh. + shardAllLike(ref_input, outputs_without_mesh, selected_parallel_types); } - shardAllLike(ref_input, outputs_without_mesh); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh // if any of them has one. This is needed in addition to the forward // propagation for ops that don't take any TensorView operands, e.g., // `uniform` used in dropout. See MultiDeviceTest.BackpropMeshes for an - // example. - for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { - Expr* expr = *i_expr; + // example. For non-fusion inputs, we also propagate shardings from outputs to + // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. + for (Expr* expr : exprs | std::views::reverse) { const auto& outputs = ir_utils::filterByType(expr->outputs()); - auto i_output = std::find_if( - outputs.begin(), - outputs.end(), - std::mem_fn(&TensorView::hasDeviceMesh)); - if (i_output == outputs.end()) { + // All outputs of an expression (Welford, SDPA) should be uniformly sharded. + // We pick the most parallel output as the reference. + // This is to avoid picking seed/offset tvs in SDPA. + std::vector sorted_outputs = sortTvsByDeviceDims(outputs); + + if (sorted_outputs.empty()) { + // No output with a device mesh. continue; } - TensorView* output_with_mesh = *i_output; + TensorView* ref_output = sorted_outputs.front(); + NVF_ERROR( + ref_output != nullptr && ref_output->hasDeviceMesh(), + "Reference output ", + ref_output, + " has no device mesh."); + + // For fusion inputs, only check if they have a device mesh. We do not + // modify their sharding. For non-fusion inputs, we try to propagate + // shardings from the reference output for parallel types that are not + // already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); + std::vector sharding_candidates; for (auto* tv : inputs) { - if (!tv->hasDeviceMesh()) { - tv->setDeviceMesh(output_with_mesh->getDeviceMesh()); + if (tv->isFusionInput()) { + if (!tv->hasDeviceMesh()) { + tv->setDeviceMesh(ref_output->getDeviceMesh()); + } + continue; + } + if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { + sharding_candidates.push_back(tv); } } - } - validateMeshes(fusion); + if (sharding_candidates.empty()) { + continue; + } + + int64_t did_pos = selectiveReorderDIDToFront(ref_output, {}); + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/sharding_candidates, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + shardAllLike(ref_output, sharding_candidates); + } } } // namespace nvfuser::preseg_passes diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 2309dc4cd36..5a26d5d6622 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -894,159 +894,4 @@ TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { __FILE__); } -namespace { -// This is a simplified version of what we will eventually do in the -// pre-segmentation pass -void propagateShardings(Fusion* fusion, int64_t num_devices) { - for (Expr* expr : fusion->exprs()) { - if (expr->isA()) { - NVF_THROW("SliceOp is not currently supported"); - } - - if (expr->isA()) { - // TransformPropagator cannot be directly used. - // It raises an error for conflicting transformations from root domain to - // logical domain. Instead, we manually find the reshaped iterdomain and - // outer split DID. This might have to be extended further in the - // presegmentation pass. - // Note: For simplicity, this assumes that the sharding is on reshaped - // IDs. It is possible that the non-reshaped IDs are sharded, in which - // case we can use the TransformPropagator. - TensorView* reshaped_tv = expr->as()->out(); - auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), - reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), - reshaped_tv->getLogicalDomain().end()}); - NVF_CHECK(transform_exprs.size() == 1); - auto transform = transform_exprs[0]; - NVF_CHECK(transform->isA() || transform->isA()); - - // Get the reshaped ID (outer ID for split reshape). - // This is the ID that will be parallelized. - IterDomain* reshaped_id = transform->isA() - ? transform->as()->outer() - : transform->as()->out(); - - auto reshaped_it = std::find( - reshaped_tv->getLoopDomain().begin(), - reshaped_tv->getLoopDomain().end(), - reshaped_id); - int64_t reshaped_axis = - std::distance(reshaped_tv->getLoopDomain().begin(), reshaped_it); - - // Apply sharding to the reshaped tensor - reshaped_tv->split(reshaped_axis, num_devices, false); - reshaped_tv->axis(reshaped_axis)->parallelize(ParallelType::DIDx); - reorderDIDToFront(reshaped_tv); - continue; - } - - // For other ops, propagate sharding from input to outputs - auto input_tv = expr->input(0)->as(); - std::vector output_tvs; - for (auto output : expr->outputs()) { - output_tvs.push_back(output->as()); - } - - TransformPropagator propagator(input_tv); - - // Note: We will finally propagate from each input iteratively. - SetSelector selector( - std::unordered_set(output_tvs.begin(), output_tvs.end())); - MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); - shardAllLike(input_tv, output_tvs); - } -} - -} // namespace - -TEST_F(MultiDeviceTest, TransformerFwd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 16; - auto mesh = DeviceMesh::createForNumDevices(d); - - std::vector in_shape = {b, s, d * h * e}; - std::vector out_shape = {b, s, d * h, e}; - - // The transformer block produces hq/hk/hv after slicing the MHA linear - // output. - TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); - - TensorView* q = reshape(hq, in_shape, out_shape); - TensorView* q_permuted = permute(q, {0, 2, 1, 3}); - TensorView* k = reshape(hk, in_shape, out_shape); - TensorView* k_permuted = permute(k, {0, 2, 1, 3}); - TensorView* v = reshape(hv, in_shape, out_shape); - TensorView* v_permuted = permute(v, {0, 2, 1, 3}); - - SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), - /*scale=*/nullptr); - - TensorView* attn = sdpa_out.output; - TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* out = reshape(attn_permute, out_shape, in_shape); - - fusion->addInput(hq); - fusion->addInput(hk); - fusion->addInput(hv); - fusion->addOutput(out); - - // Shard input tensors - for (auto* tv : {hq, hk, hv}) { - tv->setDeviceMesh(mesh); - tv->split(-1, d, /*inner_split=*/false); - tv->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv); - } - propagateShardings(fusion.get(), d); - - for (auto tv : fusion->allTvs()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - - at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); - at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); - at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - - auto nvf_out = - executor_cache - .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] - .as(); - - double scale = 1.0 / std::sqrt(e); - auto reference_out = at::_scaled_dot_product_flash_attention( - hq_tensor.view(out_shape).transpose(1, 2), - hk_tensor.view(out_shape).transpose(1, 2), - hv_tensor.view(out_shape).transpose(1, 2), - /*dropout_p=*/0.0, - /*is_causal=*/false, - /*return_debug_mask=*/false, - scale); - at::Tensor ref_attn = shardTensor( - std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); - - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_hq, sharded_hk, sharded_hv}, - {ref_attn}, - __LINE__, - __FILE__); -} } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 55f23bdc5a1..025726698be 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -1016,6 +1016,144 @@ TEST_P(DistributedTransformerTest, Backward) { 0.02}); } +namespace { +at::Tensor reference_loop_split_mlp( + at::Tensor inp, + at::Tensor w0, + at::Tensor w1) { + auto linear0 = at::linear(inp, w0); + auto gelu = at::gelu(linear0, "tanh"); + auto linear1 = at::linear(gelu, w1); + return linear1; +} + +at::Tensor reference_loop_split_mha(at::Tensor inp) { + auto qkv = inp.transpose(1, 2).split(E / H, -1); + double scale = 1.0 / std::sqrt(E / H); + auto sdpa_out = at::_scaled_dot_product_flash_attention( + qkv[0], + qkv[1], + qkv[2], + /*dropout_p=*/kDropoutProb, + /*is_causal=*/true, + /*return_debug_mask=*/false, + scale); + auto attn = std::get<0>(sdpa_out); + return attn.transpose(1, 2); +} +} // namespace + +// TODO: Allow testing for float16 and bfloat16 for loop split mlp and mha +// This currently fails because privatizeUpcast clones cast operations, +// which fails segmentation since the transforms are not replicated. +TEST_F(DistributedTransformerTest, LoopSplitMLP) { + if ((4 * E) % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide 4*E=" << 4 * E; + } + auto dtype = DataType::Float; + at::ScalarType at_dtype = data_type_to_aten(dtype); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* inp = makeContigConcreteTensor({B, S, E}, dtype); + TensorView* w0 = makeContigConcreteTensor({4 * E, E}, dtype); + TensorView* w1 = makeContigConcreteTensor({E, 4 * E}, dtype); + + TensorView* linear0 = linear(inp, w0); + TensorView* linear0_float = castOp(DataType::Float, linear0); + TensorView* gelu = tanh_gelu(linear0_float); + TensorView* gelu_dtype = castOp(dtype, gelu); + TensorView* linear1 = linear(gelu_dtype, w1); + + std::vector fusion_inputs{inp, w0, w1}; + for (auto tv : fusion_inputs) { + fusion->addInput(tv); + tv->setDeviceMesh(mesh); + } + fusion->addOutput(linear1); + + w0->outer_split(0, d); + w0->axis(0)->parallelize(ParallelType::DIDx); + w1->outer_split(1, d); + w1->axis(1)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp_tensor = at::randn({B, S, E}, tensor_options.dtype(at_dtype)); + at::Tensor w0_tensor = at::randn({4 * E, E}, tensor_options.dtype(at_dtype)); + at::Tensor w1_tensor = at::randn({E, 4 * E}, tensor_options.dtype(at_dtype)); + + at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); + at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); + + KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + + at::Tensor ref_out = + reference_loop_split_mlp(inp_tensor, w0_tensor, w1_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + +TEST_F(DistributedTransformerTest, LoopSplitMHAFwd) { + if (H % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide H=" << H; + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto dtype = DataType::Half; + at::ScalarType at_dtype = data_type_to_aten(dtype); + + const int d = communicator_->size(); + + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* qkv = makeContigConcreteTensor({B, S, H, 3 * E / H}, dtype); + TensorView* q = slice(qkv, {0, 0, 0, 0}, {B, S, H, E / H}); + TensorView* k = slice(qkv, {0, 0, 0, E / H}, {B, S, H, 2 * E / H}); + TensorView* v = slice(qkv, {0, 0, 0, 2 * E / H}, {B, S, H, 3 * E / H}); + + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); + TensorView* k_permuted = permute(k, {0, 2, 1, 3}); + TensorView* v_permuted = permute(v, {0, 2, 1, 3}); + + SdpfaFwdResult sdpa_out = sdpfa_fwd( + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(kDropoutProb), + /*is_causal=*/IrBuilder::create(true), + /*scale=*/nullptr); + + TensorView* attn = sdpa_out.output; + TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); + + fusion->addInput(qkv); + fusion->addOutput(attn_permute); + + qkv->setDeviceMesh(mesh); + qkv->outer_split(2, d); + qkv->axis(2)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor unsharded_inp_tensor = + at::randn({B, S, H, 3 * E / H}, tensor_options.dtype(at_dtype)); + at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); + + KernelArgumentHolder args = {inp_tensor}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + at::Tensor ref_out = reference_loop_split_mha(inp_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + INSTANTIATE_TEST_SUITE_P( , DistributedTransformerTest, diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 1ce1d96d8d0..ffbbabb7402 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -234,6 +234,41 @@ TEST_F(ShardingTest, MultiDimDeviceMesh) { EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDx), slice_didx); } +TEST_F(ShardingTest, ResidualAdd) { + // This is similar to the residual add after MHA dropout in the transformer. + // The output of linear following MHA is all-gathered and sharded on the + // sequence dim. This sharding can be propagated to the linear output through + // backpropagating the shardings from residual add. This information is not + // present during forward propagation. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + DeviceMesh mesh({0, 1}); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = uniform( + shape(tv0), + fusion->zeroVal(DataType::Float), + fusion->oneVal(DataType::Float), + DataType::Float); + TensorView* tv2 = add(tv0, tv1); + + tv0->setDeviceMesh(mesh); + tv0->outer_split(0, mesh.size()); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + fusion->addOutput(tv2); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + NVF_CHECK(tv1->hasDeviceMesh()); + NVF_CHECK( + getShardedLogicalAxis(tv1, ParallelType::DIDx) == + getShardedLogicalAxis(tv0, ParallelType::DIDx), + "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); +} + INSTANTIATE_TEST_SUITE_P( , ShardingTest, From 7477e4b186865fd1f37df1dbe2a502db14dd3cb8 Mon Sep 17 00:00:00 2001 From: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Date: Wed, 23 Apr 2025 20:43:36 -0700 Subject: [PATCH 55/68] Extract benchmarking timers into a separate class (#4291) The motivation is to use them in Thunder. --- benchmarks/python/core.py | 97 +++++----------------------------- nvfuser/benchmark_utils.py | 105 +++++++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+), 85 deletions(-) create mode 100644 nvfuser/benchmark_utils.py diff --git a/benchmarks/python/core.py b/benchmarks/python/core.py index c56d931cd35..735797c247b 100644 --- a/benchmarks/python/core.py +++ b/benchmarks/python/core.py @@ -4,8 +4,6 @@ from collections.abc import Iterable import pytest_benchmark import torch -from torch.autograd import DeviceType -from torch.profiler import profile, ProfilerActivity from typing import List, Callable, Union import numpy as np from nvfuser import FusionDefinition, FusionCache @@ -13,6 +11,7 @@ import warnings import thunder from thunder.executors.nvfuserex import nvfuserex +from nvfuser.benchmark_utils import TorchProfileTimer, FusionProfileTimer # These variables can be overwritten through CLI commands # --benchmark-rounds=rounds --benchmark-warmup-rounds=warmup_rounds @@ -102,20 +101,14 @@ def __init__( self.benchmark: Underlying pytest-benchmark fixture with timer modified to use torchprofile_timer self.current_time: Global montonic clock incremented based on elapsed CUDA time """ - self.device = device - self.fd = None # Set through setup() for host benchmarking. self.benchmark = benchmark_fixture + # Modify the default timer. if device == "cuda": - # Initialize a Torch Profiler object - self.prof = profile( - activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU] - ) - # Modify the default timer. - benchmark_fixture._timer = self.torchprofile_timer + benchmark_fixture._timer = TorchProfileTimer() else: - benchmark_fixture._timer = self.fusionprofile_timer + benchmark_fixture._timer = FusionProfileTimer() # Externally set the precision to avoid timer calibration. Since the timer uses CUDA times, # calibration using subsequent timer calls produces invalid results. # https://github.com/ionelmc/pytest-benchmark/blob/728752d2976ef53fde7e40beb3e55f09cf4d4736/src/pytest_benchmark/timers.py#L15 @@ -123,13 +116,6 @@ def __init__( self.benchmark = benchmark_fixture - # Global montonic clock - self.current_time = 0.0 - - # Specifies if the timer in host measurement is called at the start/finish of execution. - # Timings are measured at the end of execution. - self.execution_start = True - def __call__(self, function_to_benchmark: Callable, *args, **kwargs): return self.benchmark(function_to_benchmark, *args, **kwargs) @@ -138,73 +124,14 @@ def __getattr__(self, attr): return getattr(self.benchmark, attr) return super().__getattr__(attr) - def torchprofile_timer(self) -> float: - """ - Custom torchprofiler-based timer used by pytest-benchmark. - At every timer call, the profiler is stopped to compute the elapsed CUDA time - and the global clock is incremented. The profiler is restarted before returning to continue tracing. - - Returns: - self.current_time: Global monotonic clock variable - """ - try: - self.prof.stop() - except AssertionError: - self.prof.start() - return self.current_time - - prof_averages = self.prof.key_averages() - elapsed_cuda_time = self._get_kernel_time(prof_averages) - self._increment_global_time(elapsed_cuda_time) - # Clear the internal profiler object to avoid accumulating function events and then restart the profiler - # See PR: https://github.com/pytorch/pytorch/pull/125510 - self.prof.profiler = None - - return self.current_time - - def fusionprofile_timer(self) -> float: - if not self.execution_start: - profile = self.fd.profile() - elapsed_host_time = profile.host_time_ms / 1e3 - self._increment_global_time(elapsed_host_time) - self.execution_start = not self.execution_start - return self.current_time - - def _get_kernel_time( - self, prof_averages: torch.autograd.profiler_util.EventList - ) -> float: - """ - Arguments: - prof_averages: Output of self.prof.key_averages() - Returns: - time_value: Elapsed CUDA time in seconds. - """ - elapsed_cuda_time = 0 - has_cuda_event = False - for event in prof_averages: - if event.device_type != DeviceType.CUDA: - continue - has_cuda_event = True - # Re: torch profiler API changes in https://github.com/pytorch/pytorch/pull/123247 - elapsed_cuda_time = ( - elapsed_cuda_time + event.self_device_time_total - if hasattr(event, "self_device_time_total") - else event.self_cuda_time_total - ) - assert has_cuda_event, "No CUDA events found" - return elapsed_cuda_time / 1e6 - - def _increment_global_time(self, elapsed_time: float) -> None: - self.current_time += elapsed_time + # Set the fd object for fusion profiling. + # fd is returned by setup() for host benchmarking. + def set_fd(self, fd): + assert isinstance(self._timer, FusionProfileTimer) + self._timer.set_fd(fd) - def cleanup(self) -> None: - """ - Stops a running torchprofiler instance if found. - """ - try: - self.prof.stop() - except AssertionError: - pass + def cleanup(self): + self._timer.cleanup() def set_metrics( self, @@ -374,7 +301,7 @@ def setup(): # The host_benchmark_fn uses the `fd` object returned from setup function. def host_benchmark_fn(inputs, fd): # Set the fd variable used to query the profile object - nvf_benchmark.fd = fd + nvf_benchmark.set_fd(fd) return fd.execute(inputs, profile=True) benchmark_fn = benchmark_fn if benchmark_fn is not None else host_benchmark_fn diff --git a/nvfuser/benchmark_utils.py b/nvfuser/benchmark_utils.py new file mode 100644 index 00000000000..4949bbc599b --- /dev/null +++ b/nvfuser/benchmark_utils.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from torch.autograd import DeviceType +from torch.profiler import profile, ProfilerActivity +import torch + + +# Base class for all timers used by pytest-benchmark. +class Timer: + def __init__(self): + self.current_time = 0.0 + + def _increment_global_time(self, elapsed_time: float) -> None: + self.current_time += elapsed_time + + def __call__(self): + raise NotImplementedError("Subclass must implement this method") + + def cleanup(self): + pass + + +class TorchProfileTimer(Timer): + def __init__(self): + super().__init__() + self.prof = profile(activities=[ProfilerActivity.CUDA]) + + def _get_kernel_time( + self, prof_averages: torch.autograd.profiler_util.EventList + ) -> float: + """ + Arguments: + prof_averages: Output of self.prof.key_averages() + Returns: + time_value: Elapsed CUDA time in seconds. + """ + elapsed_cuda_time = 0 + has_cuda_event = False + for event in prof_averages: + if event.device_type != DeviceType.CUDA: + continue + has_cuda_event = True + # Re: torch profiler API changes in https://github.com/pytorch/pytorch/pull/123247 + elapsed_cuda_time = ( + elapsed_cuda_time + event.self_device_time_total + if hasattr(event, "self_device_time_total") + else event.self_cuda_time_total + ) + assert has_cuda_event, "No CUDA events found" + return elapsed_cuda_time / 1e6 + + def __call__(self): + """ + Custom torchprofiler-based timer used by pytest-benchmark. + At every timer call, the profiler is stopped to compute the elapsed CUDA time + and the global clock is incremented. The profiler is restarted before returning to continue tracing. + + Returns: + self.current_time: Global monotonic clock variable + """ + try: + self.prof.stop() + except AssertionError: + self.prof.start() + return self.current_time + + prof_averages = self.prof.key_averages() + elapsed_cuda_time = self._get_kernel_time(prof_averages) + self._increment_global_time(elapsed_cuda_time) + # Clear the internal profiler object to avoid accumulating function events and then restart the profiler + # See PR: https://github.com/pytorch/pytorch/pull/125510 + self.prof.profiler = None + + return self.current_time + + def cleanup(self): + """ + Stops a running torchprofiler instance if found. + """ + try: + self.prof.stop() + except AssertionError: + pass + + +class FusionProfileTimer(Timer): + def __init__(self): + super().__init__() + self.fd = None + # Specifies if the timer in host measurement is called at the start/finish of execution. + # Timings are measured at the end of execution. + self.execution_start = True + + def set_fd(self, fd): + self.fd = fd + + def __call__(self): + if not self.execution_start: + profile = self.fd.profile() + elapsed_host_time = profile.host_time_ms / 1e3 + self._increment_global_time(elapsed_host_time) + self.execution_start = not self.execution_start + return self.current_time From b2a76e94c68c738043ac3251611299b720e85c02 Mon Sep 17 00:00:00 2001 From: snordmann Date: Thu, 24 Apr 2025 06:33:09 -0700 Subject: [PATCH 56/68] add comment --- csrc/host_ir/lower.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index c36fae09e0a..626a7e67e28 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -720,6 +720,9 @@ std::unique_ptr HostIrLower::lower( } for (auto tv : hic->allTvs()) { + // set all host tensors to global memory type. This must be the case by + // definition of a host tensor, and setting the memory type to global is + // also required to avoid Allocate HIR nodes to throw tv->setMemoryType(MemoryType::Global); } From e798e06e4f79ec5a87f9baab10b2d4fc21ec3ebb Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 24 Apr 2025 06:55:29 -0700 Subject: [PATCH 57/68] Refactor python build (#4193) This PR updates the build to use a `pyproject.toml` and isolates the python bindings into `python` directory. ## Install From Source: ```bash git clone https://github.com/NVIDIA/Fuser.git cd Fuser pip install -r python/requirements.txt [MAX_JOBS] python setup.py develop [args] # DEPRECATED pip install --no-build-isolation -e python -v ``` ## Details - Moved `csrc/python_frontend` and `nvfuser` to `python`. - Moved `tools/gen_nvfuser_version.py` and `tools/memory.py` to `python`. - Created a new `setup.py` in `python`. This is the new primary `setup.py`. - Updated github workflows - Created symbolic links to support `setup.py` in root directory. ## Changes to argument passing to `root/setup.py` and `root/python/setup.py` - `python/utils.py` has the common utilities between `root/setup.py` and `root/python/setup.py` - Updated argument parsing to use `argparse` to create a `dataclass` configuration. - The `argparse` creates a default `dataclass` if no arguments are not provided in the command line. - `NVFUSER_BUILD_ENV_VARS` then overrides the values in the `dataclass`. - The `root/setup.py` only supports command-line arguments. --------- Co-authored-by: Wang, Xiao <24860335+xwang233@users.noreply.github.com> --- .github/workflows/build.yml | 1 + .github/workflows/lint.yml | 6 + .gitignore | 12 +- .lintrunner.toml | 2 +- CMakeLists.txt | 43 +- README.md | 10 + nvfuser | 1 + python/LICENSE | 1 + {nvfuser => python/nvfuser}/README.md | 0 {nvfuser => python/nvfuser}/__init__.py | 0 {nvfuser => python/nvfuser}/__init__.pyi | 0 .../nvfuser}/benchmark_utils.py | 0 .../nvfuser}/contrib/__init__.py | 0 .../nvfuser}/contrib/nn/__init__.py | 0 .../nvfuser}/contrib/nn/normalization.py | 0 .../nvfuser}/nvfuser_version.py | 0 {nvfuser => python/nvfuser}/pytorch_utils.py | 0 .../nvfuser}/testing/__init__.py | 0 {nvfuser => python/nvfuser}/testing/utils.py | 0 {nvfuser => python/nvfuser}/utils.py | 0 python/pyproject.toml | 3 + .../python_frontend/distributed_tensor.cpp | 0 .../python_frontend/distributed_tensor.h | 0 .../python_frontend/fusion_cache.cpp | 0 .../python_frontend/fusion_cache.h | 0 .../python_frontend/fusion_definition.cpp | 0 .../python_frontend/fusion_definition.h | 0 .../python_frontend/fusion_record.h | 0 .../python_frontend/fusion_state.cpp | 0 .../python_frontend/fusion_state.h | 0 .../python_frontend/multidevice_bindings.cpp | 0 .../python_frontend/python_bindings.cpp | 0 .../python_frontend/python_bindings.h | 0 .../python_bindings_extension.cpp | 0 .../python_frontend/schedule_bindings.cpp | 0 .../python_frontend/segmentation.cpp | 0 .../python_frontend/segmentation.h | 0 .../python_frontend/translation.cpp | 0 .../python_frontend/translation.h | 0 .../python_frontend/translation_utils.cpp | 0 .../python_frontend/translation_utils.h | 0 python/setup.py | 109 ++++ .../test_nvfuser_fusion_cache.cpp | 0 .../test_nvfuser_fusion_definition.cpp | 0 .../test_nvfuser_fusion_record.cpp | 0 python/tools/__init__.py | 3 + python/tools/gen_nvfuser_version.py | 75 +++ python/tools/memory.py | 28 + python/utils.py | 567 ++++++++++++++++++ version.txt => python/version.txt | 0 setup.py | 406 +------------ tools/gen_nvfuser_version.py | 76 +-- tools/memory.py | 29 +- 53 files changed, 868 insertions(+), 504 deletions(-) create mode 120000 nvfuser create mode 120000 python/LICENSE rename {nvfuser => python/nvfuser}/README.md (100%) rename {nvfuser => python/nvfuser}/__init__.py (100%) rename {nvfuser => python/nvfuser}/__init__.pyi (100%) rename {nvfuser => python/nvfuser}/benchmark_utils.py (100%) rename {nvfuser => python/nvfuser}/contrib/__init__.py (100%) rename {nvfuser => python/nvfuser}/contrib/nn/__init__.py (100%) rename {nvfuser => python/nvfuser}/contrib/nn/normalization.py (100%) rename {nvfuser => python/nvfuser}/nvfuser_version.py (100%) rename {nvfuser => python/nvfuser}/pytorch_utils.py (100%) rename {nvfuser => python/nvfuser}/testing/__init__.py (100%) rename {nvfuser => python/nvfuser}/testing/utils.py (100%) rename {nvfuser => python/nvfuser}/utils.py (100%) create mode 100644 python/pyproject.toml rename {csrc => python}/python_frontend/distributed_tensor.cpp (100%) rename {csrc => python}/python_frontend/distributed_tensor.h (100%) rename {csrc => python}/python_frontend/fusion_cache.cpp (100%) rename {csrc => python}/python_frontend/fusion_cache.h (100%) rename {csrc => python}/python_frontend/fusion_definition.cpp (100%) rename {csrc => python}/python_frontend/fusion_definition.h (100%) rename {csrc => python}/python_frontend/fusion_record.h (100%) rename {csrc => python}/python_frontend/fusion_state.cpp (100%) rename {csrc => python}/python_frontend/fusion_state.h (100%) rename {csrc => python}/python_frontend/multidevice_bindings.cpp (100%) rename {csrc => python}/python_frontend/python_bindings.cpp (100%) rename {csrc => python}/python_frontend/python_bindings.h (100%) rename {csrc => python}/python_frontend/python_bindings_extension.cpp (100%) rename {csrc => python}/python_frontend/schedule_bindings.cpp (100%) rename {csrc => python}/python_frontend/segmentation.cpp (100%) rename {csrc => python}/python_frontend/segmentation.h (100%) rename {csrc => python}/python_frontend/translation.cpp (100%) rename {csrc => python}/python_frontend/translation.h (100%) rename {csrc => python}/python_frontend/translation_utils.cpp (100%) rename {csrc => python}/python_frontend/translation_utils.h (100%) create mode 100644 python/setup.py rename {tests/cpp => python/tests}/python_frontend/test_nvfuser_fusion_cache.cpp (100%) rename {tests/cpp => python/tests}/python_frontend/test_nvfuser_fusion_definition.cpp (100%) rename {tests/cpp => python/tests}/python_frontend/test_nvfuser_fusion_record.cpp (100%) create mode 100644 python/tools/__init__.py create mode 100644 python/tools/gen_nvfuser_version.py create mode 100644 python/tools/memory.py create mode 100644 python/utils.py rename version.txt => python/version.txt (100%) mode change 100644 => 120000 tools/gen_nvfuser_version.py mode change 100644 => 120000 tools/memory.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae099633d78..1186bd19da9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,7 @@ jobs: tools/pip-install-things.sh & source tools/setup-env.sh wait + cd python python setup.py build --cpp=23 dynamic-type-meson: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a9168522e11..2cc23e73227 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -51,6 +51,9 @@ jobs: wait + # Go to python folder to build cmake files + cd python + # Run cmake build python setup.py --cmake-only @@ -58,6 +61,9 @@ jobs: # NOTE: this might cause a compile of flatbuffers if it is missing ninja -C build build_flatbuffer_config + # Return to root to run clang-tidy + cd .. + # Run lintrunner on all csrc files exclude benchmark and test folders this_commit=$(git rev-parse HEAD) git fetch origin main diff --git a/.gitignore b/.gitignore index 89d7c587c4b..82a26694d84 100644 --- a/.gitignore +++ b/.gitignore @@ -4,20 +4,24 @@ bin # cmake build directory build .lintbin - -# pip wheel directory -dist - nvfuser/version.py nvfuser/include nvfuser/lib nvfuser/share nvfuser/cmake +python/build +python/nvfuser/version.py +python/nvfuser/include +python/nvfuser/lib +python/nvfuser/share +python/nvfuser/cmake + .hypothesis *.egg-info/ **/__pycache__ */*.so +python/nvfuser/*.so # Editor temporaries *.swa diff --git a/.lintrunner.toml b/.lintrunner.toml index 7fcac6c3c4d..d2567f57ff1 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -16,7 +16,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'flake8==6.0.0', + 'flake8==6.1.0', ] diff --git a/CMakeLists.txt b/CMakeLists.txt index e83f7a13def..0f14bc07f07 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR}) set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc") +set(NVFUSER_PYTHON_DIR "${NVFUSER_ROOT}/python") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) @@ -292,13 +293,13 @@ endif() if(BUILD_PYTHON) list(APPEND NVFUSER_SRCS - ${NVFUSER_SRCS_DIR}/python_frontend/distributed_tensor.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/segmentation.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/translation.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/translation_utils.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/distributed_tensor.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_cache.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_definition.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_state.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/segmentation.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/translation.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/translation_utils.cpp ${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp ) endif() @@ -334,6 +335,7 @@ if(NOT MSVC) endif() target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") +target_include_directories(codegen_internal PUBLIC ${NVFUSER_PYTHON_DIR}) target_include_directories(codegen_internal SYSTEM PUBLIC ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include PRIVATE @@ -460,31 +462,32 @@ if(BUILD_PYTHON) # nvfuser python API sources set(NVFUSER_PYTHON_SRCS) list(APPEND NVFUSER_PYTHON_SRCS - ${NVFUSER_SRCS_DIR}/python_frontend/multidevice_bindings.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/multidevice_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings_extension.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/schedule_bindings.cpp ) add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) + target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_PYTHON_DIR}) target_include_directories(nvf_py_internal SYSTEM INTERFACE ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include ) # setup python API version add_custom_command( - OUTPUT ${NVFUSER_ROOT}/nvfuser/version.py + OUTPUT ${NVFUSER_PYTHON_DIR}/nvfuser/version.py COMMAND - "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_ROOT}/tools/gen_nvfuser_version.py') .touch() \" + "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py') .touch() \" COMMAND - "${PYTHON_EXECUTABLE}" ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py - DEPENDS ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py - DEPENDS ${NVFUSER_ROOT}/version.txt + "${PYTHON_EXECUTABLE}" ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/version.txt WORKING_DIRECTORY ${NVFUSER_ROOT}/tools/ ) add_custom_target( gen_nvfuser_version ALL - DEPENDS ${NVFUSER_ROOT}/nvfuser/version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/nvfuser/version.py ) add_dependencies(nvf_py_internal gen_nvfuser_version) @@ -743,9 +746,9 @@ if(BUILD_TEST) if(BUILD_PYTHON) set(PY_FRONTEND_TEST_SRCS) list(APPEND PY_FRONTEND_TEST_SRCS - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_cache.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_definition.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_record.cpp ) add_test(test_python_frontend "${PY_FRONTEND_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_python_frontend) diff --git a/README.md b/README.md index 32c3bde8f4e..a00e09921c2 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,16 @@ PyPI: [https://pypi.org/project/nvfuser/](https://pypi.org/search/?q=nvfuser) Docs: https://github.com/NVIDIA/Fuser/wiki +### Install From Source: +```bash +git clone https://github.com/NVIDIA/Fuser.git +cd Fuser +pip install -r python/requirements.txt + +[DEPRECATED] `[MAX_JOBS] python setup.py develop [args]` +pip install --no-build-isolation -e python -v +``` + Supported compilers: **GCC:** diff --git a/nvfuser b/nvfuser new file mode 120000 index 00000000000..25e57deb181 --- /dev/null +++ b/nvfuser @@ -0,0 +1 @@ +python/nvfuser \ No newline at end of file diff --git a/python/LICENSE b/python/LICENSE new file mode 120000 index 00000000000..ea5b60640b0 --- /dev/null +++ b/python/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/nvfuser/README.md b/python/nvfuser/README.md similarity index 100% rename from nvfuser/README.md rename to python/nvfuser/README.md diff --git a/nvfuser/__init__.py b/python/nvfuser/__init__.py similarity index 100% rename from nvfuser/__init__.py rename to python/nvfuser/__init__.py diff --git a/nvfuser/__init__.pyi b/python/nvfuser/__init__.pyi similarity index 100% rename from nvfuser/__init__.pyi rename to python/nvfuser/__init__.pyi diff --git a/nvfuser/benchmark_utils.py b/python/nvfuser/benchmark_utils.py similarity index 100% rename from nvfuser/benchmark_utils.py rename to python/nvfuser/benchmark_utils.py diff --git a/nvfuser/contrib/__init__.py b/python/nvfuser/contrib/__init__.py similarity index 100% rename from nvfuser/contrib/__init__.py rename to python/nvfuser/contrib/__init__.py diff --git a/nvfuser/contrib/nn/__init__.py b/python/nvfuser/contrib/nn/__init__.py similarity index 100% rename from nvfuser/contrib/nn/__init__.py rename to python/nvfuser/contrib/nn/__init__.py diff --git a/nvfuser/contrib/nn/normalization.py b/python/nvfuser/contrib/nn/normalization.py similarity index 100% rename from nvfuser/contrib/nn/normalization.py rename to python/nvfuser/contrib/nn/normalization.py diff --git a/nvfuser/nvfuser_version.py b/python/nvfuser/nvfuser_version.py similarity index 100% rename from nvfuser/nvfuser_version.py rename to python/nvfuser/nvfuser_version.py diff --git a/nvfuser/pytorch_utils.py b/python/nvfuser/pytorch_utils.py similarity index 100% rename from nvfuser/pytorch_utils.py rename to python/nvfuser/pytorch_utils.py diff --git a/nvfuser/testing/__init__.py b/python/nvfuser/testing/__init__.py similarity index 100% rename from nvfuser/testing/__init__.py rename to python/nvfuser/testing/__init__.py diff --git a/nvfuser/testing/utils.py b/python/nvfuser/testing/utils.py similarity index 100% rename from nvfuser/testing/utils.py rename to python/nvfuser/testing/utils.py diff --git a/nvfuser/utils.py b/python/nvfuser/utils.py similarity index 100% rename from nvfuser/utils.py rename to python/nvfuser/utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 00000000000..d7813c1ed06 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel", "ninja", "cmake>=3.18"] +build-backend = "setuptools.build_meta:__legacy__" diff --git a/csrc/python_frontend/distributed_tensor.cpp b/python/python_frontend/distributed_tensor.cpp similarity index 100% rename from csrc/python_frontend/distributed_tensor.cpp rename to python/python_frontend/distributed_tensor.cpp diff --git a/csrc/python_frontend/distributed_tensor.h b/python/python_frontend/distributed_tensor.h similarity index 100% rename from csrc/python_frontend/distributed_tensor.h rename to python/python_frontend/distributed_tensor.h diff --git a/csrc/python_frontend/fusion_cache.cpp b/python/python_frontend/fusion_cache.cpp similarity index 100% rename from csrc/python_frontend/fusion_cache.cpp rename to python/python_frontend/fusion_cache.cpp diff --git a/csrc/python_frontend/fusion_cache.h b/python/python_frontend/fusion_cache.h similarity index 100% rename from csrc/python_frontend/fusion_cache.h rename to python/python_frontend/fusion_cache.h diff --git a/csrc/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp similarity index 100% rename from csrc/python_frontend/fusion_definition.cpp rename to python/python_frontend/fusion_definition.cpp diff --git a/csrc/python_frontend/fusion_definition.h b/python/python_frontend/fusion_definition.h similarity index 100% rename from csrc/python_frontend/fusion_definition.h rename to python/python_frontend/fusion_definition.h diff --git a/csrc/python_frontend/fusion_record.h b/python/python_frontend/fusion_record.h similarity index 100% rename from csrc/python_frontend/fusion_record.h rename to python/python_frontend/fusion_record.h diff --git a/csrc/python_frontend/fusion_state.cpp b/python/python_frontend/fusion_state.cpp similarity index 100% rename from csrc/python_frontend/fusion_state.cpp rename to python/python_frontend/fusion_state.cpp diff --git a/csrc/python_frontend/fusion_state.h b/python/python_frontend/fusion_state.h similarity index 100% rename from csrc/python_frontend/fusion_state.h rename to python/python_frontend/fusion_state.h diff --git a/csrc/python_frontend/multidevice_bindings.cpp b/python/python_frontend/multidevice_bindings.cpp similarity index 100% rename from csrc/python_frontend/multidevice_bindings.cpp rename to python/python_frontend/multidevice_bindings.cpp diff --git a/csrc/python_frontend/python_bindings.cpp b/python/python_frontend/python_bindings.cpp similarity index 100% rename from csrc/python_frontend/python_bindings.cpp rename to python/python_frontend/python_bindings.cpp diff --git a/csrc/python_frontend/python_bindings.h b/python/python_frontend/python_bindings.h similarity index 100% rename from csrc/python_frontend/python_bindings.h rename to python/python_frontend/python_bindings.h diff --git a/csrc/python_frontend/python_bindings_extension.cpp b/python/python_frontend/python_bindings_extension.cpp similarity index 100% rename from csrc/python_frontend/python_bindings_extension.cpp rename to python/python_frontend/python_bindings_extension.cpp diff --git a/csrc/python_frontend/schedule_bindings.cpp b/python/python_frontend/schedule_bindings.cpp similarity index 100% rename from csrc/python_frontend/schedule_bindings.cpp rename to python/python_frontend/schedule_bindings.cpp diff --git a/csrc/python_frontend/segmentation.cpp b/python/python_frontend/segmentation.cpp similarity index 100% rename from csrc/python_frontend/segmentation.cpp rename to python/python_frontend/segmentation.cpp diff --git a/csrc/python_frontend/segmentation.h b/python/python_frontend/segmentation.h similarity index 100% rename from csrc/python_frontend/segmentation.h rename to python/python_frontend/segmentation.h diff --git a/csrc/python_frontend/translation.cpp b/python/python_frontend/translation.cpp similarity index 100% rename from csrc/python_frontend/translation.cpp rename to python/python_frontend/translation.cpp diff --git a/csrc/python_frontend/translation.h b/python/python_frontend/translation.h similarity index 100% rename from csrc/python_frontend/translation.h rename to python/python_frontend/translation.h diff --git a/csrc/python_frontend/translation_utils.cpp b/python/python_frontend/translation_utils.cpp similarity index 100% rename from csrc/python_frontend/translation_utils.cpp rename to python/python_frontend/translation_utils.cpp diff --git a/csrc/python_frontend/translation_utils.h b/python/python_frontend/translation_utils.h similarity index 100% rename from csrc/python_frontend/translation_utils.h rename to python/python_frontend/translation_utils.h diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 00000000000..4b39f2563fe --- /dev/null +++ b/python/setup.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Usage: +# pip install --no-build-isolation -e python -v +# This build command is equivalent to: python setup.py develop +# Options: +# -v: verbose output +# --no-build-isolation: don't build in a temporary directory +# -e: install in development mode +# +# Environment variables used during build: +# MAX_JOBS +# maximum number of compile jobs we should use to compile your code +# +# NVFUSER_BUILD_CMAKE_ONLY +# Only generate ./build directory with cmake setup +# +# NVFUSER_BUILD_NO_PYTHON +# Skips python API target `libnvfuser.so`, i.e. `_C.cpython-xxx.so` +# +# NVFUSER_BUILD_NO_TEST +# Skips cpp tests `test_nvfuser` +# +# NVFUSER_BUILD_NO_BENCHMARK +# Skips benchmark target `nvfuser_bench` +# +# NVFUSER_BUILD_NO_NINJA +# In case you want to use make instead of ninja for build +# +# NVFUSER_BUILD_WITH_UCC +# Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. +# +# NVFUSER_BUILD_WITHOUT_DISTRIBUTED +# Build nvfuser without multidevice support +# +# NVFUSER_BUILD_TYPE=Debug +# Building nvfuser in debug mode +# +# NVFUSER_BUILD_TYPE=RelwithDebInfo +# Building nvfuser in release mode with debug info, a.k.a. RelwithDebInfo +# +# NVFUSER_BUILD_DIR= +# Specify in which directory to build nvfuser. If not specified, the default build directory is "./build". +# +# NVFUSER_BUILD_INSTALL_DIR= +# Specify in which directory to install nvfuser. If not specified, the default install directory is "./python/nvfuser". +# +# NVFUSER_BUILD_VERSION_TAG=TAG +# Specify the tag for build nvfuser version, this is used for pip wheel +# package nightly where we might want to add a date tag +# nvfuser-VERSION+TAG+gitSHA1-....-whl +# +# NVFUSER_BUILD_INSTALL_REQUIRES=pkg0[,pkg1...] +# this is used for pip wheel build to specify package required for install +# e.g. NVFUSER_BUILD_INSTALL_REQUIRES=nvidia-cuda-nvrtc-cu12 +# +# NVFUSER_BUILD_WHEEL_NAME=NAME +# Specify the wheel name this is used for pip wheel package where we want +# to identify the cuda toolkit version +# +# NVFUSER_BUILD_CPP_STANDARD=STANDARD +# Specify the C++ standard to use for building nvfuser. The default is C++20. +# + +import sys + +from utils import ( + run, + create_build_config, + override_build_config_from_env, +) + + +def version_tag(config): + from tools.gen_nvfuser_version import get_version + + version = get_version() + if config.overwrite_version: + version = version.split("+")[0] + if len(config.version_tag) != 0: + # use "." to be pypi friendly + version = ".".join([version, config.version_tag]) + return version + + +def main(): + # Parse arguments using argparse + # Use argparse to create description of arguments from command line + config, forward_args = create_build_config() + + # Override build config from environment variables + override_build_config_from_env(config) + + if "clean" in sys.argv: + # only disables BUILD_SETUP, but keep the argument for setuptools + config.build_setup = False + + if config.cpp_standard < 20: + raise ValueError("nvfuser requires C++20 standard or higher") + + sys.argv = [sys.argv[0]] + forward_args + + run(config, version_tag(config), relative_path="..") + + +if __name__ == "__main__": + main() diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp b/python/tests/python_frontend/test_nvfuser_fusion_cache.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_cache.cpp diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp b/python/tests/python_frontend/test_nvfuser_fusion_definition.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_definition.cpp diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp b/python/tests/python_frontend/test_nvfuser_fusion_record.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_record.cpp diff --git a/python/tools/__init__.py b/python/tools/__init__.py new file mode 100644 index 00000000000..51ba303bccb --- /dev/null +++ b/python/tools/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/python/tools/gen_nvfuser_version.py b/python/tools/gen_nvfuser_version.py new file mode 100644 index 00000000000..a09eda53539 --- /dev/null +++ b/python/tools/gen_nvfuser_version.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import subprocess +import sys +from pathlib import Path + +UNKNOWN = "Unknown" +nvfuser_root = Path(__file__).parent.parent + + +# note that this root currently is still part of pytorch. +def get_sha() -> str: + try: + return ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=nvfuser_root) + .decode("ascii") + .strip() + ) + except Exception: + import os + + # assume the $NVFUSER_VERSION is in sha form + if nvfuser_version := os.environ.get("NVFUSER_VERSION"): + assert ( + len(nvfuser_version) < 11 + ), "The NVFUSER_VERSION should be in sha form" + return nvfuser_version + return UNKNOWN + + +def get_version() -> str: + sha = get_sha() + version = ( + open((nvfuser_root / "version.txt"), "r").read().strip() + "+git" + sha[:7] + ) + return version + + +def get_pytorch_cmake_prefix(): + from subprocess import Popen, PIPE + + # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch + process_torch_prefix = Popen( + [ + sys.executable, + "-c", + "import torch.utils; print(torch.utils.cmake_prefix_path)", + ], + stdout=PIPE, + ) + stdout_msg, error_msg = process_torch_prefix.communicate() + return stdout_msg.decode("utf-8").rstrip("\n") + + +def get_pytorch_use_distributed(): + from subprocess import Popen, PIPE + + # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch + process_torch_prefix = Popen( + [ + sys.executable, + "-c", + "import torch; print(torch._C._has_distributed())", + ], + stdout=PIPE, + ) + stdout_msg, error_msg = process_torch_prefix.communicate() + return stdout_msg.decode("utf-8").rstrip("\n") + + +if __name__ == "__main__": + version_file = nvfuser_root / "nvfuser" / "version.py" + with open(version_file, "w") as f: + f.write("_version_str = '{}'\n".format(get_version())) diff --git a/python/tools/memory.py b/python/tools/memory.py new file mode 100644 index 00000000000..1ed95f8ded5 --- /dev/null +++ b/python/tools/memory.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +def get_available_memory_gb(): + """Returns the available memory in GB.""" + try: + import psutil + + return psutil.virtual_memory().available / 1024 / 1024 / 1024 + except: # noqa: E722 + pass + + try: + with open("/proc/meminfo", "r") as f: + while True: + line = f.readline() + if line.startswith("MemAvailable:"): + mem = line.split()[1] + assert line.split()[2] == "kB" + return int(mem) / 1024 / 1024 + if not line: + break + except: # noqa: E722 + pass + + return 0 diff --git a/python/utils.py b/python/utils.py new file mode 100644 index 00000000000..3c13b898f6d --- /dev/null +++ b/python/utils.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +import os +import multiprocessing +import subprocess +import sys +import shutil +from dataclasses import dataclass +import setuptools.command.build_ext + + +@dataclass +class BuildConfig: + cmake_only: bool = False + build_setup: bool = True + no_python: bool = False + no_test: bool = False + no_benchmark: bool = False + no_ninja: bool = False + build_with_ucc: bool = False + build_with_asan: bool = False + build_without_distributed: bool = False + build_with_system_nvtx: bool = True + explicit_error_check: bool = False + overwrite_version: bool = False + version_tag: str = None + build_type: str = "Release" + wheel_name: str = "nvfuser" + build_dir: str = "" + install_dir: str = "" + install_requires: list = None + extras_require: dict = None + cpp_standard: int = 20 + + def __post_init__(self): + # dataclass cannot have mutable default values in the class definition + if self.install_requires is None: + self.install_requires = [] + if self.extras_require is None: + self.extras_require = {} + + +def check_env_flag_bool_default(name: str, default: str = "") -> bool: + if name not in os.environ: + return default + return os.getenv(name).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def get_env_flag_bool(name: str) -> bool: + assert name in os.environ + return os.getenv(name).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def parse_args(): + parser = argparse.ArgumentParser( + description="NVFUSER build options", add_help=False + ) + + # Add arguments that don't go to setuptools + parser.add_argument( + "--cmake-only", + dest="cmake_only", + action="store_true", + help="Only generate ./build directory with cmake setup", + ) + parser.add_argument( + "--no-python", + dest="no_python", + action="store_true", + help="Skips python API target libnvfuser.so", + ) + parser.add_argument( + "--no-test", + dest="no_test", + action="store_true", + help="Skips cpp tests test_nvfuser", + ) + parser.add_argument( + "--no-benchmark", + dest="no_benchmark", + action="store_true", + help="Skips benchmark target nvfuser_bench", + ) + parser.add_argument( + "--no-ninja", + dest="no_ninja", + action="store_true", + help="Use make instead of ninja for build", + ) + parser.add_argument( + "--build-with-ucc", + dest="build_with_ucc", + action="store_true", + help="Build nvfuser with UCC support", + ) + parser.add_argument( + "--explicit-error-check", + dest="explicit_error_check", + action="store_true", + help="Enable explicit error checking", + ) + parser.add_argument( + "--build-with-asan", + dest="build_with_asan", + action="store_true", + help="Build with Address Sanitizer", + ) + parser.add_argument( + "--build-without-distributed", + dest="build_without_distributed", + action="store_true", + help="Build nvfuser without multidevice support", + ) + parser.add_argument( + "--no-system-nvtx", + dest="no_system_nvtx", + action="store_true", + help="Disable system NVTX", + ) + parser.add_argument( + "--debug", + dest="debug_mode", + action="store_true", + help="Building nvfuser in debug mode", + ) + parser.add_argument( + "--debinfo", + dest="debinfo_mode", + action="store_true", + help="Building nvfuser in release mode with debug info", + ) + parser.add_argument( + "--build-dir", + dest="build_dir", + type=str, + default="", + help="Specify in which directory to build nvfuser", + ) + parser.add_argument( + "--install-dir", + dest="install_dir", + type=str, + default="", + help="Specify in which directory to install nvfuser", + ) + parser.add_argument( + "-install_requires", + dest="install_requires", + type=str, + help="Specify package required for installation", + ) + parser.add_argument( + "--extras_require", + dest="extras_require", + type=str, + help="Specify extra requirements", + ) + parser.add_argument( + "-version-tag", + dest="version_tag", + type=str, + help="Specify the tag for build nvfuser version", + ) + parser.add_argument( + "-wheel-name", + dest="wheel_name", + type=str, + default="nvfuser", + help="Specify the wheel name", + ) + parser.add_argument( + "--cpp", + dest="cpp_standard", + type=int, + help="Specify the C++ standard to use", + default=20, + ) + + # Use parse_known_args to separate our arguments from setuptools arguments + args, forward_args = parser.parse_known_args() + return args, forward_args + + +# Create BuildConfig using argparse +def create_build_config(): + # Parse arguments and set global variables accordingly + args, forward_args = parse_args() + + # Create a BuildConfig from args + config = BuildConfig( + cmake_only=args.cmake_only, + no_python=args.no_python, + no_test=args.no_test, + no_benchmark=args.no_benchmark, + no_ninja=args.no_ninja, + build_with_ucc=args.build_with_ucc, + build_with_asan=args.build_with_asan, + build_without_distributed=args.build_without_distributed, + build_with_system_nvtx=not args.no_system_nvtx, + explicit_error_check=args.explicit_error_check, + wheel_name=args.wheel_name, + build_dir=args.build_dir, + install_dir=args.install_dir, + cpp_standard=args.cpp_standard, + ) + + # Apply remaining options + if args.debug_mode: + config.build_type = "Debug" + if args.debinfo_mode: + config.build_type = "RelwithDebInfo" + if args.install_requires: + config.install_requires = args.install_requires.split(",") + if args.extras_require: + config.extras_require = eval(args.extras_require) + if args.version_tag: + config.version_tag = args.version_tag + config.overwrite_version = True + return config, forward_args + + +# Override BuildConfig with environment variables. Only change if variable +# exists. Do not use default to override argparse. +def override_build_config_from_env(config): + # Command line arguments don't work on PEP517 builds and will be silently ignored, + # so we need to pass those options as environment variables instead. + if "NVFUSER_BUILD_CMAKE_ONLY" in os.environ: + config.cmake_only = get_env_flag_bool("NVFUSER_BUILD_CMAKE_ONLY") + if "NVFUSER_BUILD_SETUP" in os.environ: + config.build_setup = get_env_flag_bool("NVFUSER_BUILD_SETUP") + if "NVFUSER_BUILD_NO_PYTHON" in os.environ: + config.no_python = get_env_flag_bool("NVFUSER_BUILD_NO_PYTHON") + if "NVFUSER_BUILD_NO_TEST" in os.environ: + config.no_test = get_env_flag_bool("NVFUSER_BUILD_NO_TEST") + if "NVFUSER_BUILD_NO_BENCHMARK" in os.environ: + config.no_benchmark = get_env_flag_bool("NVFUSER_BUILD_NO_BENCHMARK") + if "NVFUSER_BUILD_NO_NINJA" in os.environ: + config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") + if "NVFUSER_BUILD_WITH_UCC" in os.environ: + config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_WITH_ASAN" in os.environ: + config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") + if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: + config.build_without_distributed = get_env_flag_bool( + "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" + ) + if "NVFUSER_BUILD_WITH_SYSTEM_NVTX" in os.environ: + config.build_with_system_nvtx = get_env_flag_bool( + "NVFUSER_BUILD_WITH_SYSTEM_NVTX" + ) + if "NVFUSER_BUILD_EXPLICIT_ERROR_CHECK" in os.environ: + config.explicit_error_check = get_env_flag_bool( + "NVFUSER_BUILD_EXPLICIT_ERROR_CHECK" + ) + if "NVFUSER_BUILD_OVERWRITE_VERSION" in os.environ: + config.overwrite_version = get_env_flag_bool("NVFUSER_BUILD_OVERWRITE_VERSION") + if "NVFUSER_BUILD_VERSION_TAG" in os.environ: + config.version_tag = os.getenv("NVFUSER_BUILD_VERSION_TAG") + if "NVFUSER_BUILD_BUILD_TYPE" in os.environ: + config.build_type = os.getenv("NVFUSER_BUILD_BUILD_TYPE") + if "NVFUSER_BUILD_WHEEL_NAME" in os.environ: + config.wheel_name = os.getenv("NVFUSER_BUILD_WHEEL_NAME") + if "NVFUSER_BUILD_DIR" in os.environ: + config.build_dir = os.getenv("NVFUSER_BUILD_DIR") + if "NVFUSER_BUILD_INSTALL_DIR" in os.environ: + config.install_dir = os.getenv("NVFUSER_BUILD_INSTALL_DIR") + if "NVFUSER_BUILD_INSTALL_REQUIRES" in os.environ: + config.install_requires = os.getenv("NVFUSER_BUILD_INSTALL_REQUIRES").split(",") + if "NVFUSER_BUILD_EXTRAS_REQUIRE" in os.environ: + config.extras_require = eval(os.getenv("NVFUSER_BUILD_EXTRAS_REQUIRE")) + if "NVFUSER_BUILD_CPP_STANDARD" in os.environ: + config.cpp_standard = int(os.getenv("NVFUSER_BUILD_CPP_STANDARD")) + if "NVFUSER_BUILD_VERSION_TAG" in os.environ: + config.overwrite_version = True + config.version_tag = os.getenv("NVFUSER_BUILD_VERSION_TAG") + + +class build_ext(setuptools.command.build_ext.build_ext): + def build_extension(self, ext): + if ext.name == "nvfuser._C": + # Copy files on necessity. + filename = self.get_ext_filename(self.get_ext_fullname(ext.name)) + fileext = os.path.splitext(filename)[1] + + libnvfuser_path = os.path.join("./nvfuser/lib", f"libnvfuser{fileext}") + assert os.path.exists(libnvfuser_path) + install_dst = os.path.join(self.build_lib, filename) + if not os.path.exists(os.path.dirname(install_dst)): + os.makedirs(os.path.dirname(install_dst)) + self.copy_file(libnvfuser_path, install_dst) + else: + super().build_extension(ext) + + +class concat_third_party_license: + def __init__(self, directory="third_party"): + self.license_file = "LICENSE" + self.directory = directory + + def __enter__(self): + # read original license file + with open(self.license_file, "r") as f: + self.nvfuser_license_txt = f.read() + + licenses = {"LICENSE", "LICENSE.txt", "LICENSE.rst", "COPYING.BSD"} + + # aggregated license, we key on project name + aggregated_license = {} + for root, dirs, files in os.walk(self.directory): + license = list(licenses & set(files)) + if license: + project_name = root.split("/")[-1] + # let's worry about multiple license when we see it. + assert len(license) == 1 + license_entry = os.path.join(root, license[0]) + if project_name in aggregated_license: + # Only add it if the license is different + aggregated_license[project_name].append(license_entry) + else: + aggregated_license[project_name] = [license_entry] + return aggregated_license + + def __exit__(self, exception_type, exception_value, traceback): + # restore original license file + with open(self.license_file, "w") as f: + f.write(self.nvfuser_license_txt) + + +try: + from wheel.bdist_wheel import bdist_wheel +except ImportError: + build_whl = None +else: + + class build_whl(bdist_wheel): + def run(self): + with concat_third_party_license() as tp_licenses: + if len(tp_licenses) != 0: + with open("LICENSE", "a") as f: + f.write("\n\n") + f.write( + "NVIDIA/fuser depends on libraries with license listed below:" + ) + + for project_name, license_files in tp_licenses.items(): + # check all license files are identical + with open(license_files[0], "r") as f: + license_ref = f.read() + + def check_file(file_name): + with open(file_name, "r") as f: + return f.read() == license_ref + + identical_flag = all(map(check_file, license_files[1:])) + if not identical_flag: + raise RuntimeError( + "inconsistent license found for project: ", + project_name, + " check its license files under: ", + license_files, + ) + + with open("LICENSE", "a") as f: + f.write("\n\nProject Name: " + project_name) + f.write("\nLicense Files:\n") + for file_name in license_files: + f.write("\t" + file_name) + f.write("\n" + license_ref) + + # generate whl before we restore LICENSE + super().run() + + +def get_cmake_bin(): + # TODO: double check cmake version here and retrieve later version if necessary + return "cmake" + + +def cmake(config, relative_path): + from tools.memory import get_available_memory_gb + + # make build directories + cwd = os.path.dirname(os.path.abspath(__file__)) + cmake_build_dir = ( + os.path.join(cwd, "build") if not config.build_dir else config.build_dir + ) + if not os.path.exists(cmake_build_dir): + os.makedirs(cmake_build_dir) + + install_prefix = ( + os.path.join(cwd, "nvfuser") if not config.install_dir else config.install_dir + ) + + from tools.gen_nvfuser_version import ( + get_pytorch_cmake_prefix, + get_pytorch_use_distributed, + ) + + # this is used to suppress import error. + # so we can get the right pytorch prefix for cmake + import logging + + logger = logging.getLogger("nvfuser") + logger_level = logger.getEffectiveLevel() + logger.setLevel(logging.CRITICAL) + + pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() + + logger.setLevel(logger_level) + + pytorch_use_distributed = get_pytorch_use_distributed() + + # generate cmake directory + cmd_str = [ + get_cmake_bin(), + pytorch_cmake_config, + "-DCMAKE_BUILD_TYPE=" + config.build_type, + f"-DCMAKE_INSTALL_PREFIX={install_prefix}", + f"-DNVFUSER_CPP_STANDARD={config.cpp_standard}", + f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", + "-B", + cmake_build_dir, + ] + if config.build_with_ucc: + cmd_str.append("-DNVFUSER_STANDALONE_BUILD_WITH_UCC=ON") + if config.explicit_error_check: + cmd_str.append("-DNVFUSER_EXPLICIT_ERROR_CHECK=ON") + if not config.no_ninja: + cmd_str.append("-G") + cmd_str.append("Ninja") + if not config.no_test: + cmd_str.append("-DBUILD_TEST=ON") + if not config.no_python: + cmd_str.append("-DBUILD_PYTHON=ON") + cmd_str.append(f"-DPython_EXECUTABLE={sys.executable}") + if not config.no_benchmark: + cmd_str.append("-DBUILD_NVFUSER_BENCHMARK=ON") + if config.build_with_asan: + cmd_str.append("-DNVFUSER_BUILD_WITH_ASAN=ON") + if config.build_without_distributed: + cmd_str.append("-DNVFUSER_DISTRIBUTED=OFF") + if config.build_with_system_nvtx: + cmd_str.append("-DUSE_SYSTEM_NVTX=ON") + cmd_str.append(relative_path) + + print(f"Configuring CMake with {' '.join(cmd_str)}") + subprocess.check_call(cmd_str) + + max_jobs = multiprocessing.cpu_count() + mem_gb_per_task = 3 # Currently compilation of nvFuser souce code takes ~3GB of memory per task, we should adjust this value if it changes in the future. + available_mem = get_available_memory_gb() + if available_mem > 0: + max_jobs_mem = int(available_mem / mem_gb_per_task) + max_jobs = min(max_jobs, max_jobs_mem) + + if not config.cmake_only: + # build binary + max_jobs = os.getenv("MAX_JOBS", str(max_jobs)) + print(f"Using {max_jobs} jobs for compilation") + cmd_str = [ + get_cmake_bin(), + "--build", + cmake_build_dir, + "--target", + "install", + "--", + "-j", + max_jobs, + ] + subprocess.check_call(cmd_str) + + +def create_clean(relative_path): + class clean(setuptools.Command): + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + import glob + + gitignore_path = os.path.join(relative_path, ".gitignore") + assert os.path.exists(gitignore_path) + with open(gitignore_path, "r") as f: + ignores = f.read() + for entry in ignores.split("\n"): + # ignore comment in .gitignore + if len(entry) >= 1 and entry[0] != "#": + for filename in glob.glob(entry): + print("removing: ", filename) + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) + + return clean + + +def run(config, version_tag, relative_path): + from setuptools import Extension, setup, find_packages + + # NOTE(crcrpar): Deliberately build basically two dynamic libraries here so that they can + # be treated as "nvfuser_package_data". This function call will put the two of "nvfuser" and + # "nvfuser_codegen" into "./nvfuser/lib", and the former will be "nvfuser._C". + if config.build_setup: + cmake(config, relative_path) + if not config.cmake_only: + # NOTE: package include files for cmake + # TODO(crcrpar): Better avoid hardcoding `libnvfuser_codegen.so` + # might can be treated by using `exclude_package_data`. + nvfuser_package_data = [ + "lib/libnvfuser_codegen.so", + "include/nvfuser/*.h", + "include/nvfuser/struct.inl", + "include/nvfuser/C++20/type_traits", + "include/nvfuser/device_lower/*.h", + "include/nvfuser/device_lower/analysis/*.h", + "include/nvfuser/device_lower/pass/*.h", + "include/nvfuser/dynamic_type/*", + "include/nvfuser/dynamic_type/C++20/*", + "include/nvfuser/kernel_db/*.h", + "include/nvfuser/multidevice/*.h", + "include/nvfuser/ops/*.h", + "include/nvfuser/ir/*.h", + "include/nvfuser/python_frontend/*.h", + "include/nvfuser/scheduler/*.h", + "include/nvfuser/serde/*.h", + "include/nvfuser/flatbuffers/*.h", + "include/nvfuser/host_ir/*.h", + "include/nvfuser/id_model/*.h", + "share/cmake/nvfuser/NvfuserConfig*", + # TODO(crcrpar): it'd be better to ship the following two binaries. + # Would need some change in CMakeLists.txt. + # "bin/test_nvfuser", + # "bin/nvfuser_bench" + ] + + setup( + name=config.wheel_name, + version=version_tag, + url="https://github.com/NVIDIA/Fuser", + description="A Fusion Code Generator for NVIDIA GPUs (commonly known as 'nvFuser')", + packages=find_packages(), + ext_modules=[Extension(name="nvfuser._C", sources=[])], + license_files=("LICENSE",), + cmdclass={ + "bdist_wheel": build_whl, + "build_ext": build_ext, + "clean": create_clean(relative_path), + }, + package_data={ + "nvfuser": nvfuser_package_data, + }, + install_requires=config.install_requires, + extras_require={ + "test": ["numpy", "expecttest", "pytest"], + **config.extras_require, + }, + license="BSD-3-Clause", + ) diff --git a/version.txt b/python/version.txt similarity index 100% rename from version.txt rename to python/version.txt diff --git a/setup.py b/setup.py index 4aced7e1a57..e1cd19eb726 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +# +# Usage: +# [MAX_JOBS] python setup.py develop [args] +# # Environment variables used during build: # -# MAX_JOBS +# MAX_JOBS # maximum number of compile jobs we should use to compile your code # -# build argument: +# NvFuser build arguments: # # --cmake-only # Only generate ./build directory with cmake setup @@ -39,7 +43,7 @@ # Specify in which directory to build nvfuser. If not specified, the default build directory is "./build". # # --install-dir= -# Specify in which directory to install nvfuser. If not specified, the default install directory is "./nvfuser". +# Specify in which directory to install nvfuser. If not specified, the default install directory is "./python/nvfuser". # # -version-tag=TAG # Specify the tag for build nvfuser version, this is used for pip wheel @@ -58,395 +62,45 @@ # Specify the C++ standard to use for building nvfuser. The default is C++20. # -import multiprocessing -import os -import shutil -import subprocess -import sys - -import setuptools -import setuptools.command.build_ext -from setuptools import Extension, setup, find_packages - -# pick args used by this script -CMAKE_ONLY = False -BUILD_SETUP = True -NO_PYTHON = False -NO_TEST = False -NO_BENCHMARK = False -NO_NINJA = False -BUILD_WITH_UCC = False -BUILD_WITH_ASAN = False -BUILD_WITHOUT_DISTRIBUTED = False -BUILD_WITH_SYSTEM_NVTX = True -OVERWRITE_VERSION = False -EXPLICIT_ERROR_CHECK = False -VERSION_TAG = None -BUILD_TYPE = "Release" -WHEEL_NAME = "nvfuser" -BUILD_DIR = "" -INSTALL_DIR = "" -INSTALL_REQUIRES = [] -EXTRAS_REQUIRE = {} -CPP_STANDARD = 20 -forward_args = [] -for i, arg in enumerate(sys.argv): - if arg == "--cmake-only": - CMAKE_ONLY = True - continue - if arg == "--no-python": - NO_PYTHON = True - continue - if arg == "--no-test": - NO_TEST = True - continue - if arg == "--no-benchmark": - NO_BENCHMARK = True - continue - if arg == "--no-ninja": - NO_NINJA = True - continue - if arg == "--build-with-ucc": - BUILD_WITH_UCC = True - continue - if arg == "--explicit-error-check": - EXPLICIT_ERROR_CHECK = True - continue - if arg == "--build-with-asan": - BUILD_WITH_ASAN = True - continue - if arg == "--build-without-distributed": - BUILD_WITHOUT_DISTRIBUTED = True - continue - if arg == "--no-system-nvtx": - BUILD_WITH_SYSTEM_NVTX = False - continue - if arg == "--debug": - BUILD_TYPE = "Debug" - continue - if arg == "--debinfo": - BUILD_TYPE = "RelwithDebInfo" - continue - if arg.startswith("--build-dir"): - BUILD_DIR = arg.split("=")[1] - continue - if arg.startswith("--install-dir"): - INSTALL_DIR = arg.split("=")[1] - continue - if arg.startswith("-install_requires="): - INSTALL_REQUIRES = arg.split("=")[1].split(",") - continue - if arg.startswith("--extras_require="): - EXTRAS_REQUIRE = eval("=".join(arg.split("=")[1:])) - continue - if arg.startswith("-version-tag="): - OVERWRITE_VERSION = True - VERSION_TAG = arg.split("=")[1] - continue - if arg.startswith("-wheel-name="): - WHEEL_NAME = arg.split("=")[1] - continue - if arg.startswith("--cpp="): - CPP_STANDARD = int(arg.split("=")[1]) - if CPP_STANDARD < 20: - raise ValueError("nvfuser requires C++20 standard or higher") - continue - if arg in ["clean"]: - # only disables BUILD_SETUP, but keep the argument for setuptools - BUILD_SETUP = False - forward_args.append(arg) -sys.argv = forward_args - - -def get_cmake_bin(): - # TODO: double check cmake version here and retrieve later version if necessary - return "cmake" - - -class clean(setuptools.Command): - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - import glob - - with open(".gitignore", "r") as f: - ignores = f.read() - for entry in ignores.split("\n"): - # ignore comment in .gitignore - if len(entry) >= 1 and entry[0] != "#": - for filename in glob.glob(entry): - print("removing: ", filename) - try: - os.remove(filename) - except OSError: - shutil.rmtree(filename, ignore_errors=True) - - -class build_ext(setuptools.command.build_ext.build_ext): - def build_extension(self, ext): - if ext.name == "nvfuser._C": - # Copy files on necessity. - filename = self.get_ext_filename(self.get_ext_fullname(ext.name)) - fileext = os.path.splitext(filename)[1] - - libnvfuser_path = os.path.join("./nvfuser/lib", f"libnvfuser{fileext}") - assert os.path.exists(libnvfuser_path) - install_dst = os.path.join(self.build_lib, filename) - if not os.path.exists(os.path.dirname(install_dst)): - os.makedirs(os.path.dirname(install_dst)) - self.copy_file(libnvfuser_path, install_dst) - else: - super().build_extension(ext) - - -class concat_third_party_license: - def __init__(self, directory="third_party"): - self.license_file = "LICENSE" - self.directory = directory - - def __enter__(self): - # read original license file - with open(self.license_file, "r") as f: - self.nvfuser_license_txt = f.read() - - licenses = {"LICENSE", "LICENSE.txt", "LICENSE.rst", "COPYING.BSD"} - - # aggregated license, we key on project name - aggregated_license = {} - for root, dirs, files in os.walk(self.directory): - license = list(licenses & set(files)) - if license: - project_name = root.split("/")[-1] - # let's worry about multiple license when we see it. - assert len(license) == 1 - license_entry = os.path.join(root, license[0]) - if project_name in aggregated_license: - # Only add it if the license is different - aggregated_license[project_name].append(license_entry) - else: - aggregated_license[project_name] = [license_entry] - return aggregated_license +# TODO Remove nvfuser symbolic link to python/nvfuser +# TODO Remove tools/gen_nvfuser_version.py symbolic link to python/tools/gen_nvfuser_version.py +# TODO Remove tools/memory.py symbolic link to python/tools/memory.py - def __exit__(self, exception_type, exception_value, traceback): - # restore original license file - with open(self.license_file, "w") as f: - f.write(self.nvfuser_license_txt) - - -try: - from wheel.bdist_wheel import bdist_wheel -except ImportError: - build_whl = None -else: - - class build_whl(bdist_wheel): - def run(self): - with concat_third_party_license() as tp_licenses: - if len(tp_licenses) != 0: - with open("LICENSE", "a") as f: - f.write("\n\n") - f.write( - "NVIDIA/fuser depends on libraries with license listed below:" - ) - - for project_name, license_files in tp_licenses.items(): - # check all license files are identical - with open(license_files[0], "r") as f: - license_ref = f.read() - - def check_file(file_name): - with open(file_name, "r") as f: - return f.read() == license_ref - - identical_flag = all(map(check_file, license_files[1:])) - if not identical_flag: - raise RuntimeError( - "inconsistent license found for project: ", - project_name, - " check its license files under: ", - license_files, - ) +import sys - with open("LICENSE", "a") as f: - f.write("\n\nProject Name: " + project_name) - f.write("\nLicense Files:\n") - for file_name in license_files: - f.write("\t" + file_name) - f.write("\n" + license_ref) - # generate whl before we restore LICENSE - super().run() +from python.utils import ( + run, + create_build_config, +) -def version_tag(): - from tools.gen_nvfuser_version import get_version +def version_tag(config): + from python.tools.gen_nvfuser_version import get_version version = get_version() - if OVERWRITE_VERSION: + if config.overwrite_version: version = version.split("+")[0] - if len(VERSION_TAG) != 0: + if len(config.version_tag) != 0: # use "." to be pypi friendly - version = ".".join([version, VERSION_TAG]) + version = ".".join([version, config.version_tag]) return version -from tools.memory import get_available_memory_gb - - -def cmake(): - # make build directories - cwd = os.path.dirname(os.path.abspath(__file__)) - cmake_build_dir = os.path.join(cwd, "build") if not BUILD_DIR else BUILD_DIR - if not os.path.exists(cmake_build_dir): - os.makedirs(cmake_build_dir) - - install_prefix = os.path.join(cwd, "nvfuser") if not INSTALL_DIR else INSTALL_DIR - - from tools.gen_nvfuser_version import ( - get_pytorch_cmake_prefix, - get_pytorch_use_distributed, - ) - - # this is used to suppress import error. - # so we can get the right pytorch prefix for cmake - import logging - - logger = logging.getLogger("nvfuser") - logger_level = logger.getEffectiveLevel() - logger.setLevel(logging.CRITICAL) - - pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() - - logger.setLevel(logger_level) - - pytorch_use_distributed = get_pytorch_use_distributed() - - # generate cmake directory - cmd_str = [ - get_cmake_bin(), - pytorch_cmake_config, - "-DCMAKE_BUILD_TYPE=" + BUILD_TYPE, - f"-DCMAKE_INSTALL_PREFIX={install_prefix}", - f"-DNVFUSER_CPP_STANDARD={CPP_STANDARD}", - f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", - "-B", - cmake_build_dir, - ] - if BUILD_WITH_UCC: - cmd_str.append("-DNVFUSER_STANDALONE_BUILD_WITH_UCC=ON") - if EXPLICIT_ERROR_CHECK: - cmd_str.append("-DNVFUSER_EXPLICIT_ERROR_CHECK=ON") - if not NO_NINJA: - cmd_str.append("-G") - cmd_str.append("Ninja") - if not NO_TEST: - cmd_str.append("-DBUILD_TEST=ON") - if not NO_PYTHON: - cmd_str.append("-DBUILD_PYTHON=ON") - cmd_str.append(f"-DPython_EXECUTABLE={sys.executable}") - if not NO_BENCHMARK: - cmd_str.append("-DBUILD_NVFUSER_BENCHMARK=ON") - if BUILD_WITH_ASAN: - cmd_str.append("-DNVFUSER_BUILD_WITH_ASAN=ON") - if BUILD_WITHOUT_DISTRIBUTED: - cmd_str.append("-DNVFUSER_DISTRIBUTED=OFF") - if BUILD_WITH_SYSTEM_NVTX: - cmd_str.append("-DUSE_SYSTEM_NVTX=ON") - cmd_str.append(".") - - print(f"Configuring CMake with {' '.join(cmd_str)}") - subprocess.check_call(cmd_str) - - max_jobs = multiprocessing.cpu_count() - mem_gb_per_task = 3 # Currently compilation of nvFuser souce code takes ~3GB of memory per task, we should adjust this value if it changes in the future. - available_mem = get_available_memory_gb() - if available_mem > 0: - max_jobs_mem = int(available_mem / mem_gb_per_task) - max_jobs = min(max_jobs, max_jobs_mem) +def main(): + # Parse arguments using argparse + config, forward_args = create_build_config() - if not CMAKE_ONLY: - # build binary - max_jobs = os.getenv("MAX_JOBS", str(max_jobs)) - print(f"Using {max_jobs} jobs for compilation") - cmd_str = [ - get_cmake_bin(), - "--build", - cmake_build_dir, - "--target", - "install", - "--", - "-j", - max_jobs, - ] - subprocess.check_call(cmd_str) + if "clean" in sys.argv: + # only disables BUILD_SETUP, but keep the argument for setuptools + config.build_setup = False + if config.cpp_standard < 20: + raise ValueError("nvfuser requires C++20 standard or higher") -def main(): - # NOTE(crcrpar): Deliberately build basically two dynamic libraries here so that they can - # be treated as "nvfuser_package_data". This function call will put the two of "nvfuser" and - # "nvfuser_codegen" into "./nvfuser/lib", and the former will be "nvfuser._C". - if BUILD_SETUP: - cmake() - if not CMAKE_ONLY: - # NOTE: package include files for cmake - # TODO(crcrpar): Better avoid hardcoding `libnvfuser_codegen.so` - # might can be treated by using `exclude_package_data`. - nvfuser_package_data = [ - "lib/libnvfuser_codegen.so", - "include/nvfuser/*.h", - "include/nvfuser/struct.inl", - "include/nvfuser/C++20/type_traits", - "include/nvfuser/device_lower/*.h", - "include/nvfuser/device_lower/analysis/*.h", - "include/nvfuser/device_lower/pass/*.h", - "include/nvfuser/dynamic_type/*", - "include/nvfuser/dynamic_type/C++20/*", - "include/nvfuser/kernel_db/*.h", - "include/nvfuser/multidevice/*.h", - "include/nvfuser/ops/*.h", - "include/nvfuser/ir/*.h", - "include/nvfuser/python_frontend/*.h", - "include/nvfuser/scheduler/*.h", - "include/nvfuser/serde/*.h", - "include/nvfuser/flatbuffers/*.h", - "include/nvfuser/host_ir/*.h", - "include/nvfuser/id_model/*.h", - "share/cmake/nvfuser/NvfuserConfig*", - # TODO(crcrpar): it'd be better to ship the following two binaries. - # Would need some change in CMakeLists.txt. - # "bin/test_nvfuser", - # "bin/nvfuser_bench" - ] + sys.argv = [sys.argv[0]] + forward_args - setup( - name=WHEEL_NAME, - version=version_tag(), - url="https://github.com/NVIDIA/Fuser", - description="A Fusion Code Generator for NVIDIA GPUs (commonly known as 'nvFuser')", - packages=find_packages(), - ext_modules=[Extension(name="nvfuser._C", sources=[])], - license_files=("LICENSE",), - cmdclass={ - "bdist_wheel": build_whl, - "build_ext": build_ext, - "clean": clean, - }, - package_data={ - "nvfuser": nvfuser_package_data, - }, - install_requires=INSTALL_REQUIRES, - extras_require={ - "test": ["numpy", "expecttest", "pytest"], - **EXTRAS_REQUIRE, - }, - license="BSD-3-Clause", - ) + run(config, version_tag(config), relative_path=".") if __name__ == "__main__": diff --git a/tools/gen_nvfuser_version.py b/tools/gen_nvfuser_version.py deleted file mode 100644 index a09eda53539..00000000000 --- a/tools/gen_nvfuser_version.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -import subprocess -import sys -from pathlib import Path - -UNKNOWN = "Unknown" -nvfuser_root = Path(__file__).parent.parent - - -# note that this root currently is still part of pytorch. -def get_sha() -> str: - try: - return ( - subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=nvfuser_root) - .decode("ascii") - .strip() - ) - except Exception: - import os - - # assume the $NVFUSER_VERSION is in sha form - if nvfuser_version := os.environ.get("NVFUSER_VERSION"): - assert ( - len(nvfuser_version) < 11 - ), "The NVFUSER_VERSION should be in sha form" - return nvfuser_version - return UNKNOWN - - -def get_version() -> str: - sha = get_sha() - version = ( - open((nvfuser_root / "version.txt"), "r").read().strip() + "+git" + sha[:7] - ) - return version - - -def get_pytorch_cmake_prefix(): - from subprocess import Popen, PIPE - - # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch - process_torch_prefix = Popen( - [ - sys.executable, - "-c", - "import torch.utils; print(torch.utils.cmake_prefix_path)", - ], - stdout=PIPE, - ) - stdout_msg, error_msg = process_torch_prefix.communicate() - return stdout_msg.decode("utf-8").rstrip("\n") - - -def get_pytorch_use_distributed(): - from subprocess import Popen, PIPE - - # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch - process_torch_prefix = Popen( - [ - sys.executable, - "-c", - "import torch; print(torch._C._has_distributed())", - ], - stdout=PIPE, - ) - stdout_msg, error_msg = process_torch_prefix.communicate() - return stdout_msg.decode("utf-8").rstrip("\n") - - -if __name__ == "__main__": - version_file = nvfuser_root / "nvfuser" / "version.py" - with open(version_file, "w") as f: - f.write("_version_str = '{}'\n".format(get_version())) diff --git a/tools/gen_nvfuser_version.py b/tools/gen_nvfuser_version.py new file mode 120000 index 00000000000..fef974811db --- /dev/null +++ b/tools/gen_nvfuser_version.py @@ -0,0 +1 @@ +../python/tools/gen_nvfuser_version.py \ No newline at end of file diff --git a/tools/memory.py b/tools/memory.py deleted file mode 100644 index 1ed95f8ded5..00000000000 --- a/tools/memory.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - - -def get_available_memory_gb(): - """Returns the available memory in GB.""" - try: - import psutil - - return psutil.virtual_memory().available / 1024 / 1024 / 1024 - except: # noqa: E722 - pass - - try: - with open("/proc/meminfo", "r") as f: - while True: - line = f.readline() - if line.startswith("MemAvailable:"): - mem = line.split()[1] - assert line.split()[2] == "kB" - return int(mem) / 1024 / 1024 - if not line: - break - except: # noqa: E722 - pass - - return 0 diff --git a/tools/memory.py b/tools/memory.py new file mode 120000 index 00000000000..d818457a563 --- /dev/null +++ b/tools/memory.py @@ -0,0 +1 @@ +../python/tools/memory.py \ No newline at end of file From 95c9bde94c61cda1ad642323aa19a019edf24bf4 Mon Sep 17 00:00:00 2001 From: Ryan Spring Date: Thu, 24 Apr 2025 10:07:47 -0700 Subject: [PATCH 58/68] Change build directory for clang-tidy in lintrunner (#4309) This PR fixes the clang-tidy lintrunner. The `build_dir` argument needs to change from `./build` to `./python/build`. --- .lintrunner.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index d2567f57ff1..f10037e39bb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -185,7 +185,7 @@ command = [ 'python3', 'tools/linter/adapters/clangtidy_linter.py', '--binary=~/.local/bin/clang-tidy', - '--build_dir=./build', + '--build_dir=./python/build', '--', '@{{PATHSFILE}}' ] From fadfde5a44ed28998c99fcb9fd8e796c5b3de142 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> Date: Thu, 24 Apr 2025 21:57:50 -0400 Subject: [PATCH 59/68] Switch axis we use to compute swizzled_tiles (#4311) #4242 turned on "grid traversal factor" which is a good thing. However, it exposed a bug in how we limit that factor to prevent overrun in case the swizzled axis has fewer tiles than the factor. This led to a regression from 58% to 35% geomean perf compared to eager on H200. This PR swaps the axes used to compute the number of swizzled tiles and takes us from a geomean of 35% to 65% on `benchmarks/python/test_matmul.py` on H200. --- benchmarks/python/test_matmul.py | 6 ++++++ csrc/scheduler/matmul_utils.cpp | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/benchmarks/python/test_matmul.py b/benchmarks/python/test_matmul.py index 669b032bd3f..f7886178198 100644 --- a/benchmarks/python/test_matmul.py +++ b/benchmarks/python/test_matmul.py @@ -41,6 +41,9 @@ def test_matmul_baseline_benchmark( ): m, n, k, layout = config + if (m * k + n * k + m * n) * 2 > 20 * (2**30): + pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction @@ -75,6 +78,9 @@ def test_matmul_nvf_benchmark( ): m, n, k, layout = config + if (m * k + n * k + m * n) * 2 > 20 * (2**30): + pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index e65d571ed76..98c0851472e 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -396,7 +396,7 @@ bool fillDefaultHopperHeuristic( // the _other_ dimension to create a new inner dimension. We find the swizzle // factor that is largest and has the least quantization when we divide that // other dimension by the swizzle factor. - int64_t swizzled_tiles = Mtiles <= Ntiles ? Ntiles : Mtiles; + int64_t swizzled_tiles = Mtiles >= Ntiles ? Ntiles : Mtiles; mparams->cta_order = Mtiles <= Ntiles ? MatmulParams::TileRasterizationOrder::ColumnMajor : MatmulParams::TileRasterizationOrder::RowMajor; From 87be6c3ee8d96b1e05058b068d97183d4f432e9c Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 25 Apr 2025 12:13:49 -0700 Subject: [PATCH 60/68] Simplify some tests since sharding propagation is in place (#4304) Co-authored-by: root <26priya11@gmail.com> Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> --- tests/python/multidevice/test_matmul.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index 76ce1939edf..1c877b4acb5 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -81,7 +81,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.bias, self.out]: + for t in [self.inp, self.weight, self.bias]: self.sched._set_device_mesh(t, mesh) # Shard N for weight (N, K) and bias (N) @@ -90,12 +90,6 @@ def multidevice_schedule(self): self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(t) - # Output of linear: {.., i{M}, i{N}, r{K}} - # Shard N -> axis(-2) - self.sched.split(self.out, -2, d, False) - self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x) - self.sched.set_allocation_as_loop(self.out) - torch.cuda.set_device(multidevice_test.local_rank) b, s = 2, 1024 @@ -135,7 +129,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.out]: + for t in [self.inp, self.weight]: self.sched._set_device_mesh(t, mesh) self.sched.split(t, -1, d, False) self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x) From 07effe8da23ae20ce6ed8e1f9510f69861f0a34f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 25 Apr 2025 12:36:48 -0700 Subject: [PATCH 61/68] More precise WAR for resize vectorization (#4305) This is a follow-up to #3906, which added a WAR to #3640. While it's safe, it turned out it's just too conservative. For example, here's a concat pattern appearing in the backward of Litgpt Llama RoPE: ``` Inputs: T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}] T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}] T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}] Outputs: T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf] %kernel_math { T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}] = pad( T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}], {0, 0, 0, 0, 0, 2, 0, 0, 0, 0} ) i31 = 0 + 4; T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}] = pad( T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}], {0, 0, 0, 0, i31, 1, 0, 0, 0, 0} ) i47 = i31 + 1; T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}] = pad( T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}], {0, 0, 0, 0, i47, 0, 0, 0, 0, 0} ) T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}] = cat( T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}], T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}], T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}], 2 ) T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}] = Set.Permute( T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}], cache_op=Streaming ) T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf] = view( T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}] ) } // %kernel_math ``` This is currently taken by the pointwise scheduler, which attempts to vectorize the innermost ID of the output (i.e., `iS52{6144}`). Since the resize ops of the three pad ops are reachable from `iS52`, the WAR of #3640 simply takes them into consideration by calculating gcd with the left and right expand factors. In this case, since there's an expand factor of 1, the resulting vectorization factor is also just 1, which is clearly not what we want. Here, while the resized ID itself is not vectorizable due to the expand factor of 1, all of the resized tensors have large enough inner IDs that should allow the maximum vectorization. To make the WAR a little less conservative, this PR also checks if the constraint by a Resize expr may be missed by the vectorization analysis. In the above case, that should not happen as there's only one path through each of the resize-based tensor ops. This change is still not able to eliminate false positives completely. See one of the new tests that is currently disabled. The codediff results all seem to make sense. http://nv/eFb. Previously some of the tests did not have vectorization due to the WAR, which is relaxed in this PR and allows some vectorization. --- csrc/scheduler/vectorize_helper.cpp | 149 ++++++++++++++++------------ tests/cpp/test_resize.cpp | 91 ++++++++++++++++- 2 files changed, 175 insertions(+), 65 deletions(-) diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 6ec81585869..091f3519a0d 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -842,6 +842,59 @@ std::vector> getTvToContigInnerSizeMapsOf( return mappers; } +// Check if a traversal from vectorized reference IDs may reach the +// IDs of a resize expr without visiting the Resize expr itself. That's +// problematic for the vectorization analysis as the spanning-tree +// based analysis may miss the constraint by the Resize expr. +// +// For this analysis, we start a traversal from the vectorized +// reference IDs to both the input and output of the Resize expr but +// disallow visiting the Resize expr itself. If the traversal is still +// successful, it means there's a path from the reference IDs to the +// resize input and output IDs without visiting the Resize expr. +// +// Permissive BFS is used in this traversal as the vectorized +// reference IDs may not have all the dependencies for the +// traversal. For example, suppose there's a split resshape, and only +// the innermost ID is vectorized. The standard BFS is not able to +// move forward if only the vectorized ID is give as the backward +// split requires both outputs to be presented. +class CanSkipResize : public ValGraphPermissiveBFS { + public: + static bool run( + const ValGraph& graph, + const ValGroups& ref_groups, + Resize* resize) { + ValGroups resize_in_out_groups; + resize_in_out_groups.pushBack(graph.toGroup(resize->in())); + resize_in_out_groups.pushBack(graph.toGroup(resize->out())); + CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize); + bfs.traverse(); + return bfs.allToNodesVisited(); + } + + CanSkipResize( + const ValGraph& graph, + const ValGroups& ref_groups, + const ValGroups& resize_in_out_groups, + Resize* resize) + : ValGraphPermissiveBFS( + graph, + {ref_groups.begin(), ref_groups.end()}, + {resize_in_out_groups.begin(), resize_in_out_groups.end()}, + /*require_all_to_visited=*/false, + /*allowed_direction=*/Direction::Undefined), + resize_(resize) {} + + bool excludeFromTraversal(const NodeType& node) const override { + const ExprGroup* e = std::get_if(&node); + return e != nullptr && (*e)->has(resize_); + } + + private: + Resize* resize_ = nullptr; +}; + // This is a WAR for vectorizing through resized iter domains. The // spanning tree based analysis is not guaranteed to take all resize // ops into considerations (issue @@ -852,84 +905,48 @@ std::unordered_set getResizeVectorizationFactors( TensorView* reference_tv, int64_t break_point) { Fusion* fusion = reference_tv->fusion(); - std::unordered_set factors; const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); if (resize_based_ops.empty()) { - return factors; + return {}; } - IdModel id_model(reference_tv->fusion()); + IdModel id_model(fusion); const auto& graph = id_model.buildExactGraph(); - const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); + std::unordered_set resize_factors; - // For each of resize-based tensor ops, find all resize ops - // that exist between the vectorized reference IDs and the output - // tensor. - for (auto resize_based_op : resize_based_ops) { - auto resize_out = resize_based_op->output(0)->as(); - NVF_ERROR( - resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); - // getAllExprGroupsBetween finds exprs between IDs. To make sure - // the the resize op of this resize_based_op tensor op is found, - // use both the root and logical domains as the traversal targets. - ValGroups resize_inp_out; - resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); - resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); - - auto expr_path = getAllExprGroupsBetween( - graph, - ref_groups, - resize_inp_out, - /*require_all_to_visited=*/false) - .first; - - ValGroups vectorized_groups; - for (auto it = reference_tv->getLogicalDomain().begin() + break_point; - it != reference_tv->getLogicalDomain().end(); - ++it) { - vectorized_groups.pushBack(graph.toGroup(*it)); + auto add_resize_factors = [&](Resize* resize) { + if (!resize->leftExpand()->isZeroInt()) { + resize_factors.insert(resize->leftExpand()); } + if (!resize->rightExpand()->isZeroInt()) { + resize_factors.insert(resize->rightExpand()); + } + }; - // Find all resize exprs that appear in expr_path and depend on - // vectorized_groups. Since expr_path is not guaranteed to be - // topologically sorted, need to loop through the path until - // converged. - - bool something_has_changed = true; - while (something_has_changed) { - something_has_changed = false; - for (const auto& [expr_g, dir] : expr_path) { - const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); - if (std::none_of( - inputs.begin(), inputs.end(), [&](const ValGroup& inp) { - return vectorized_groups.has(inp); - })) { - continue; - } - - if (vectorized_groups.pushBack( - getOutputsOfExprGroup(graph, expr_g, dir))) { - something_has_changed = true; - } - - auto resize = dynamic_cast(expr_g->front()); - if (resize == nullptr) { - continue; - } + const ValGroups ref_vec_groups = graph.toGroups(std::vector{ + reference_tv->getLogicalDomain().begin() + break_point, + reference_tv->getLogicalDomain().end()}); + + // For each of Resize exprs, if it's reachable from the reference + // vectorized IDs without visiting the Resize expr itself, its + // constraint may not be reflectd in the inner sizes. + for (auto resize : resize_based_ops) { + auto resize_out_tv = resize->output(0)->as(); + for (const auto logical_id : resize_out_tv->getLogicalDomain()) { + auto resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } - // These three vals need to be divisible - factors.emplace(resize->leftExpand()); - factors.emplace(resize->rightExpand()); - factors.emplace( - dir == Direction::Forward ? resize->out()->extent() - : resize->in()->extent()); + if (CanSkipResize::run(graph, ref_vec_groups, resize)) { + add_resize_factors(resize); } } } - return factors; + return resize_factors; } } // namespace @@ -1028,7 +1045,11 @@ int64_t getVectorizationFactor( if (!inferred_val.hasValue()) { return 1; } - max_vec_size = std::gcd(max_vec_size, inferred_val.as()); + auto inferred_val_int = inferred_val.as(); + if (inferred_val_int == 0) { + continue; + } + max_vec_size = std::gcd(max_vec_size, inferred_val_int); } return max_vec_size; diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 12ca31a929f..4aaac477a52 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5970,7 +5970,7 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) { } } -TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { +TEST_F(ResizeTest, VectorizeInnerSliceMultiplePaths) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -6005,6 +6005,50 @@ TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); } +// The current analysis is not precise enough to pass this test +TEST_F(ResizeTest, DISABLED_VectorizeOuterSliceMultiplePaths) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape{4, 1024 * 1024}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(2), + IrBuilder::create(2)}); + auto tv2 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(4)}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + auto outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); + + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto out_tv = tv3; + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 4); +} + // Repro of issue #4202 TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { auto fusion_ptr = std::make_unique(); @@ -6040,4 +6084,49 @@ TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { testValidate(&fusion, outputs.outputs, {t0, t1}, __LINE__, __FILE__); } +// Check if vectorization is properly applied even when a resized ID +// is reachable from vectorized IDs. Pattern extracted from Litgpt +// LLama RoPE backward. +TEST_F(ResizeTest, VectorizeOuterPad) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape1{1, 8, 4, 8192, 128}; + const std::vector shape2{1, 8, 1, 8192, 128}; + auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv2); + + // [1, 8, 6, 8192, 128] + auto tv3 = cat({tv0, tv1, tv2}, 2); + // [1, 8192, 8, 6, 128] + auto tv4 = permute(tv3, {0, 3, 1, 2, 4}); + auto tv5 = reshape(tv4, {1, 8192, 8, 6, 128}, {1, 8192, 6144}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t1 = at::randn(shape2, options); + auto t2 = at::randn(shape2, options); + + auto outputs = + scheduleAndRun(&fusion, SchedulerType::PointWise, {t0, t1, t2}); + testValidate(&fusion, outputs.outputs, {t0, t1, t2}, __LINE__, __FILE__); + + auto out_tv = tv5; + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 8); +} + } // namespace nvfuser From 3fe1c32124ea83f90253496cecbe5becf46c9df0 Mon Sep 17 00:00:00 2001 From: samnordmann Date: Sun, 27 Apr 2025 10:58:18 +0200 Subject: [PATCH 62/68] [Cuda Ipc] Add barrier at the end of `IpcHandleCache::exchangeHandles` (#4308) A synchronization is needed at the `IpcHandleCache::exchangeHandles` to avoid the exporter exporting twice before the importer delete the key to the store. Should fix a bug [observed in the CI](https://gitlab-master.nvidia.com/dl/pytorch/fuser-gh-mirror/-/jobs/160018431). cc [Team thread](https://teams.microsoft.com/l/message/19:e875a0fc9d0747b9bde9715c9b6093ae@thread.tacv2/1745028577674?tenantId=43083d15-7273-40c1-b7db-39efd9ccc17a&groupId=74537292-c203-4a6d-aea0-40c527b969f7&parentMessageId=1745028577674&teamName=csarofeen-staff-and-guests&channelName=nvFuser-MultiGPU&createdTime=1745028577674) --- csrc/multidevice/ipc_handle.cpp | 6 ++++++ tests/cpp/test_multidevice_communications.cpp | 2 +- tests/cpp/test_multidevice_host_ir.cpp | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/ipc_handle.cpp b/csrc/multidevice/ipc_handle.cpp index 6bb700dc2de..9a5ec4286b8 100644 --- a/csrc/multidevice/ipc_handle.cpp +++ b/csrc/multidevice/ipc_handle.cpp @@ -151,6 +151,12 @@ void IpcHandleCache::exchangeHandles( insert(communication, std::move(ipc_handles)); } + + // a second barrier is needed here to ensure all ranks have received the + // memhandles and the keys are deleted from the store before the next call to + // exchangeHandles + // TODO: precisely select what ranks need to wait on that barrier. + communicator->barrier(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_communications.cpp b/tests/cpp/test_multidevice_communications.cpp index 1b6ce59801c..af0c0719aa7 100644 --- a/tests/cpp/test_multidevice_communications.cpp +++ b/tests/cpp/test_multidevice_communications.cpp @@ -417,7 +417,7 @@ INSTANTIATE_TEST_SUITE_P( using P2PCommunicationTest = MultiDeviceTest; -TEST_F(P2PCommunicationTest, DISABLED_CudaComm) { +TEST_F(P2PCommunicationTest, CudaComm) { static constexpr int kTensorSize = 8; static constexpr int kNumRepetitions = 32; diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 0b6efbd15a4..88286d6e4c0 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -478,7 +478,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_linear) { EXPECT_TRUE(torch::allclose(out_ref, out_at, 1e-1, 1e-1)); } -TEST_F(MultiDeviceTest, DISABLED_ShareIpcMemHandles) { +TEST_F(MultiDeviceTest, ShareIpcMemHandles) { static constexpr int kTensorSize = 4; static constexpr int kNumRepetitions = 10; From cf5c6d2bf40fb7708c45f9dd08ff751f3f129a47 Mon Sep 17 00:00:00 2001 From: samnordmann Date: Sun, 27 Apr 2025 13:16:23 +0200 Subject: [PATCH 63/68] [Host ir] support for set reduce and binary op (#4146) This PR belongs to a series of stacked PRs: 1. #4144 2. #4145 3. **=> You are here:** #4146 4. #4147 Add support for `LoadStoreOp`, `BinaryOp`, `ReductionOp`, including support for pre-allocated output, which is not provided by ExprEvaluator. --------- Co-authored-by: Jingyue Wu --- csrc/host_ir/executor.cpp | 123 ++++++++++++++++ csrc/host_ir/executor.h | 3 + csrc/host_ir/lower.cpp | 28 +++- tests/cpp/test_host_irs.cpp | 180 ++++++++++++++++++++++++ tests/cpp/test_multidevice_pipeline.cpp | 131 ----------------- 5 files changed, 333 insertions(+), 132 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 89710eaae4b..4ceef8927ed 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -631,6 +632,56 @@ void HostIrEvaluator::handle(LinearOp* linear) { } } +void HostIrEvaluator::handle(LoadStoreOp* load_store_op) { + NVF_ERROR( + load_store_op->opType() == LoadStoreOpType::Set, + "LoadStoreOp must be a Set"); + NVF_ERROR( + load_store_op->out()->isA(), "out must be a TensorView"); + auto* out_tv = load_store_op->out()->as(); + auto in_tensor = getKnownConcreteValue(load_store_op->in()).as(); + + at::Tensor t; + if (out_tv->hasRoot()) { + std::optional> permutation = + ir_utils::computePermutation( + out_tv->getRootDomain(), out_tv->getLogicalDomain()); + NVF_ERROR( + permutation.has_value(), + "The logical domain of a Set.Permute is supposed to be a permutation" + " of the root domain: ", + out_tv); + t = in_tensor.permute(*permutation); + } else { + t = in_tensor; + } + + if (isKnown(out_tv)) { + auto out_tensor = + getKnownConcreteValue(load_store_op->out()).as(); + out_tensor.copy_(t); + } else { + // For completeness, we may check if out_tv's allocation matches `t` and + // copy data if yes. For example, + // + // clang-format off + // ``` + // const auto& [sizes, strides] = inferShapeOfOutput(out_tv, expr_evaluator_); + // if (strides == t.strides()) { + // bind(out_tv, t); + // } else { + // auto out_tensor = at::empty_strided(sizes, strides, in_tensor.dtype()); + // out_tensor.copy_(t); + // bind_(out_tv, out_tensor); + // } + // ``` + // clang-format on + // + // For now, I choose to keep code simple for the limited use cases. + bind(out_tv, t); + } +} + void HostIrEvaluator::handle(kir::Allocate* allocate) { NVF_ERROR( allocate->buffer()->isA(), @@ -654,6 +705,78 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } +void HostIrEvaluator::handle(BinaryOp* binary_op) { + if (!isKnown(binary_op->outputs().at(0))) { + return unhandled(binary_op); + } + + auto lhs = getKnownConcreteValue(binary_op->inputs().at(0)).as(); + auto rhs = getKnownConcreteValue(binary_op->inputs().at(1)).as(); + auto output = + getKnownConcreteValue(binary_op->outputs().at(0)).as(); + + switch (binary_op->getBinaryOpType()) { + case BinaryOpType::Add: + at::add_out(output, lhs, rhs); + break; + case BinaryOpType::Sub: + at::sub_out(output, lhs, rhs); + break; + case BinaryOpType::Mul: + at::mul_out(output, lhs, rhs); + break; + case BinaryOpType::Div: + at::div_out(output, lhs, rhs); + break; + default: + NVF_THROW( + "Unexpected operator type: ", + binary_op->getBinaryOpType(), + " in ", + binary_op); + } +} + +void HostIrEvaluator::handle(ReductionOp* reduction_op) { + auto input_tv = reduction_op->in()->as(); + auto output_tv = reduction_op->out()->as(); + if (!isKnown(output_tv)) { + return unhandled(reduction_op); + } + + NVF_ERROR( + !output_tv->hasRoot(), + "Evaluation for rFactored reductions is not supported."); + auto input = getKnownConcreteValue(input_tv).as(); + auto output = getKnownConcreteValue(output_tv).as(); + + std::vector reduction_axes; + for (const auto i : + c10::irange(int64_t(output_tv->getLogicalDomain().size()))) { + auto ax = output_tv->getLogicalDomain().at(i); + if (ax->isReduction()) { + reduction_axes.push_back(i); + } + } + switch (reduction_op->getReductionOpType()) { + case BinaryOpType::Add: + at::sum_out(output, input, reduction_axes); + return; + case BinaryOpType::Max: + at::amax_out(output, input, reduction_axes); + return; + case BinaryOpType::Min: + at::amin_out(output, input, reduction_axes); + return; + default: + NVF_THROW( + "Unexpected operator type: ", + reduction_op->getReductionOpType(), + " in ", + reduction_op); + } +} + void HostIrEvaluator::unhandled(Statement* stmt) { NVF_ERROR(stmt->isA(), stmt, " must be an Expr"); auto* expr = stmt->as(); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index d71b74e0dda..dfe84fba068 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -134,6 +134,9 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(MatmulOp* matmul) override; void handle(LinearOp* linear) override; void handle(kir::Allocate* allocate) override; + void handle(LoadStoreOp* load_store_op) override; + void handle(BinaryOp* binary_op) override; + void handle(ReductionOp* reduction_op) override; void handle(ShareMemHandles* share_mem_handles) override; void unhandled(Statement* stmt) override; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 1cdc21e60c0..fd14096b190 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -615,7 +615,33 @@ std::vector HostIrLower::lowerToCollectiveBasedPipelinedGemmComm( } bool HostIrLower::isLowerableAsStandaloneHostOp(Expr* expr) { - return isResharding(expr); + if (expr->isOneOf< + MatmulOp, + SliceOp, + SelectOp, + LinearOp, + BinaryOp, + ReductionOp, + Communication, + P2PCommunication>()) { + return true; + } + + // Lower as standalone op "set" ops, i.e., LoadStoreOp of "Set" type with no + // permute + if (expr->isA()) { + auto* load_store = expr->as(); + if (load_store->opType() == LoadStoreOpType::Set && + load_store->out()->isA()) { + auto* tv = load_store->out()->as(); + // If the output tensor has no root, it means it has no permute + if (!tv->hasRoot()) { + return true; + } + } + } + + return false; } bool HostIrLower::shouldMergeSegmentedGroups( diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 6a41e47c744..633ebc83504 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -1307,6 +1307,186 @@ TEST_F(HirAlias, ThrowOnInputAlias) { EXPECT_ANY_THROW(HostIrEvaluator hie(std::move(hic))); } +using HirSetTest = NVFuserTest; + +TEST_F(HirSetTest, HostIr) { + const std::vector sizes = {8, 64}; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in = makeConcreteTensor(sizes); + auto* out = makeConcreteTensor(sizes); + auto* set = IrBuilder::create(LoadStoreOpType::Set, out, in); + hic->addInput(in); + hic->addInput(out); + hic->pushBackTopLevelExprs(set); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto in_aten = at::randn(sizes, options); + auto out_aten = at::empty(sizes, options); + + hie.runWithInput({{in, in_aten}, {out, out_aten}}); + + EXPECT_TRUE(out_aten.equal(in_aten)) + << "Obtained output: " << out_aten << "\n" + << "Expected output: " << in_aten; +} + +class HirBinaryOpTest : public NVFuserFixtureParamTest { + protected: + at::Tensor executeBinaryOp(at::Tensor lhs, at::Tensor rhs) { + switch (GetParam()) { + case BinaryOpType::Add: + return lhs + rhs; + case BinaryOpType::Sub: + return lhs - rhs; + case BinaryOpType::Mul: + return lhs * rhs; + case BinaryOpType::Div: + return lhs / rhs; + default: + NVF_ERROR("Unsupported binary op type ", GetParam()); + return at::Tensor(); + } + } +}; + +TEST_P(HirBinaryOpTest, PreAllocatedOutputs) { + const std::vector sizes = {8, 64}; + const auto& binary_op_type = GetParam(); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* lhs = makeConcreteTensor(sizes); + auto* rhs = makeConcreteTensor(sizes); + auto* out = makeConcreteTensor(sizes); + auto* binary_op = IrBuilder::create(binary_op_type, out, lhs, rhs); + hic->addInput(lhs); + hic->addInput(rhs); + hic->addInput(out); + hic->pushBackTopLevelExprs(binary_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto lhs_aten = at::randn(sizes, options); + auto rhs_aten = at::randn(sizes, options); + auto out_aten = at::empty(sizes, options); + + hie.runWithInput({{lhs, lhs_aten}, {rhs, rhs_aten}, {out, out_aten}}); + + at::Tensor expected_out = executeBinaryOp(lhs_aten, rhs_aten); + EXPECT_TRUE(expected_out.equal(out_aten)) + << "Obtained output: " << out_aten << "\n" + << "Expected output: " << expected_out; +} + +TEST_P(HirBinaryOpTest, NonPreAllocatedOutputs) { + const std::vector sizes = {8, 64}; + const auto& binary_op_type = GetParam(); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* lhs = makeConcreteTensor(sizes); + auto* rhs = makeConcreteTensor(sizes); + auto* out = binaryOp(binary_op_type, lhs, rhs); + hic->addInput(lhs); + hic->addInput(rhs); + hic->addOutput(out); + hic->pushBackTopLevelExprs(out->definition()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto lhs_aten = at::randn(sizes, options); + auto rhs_aten = at::randn(sizes, options); + + auto out_aten = + hie.runWithInput({{lhs, lhs_aten}, {rhs, rhs_aten}})[0].as(); + + at::Tensor expected_out = executeBinaryOp(lhs_aten, rhs_aten); + EXPECT_TRUE(expected_out.equal(out_aten)) + << "Obtained output: " << out_aten << "\n" + << "Expected output: " << expected_out; +} + +INSTANTIATE_TEST_SUITE_P( + , + HirBinaryOpTest, + testing::Values( + BinaryOpType::Add, + BinaryOpType::Sub, + BinaryOpType::Mul, + BinaryOpType::Div), + [](const testing::TestParamInfo& info) -> std::string { + std::stringstream ss; + ss << "BinaryOpType_" << info.param; + return ss.str(); + }); + +using HirReductionOpTest = NVFuserTest; + +TEST_F(HirReductionOpTest, PreAllocatedOutputs) { + constexpr int64_t size0 = 8, size1 = 64; + constexpr int64_t reduction_axis = 1; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in = makeConcreteTensor({size0, size1}); + auto* out = newForReduction(in, {reduction_axis}, in->dtype()); + auto* reduction_op = IrBuilder::create( + BinaryOpType::Add, hic->zeroVal(), out, in); + hic->addInput(in); + hic->addOutput(out); + hic->pushBackTopLevelExprs(reduction_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto in_aten = at::randn({size0, size1}, options); + auto out_aten = at::empty({size0}, options); + + hie.runWithInput({{in, in_aten}, {out, out_aten}}); + + at::Tensor expected_out = in_aten.sum(reduction_axis); + EXPECT_TRUE(expected_out.equal(out_aten)) + << "Obtained output: " << out_aten << "\n" + << "Expected output: " << expected_out; +} + +TEST_F(HirReductionOpTest, NonPreAllocatedOutputs) { + constexpr int64_t size0 = 8, size1 = 64; + constexpr int64_t reduction_axis = 1; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto* in = makeConcreteTensor({size0, size1}); + auto* out = sum(in, {reduction_axis}); + hic->addInput(in); + hic->addOutput(out); + hic->pushBackTopLevelExprs(out->definition()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + auto in_aten = at::randn({size0, size1}, options); + auto out_aten = at::empty({size0}, options); + + hie.runWithInput({{in, in_aten}, {out, out_aten}}); + + at::Tensor expected_out = in_aten.sum(reduction_axis); + EXPECT_TRUE(expected_out.equal(out_aten)) + << "Obtained output: " << out_aten << "\n" + << "Expected output: " << expected_out; +} + } // namespace hir } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 5985571c57a..12dfed5dd43 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -457,135 +457,4 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(0, 1), testing::Values(true))); -// Different scheduling modes used in -// PipelineTestStagedReduction.StagedReduction -enum class SchedulingMode { - // Manual interdevice scheduling, no intra-device scheduling - InterDeviceOnly, - // Manual inter-/intra-device scheduling - Manual, - // Manual inter-device scheduling, composed with fully automated intra-device - // scheduling (through FusionExecutorCache) - Automatic, -}; - -std::ostream& operator<<(std::ostream& out, const SchedulingMode& mode) { - switch (mode) { - case SchedulingMode::InterDeviceOnly: - out << "InterDeviceOnly"; - break; - case SchedulingMode::Manual: - out << "Manual"; - break; - case SchedulingMode::Automatic: - out << "Automatic"; - break; - } - return out; -} - -class PipelineTestStagedReduction - : public PipelineTest, - public ::testing::WithParamInterface {}; - -// 1D staged reduction -// Inputs: X[num_devices,B,C] -TEST_P(PipelineTestStagedReduction, StagedReduction) { - auto scheduling_mode = GetParam(); - - const int num_devices = communicator_->size(); - constexpr int B = 8; - constexpr int C = 64; - - FusionGuard fg(fusion.get()); - // The first dimension is made symbolic so `tv_out->definition()` won't - // become a squeeze when num_devices == 1. This wouldn't be a problem for - // automatic mode. However, for the manual mode, the scheduling code below - // assumes `tv_out->definition()` can be lowered to communication. A squeeze - // can't. - TensorView* tv0 = TensorViewBuilder() - .dtype(DataType::Float) - .contiguity(true) - .shape({-1, B, C}) - .build(); - auto mesh = DeviceMesh::createForNumDevices(num_devices); - tv0->setDeviceMesh(mesh); - TensorView* tv1 = sum(tv0, {2}); - TensorView* tv_out = sum(tv1, {0}); - fusion->addInput(tv0); - fusion->addOutput(tv_out); - - for (auto* tv : {tv0, tv1}) { - tv->axis(0)->parallelize(ParallelType::DIDx); - } - - // Intra-device reduction scheduling for the first reduction: - switch (scheduling_mode) { - case SchedulingMode::InterDeviceOnly: - break; - case SchedulingMode::Manual: { - // inspired from NVFuserTest.FusionReduction1_CUDA - // tv0[I0{A}, I1{B}, I2{C}] - tv1->split(2, 32); - // tv1[I0{A}, I1{B}, R2o{C/32}, R2i{32}] = tv0[I0{A}, I1{B}, I2{C}] - tv1->split(2, 4); - // clang-format off - // tv1[I0{A}, I1{B}, R2oo{C/32/4)}, R2oi{4}, R2i{32}] = tv0[I0{A}, I1{B}, I2{C}] - // clang-format on - - TensorView* tv2 = tv1->rFactor({2}); - // clang-format off - // tv2[I0{A}, I1{B}, R2oo{C/32/4)}, I2oi{4}, I2i{32}] = tv0[I0{A}, I1{B}, I2{C}] - // tv1[I0{A}, I1{B}, R2oi{4}, R2i{32}] = tv2[I0{A}, I1{B}, R2oo{C/32/4)}, I2oi{4}, I2i{32}] - // clang-format on - - TensorView* tv3 = tv1->rFactor({2}); - // clang-format off - // tv2[I0{A}, I1{B}, R2oo{C/32/4)}, I2oi{4}, I2i{32}] = tv0[I0{A}, I1{B}, I2{C}] - // tv3[I0{A}, I1{B}, R2oi{4}, I2i{32}] = tv2[I0{A}, I1{B}, R2oo{C/32/4)}, I2oi{4}, I2i{32}] - // tv1[I0{A}, I1{B}, R2i{32}] = tv3[I0{A}, I1{B}, R2oi{4}, I2i{32}] - // clang-format on - - // tv1 is a segment boundary so must be in global. This wouldn't be - // needed if the fusion were scheduled automatically. - tv1->setMemoryType(MemoryType::Global); - - // Use `tv2` as the reference tensor because it contains the most - // parallel IterDomains. - tv2->axis(1)->parallelize(ParallelType::BIDx); - tv2->axis(3)->parallelize(ParallelType::Unroll); - tv2->axis(-1)->parallelize(ParallelType::TIDx); - scheduler_utils::parallelizeAllLike( - tv2, - /*pos=*/-1, - // Don't propagate the parallelization to `tv_out` because that's in - // a different, resharding segment. - /*selected_tv=*/{tv0, tv1, tv2, tv3}); - inlineMost(); - break; - } - case SchedulingMode::Automatic: - host_ir_executor_params.use_fusion_executor_cache = true; - break; - } - - at::Tensor unsharded_input_tensor = - at::randn({num_devices, B, C}, tensor_options); - at::Tensor ref_unsharded_output_tensor = - unsharded_input_tensor.sum(at::IntArrayRef({0, 2})); - unsharded_args = {unsharded_input_tensor}; - ref_unsharded_outputs = {ref_unsharded_output_tensor}; - - executeAndValidate(/* validate_with_prescribed_values */ true); -} - -INSTANTIATE_TEST_SUITE_P( - , - PipelineTestStagedReduction, - testing::Values( - SchedulingMode::InterDeviceOnly, - SchedulingMode::Manual, - SchedulingMode::Automatic), - testing::PrintToStringParamName()); - } // namespace nvfuser From 7f7caf5177936b7387d4e443768788584de2702a Mon Sep 17 00:00:00 2001 From: snordmann Date: Sun, 27 Apr 2025 04:17:41 -0700 Subject: [PATCH 64/68] change namespace of the optimization pass to hir --- csrc/host_ir/lower.cpp | 2 +- csrc/host_ir/pass/stream_parallel_type.cpp | 4 +-- csrc/host_ir/pass/stream_parallel_type.h | 8 +++--- csrc/python_frontend/fusion_definition.cpp | 2 +- tests/cpp/test_host_ir_stream_lowering.cpp | 30 +++++++++++----------- tests/cpp/test_multidevice_host_ir.cpp | 4 +-- 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 227b76eb597..313970f703c 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -769,7 +769,7 @@ std::unique_ptr HostIrLower::lower( } hic->resetTopLevelExprs(new_top_level_exprs); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); return hic; diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp index 8bea4c3430f..d7bfa0f090a 100644 --- a/csrc/host_ir/pass/stream_parallel_type.cpp +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -18,7 +18,7 @@ #include #include -namespace nvfuser::preseg_passes { +namespace nvfuser::hir { namespace { @@ -437,4 +437,4 @@ void StreamParallelType::runPass(Fusion* fusion) { hic->resetTopLevelExprs(top_level_exprs); } -} // namespace nvfuser::preseg_passes +} // namespace nvfuser::hir diff --git a/csrc/host_ir/pass/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h index 9c0c39efe87..f389dbe1ff7 100644 --- a/csrc/host_ir/pass/stream_parallel_type.h +++ b/csrc/host_ir/pass/stream_parallel_type.h @@ -10,7 +10,7 @@ #include #include -namespace nvfuser::preseg_passes { +namespace nvfuser::hir { // A pass used in HostIrLower that takes a HostIrContainer as input, reads the // TensorView's ParallelType::Stream, and modify the the HostIrContainer's top @@ -22,8 +22,8 @@ namespace nvfuser::preseg_passes { // An illustration of the pass can be found in the tests // `test_host_ir_stream_lowering.cpp` // with the option `NVFUSER_DUMP=host_ir`. -class StreamParallelType : public OptimizationPass { - friend class OptimizationPass; +class StreamParallelType : public preseg_passes::OptimizationPass { + friend class preseg_passes::OptimizationPass; protected: static void runPass(Fusion* fusion); @@ -32,4 +32,4 @@ class StreamParallelType : public OptimizationPass { } }; -} // namespace nvfuser::preseg_passes +} // namespace nvfuser::hir diff --git a/csrc/python_frontend/fusion_definition.cpp b/csrc/python_frontend/fusion_definition.cpp index d6e552032b1..fd3e714f7cb 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/csrc/python_frontend/fusion_definition.cpp @@ -455,7 +455,7 @@ std::pair> FusionDefinition:: params.lower.communicator_backend = backend_type_; // Disable StreamParallelType pass temporarily as proper stream lowering // gets implemented - preseg_passes::OptimizationPassGuard + preseg_passes::OptimizationPassGuard guard(false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index f6d74caea87..e03fccb34e0 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -36,7 +36,7 @@ TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { tv->axis(0)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Split) { @@ -51,7 +51,7 @@ TEST_F(HirLowerStreamTest, Split) { tv1->axis(0)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Merge) { @@ -66,7 +66,7 @@ TEST_F(HirLowerStreamTest, Merge) { tv1->axis(0)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } TEST_F(HirLowerStreamTest, SingleSetOp) { @@ -81,7 +81,7 @@ TEST_F(HirLowerStreamTest, SingleSetOp) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -111,7 +111,7 @@ TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { tv1->setMemoryType(MemoryType::Global); tv1->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -144,7 +144,7 @@ TEST_F(HirLowerStreamTest, SingleBinaryOp) { tv2->setMemoryType(MemoryType::Global); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -180,7 +180,7 @@ TEST_F(HirLowerStreamTest, TwoSetOps) { tv1->axis(0)->parallelize(ParallelType::Stream); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 3); @@ -218,7 +218,7 @@ TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { tv1->axis(0)->parallelize(ParallelType::Stream); tv3->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 5); @@ -252,7 +252,7 @@ TEST_F(HirLowerStreamTest, ReductionUnsupported) { tv1->axis(0)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Reduction) { @@ -267,7 +267,7 @@ TEST_F(HirLowerStreamTest, Reduction) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -301,7 +301,7 @@ TEST_F(HirLowerStreamTest, Matmul_M) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -338,7 +338,7 @@ TEST_F(HirLowerStreamTest, BatchedMatmul) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -375,7 +375,7 @@ TEST_F(HirLowerStreamTest, Matmul_N) { c->setMemoryType(MemoryType::Global); c->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( + preseg_passes::OptimizationPass::runPass( hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); @@ -413,7 +413,7 @@ TEST_F(HirLowerStreamTest, Matmul_K) { c->axis(-1)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } // We don's support PostOnStream because it does not support well pre-allocated @@ -461,7 +461,7 @@ TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { output->axis(-1)->parallelize(ParallelType::Stream); EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - preseg_passes::StreamParallelType>::runPass(hic.get())); + StreamParallelType>::runPass(hic.get())); } } // namespace hir diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 7b233bc47db..6932a40fe5c 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -365,7 +365,7 @@ using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { // Disable StreamParallelType pass temporarily as proper stream lowering gets // implemented - preseg_passes::OptimizationPassGuard guard( + preseg_passes::OptimizationPassGuard guard( false); constexpr int64_t M = 32768; @@ -424,7 +424,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { TEST_F(OverlapDistributedMatmulTest, AG_linear) { // Disable StreamParallelType pass tempor - preseg_passes::OptimizationPassGuard guard( + preseg_passes::OptimizationPassGuard guard( false); constexpr int64_t M = 32768; From bfc7ba836400aa349fab473fa04bab204e9c5601 Mon Sep 17 00:00:00 2001 From: samnordmann Date: Sun, 27 Apr 2025 16:21:05 +0200 Subject: [PATCH 65/68] add HirAliasSelect (#4301) # What Add a `SelectOp`-like HIR to express indexing into ATen tensor. # Why it is used in the context of stream lowering, see https://github.com/NVIDIA/Fuser/pull/4147 and especially the discussion in https://github.com/NVIDIA/Fuser/pull/4147#discussion_r2055365814 --- csrc/dispatch.h | 3 ++- csrc/host_ir/executor.cpp | 9 ++++++++ csrc/host_ir/executor.h | 1 + csrc/host_ir/host_ir.cpp | 45 +++++++++++++++++++++++++++++++++++++ csrc/host_ir/host_ir.h | 43 +++++++++++++++++++++++++++++++++++ tests/cpp/test_host_irs.cpp | 38 +++++++++++++++++++++++++++++++ 6 files changed, 138 insertions(+), 1 deletion(-) diff --git a/csrc/dispatch.h b/csrc/dispatch.h index f1f4153d1d2..007287e49f7 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -160,7 +160,8 @@ class Val; f(Synchronize); \ f(StartCoalescing); \ f(EndCoalescing); \ - f(ShareMemHandles); + f(ShareMemHandles); \ + f(HirAliasSelect); // Forward declarations for all Val and Expr types diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 4ceef8927ed..2f2cf9e7b92 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -705,6 +705,15 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } +void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { + auto index = + expr_evaluator_.evaluate(hir_alias_select->index()).as(); + auto input = getKnownConcreteValue(hir_alias_select->in()->as()) + .as(); + int64_t axis = hir_alias_select->axis(); + bind(hir_alias_select->out(), input.select(axis, index)); +} + void HostIrEvaluator::handle(BinaryOp* binary_op) { if (!isKnown(binary_op->outputs().at(0))) { return unhandled(binary_op); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index dfe84fba068..7894245de75 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -138,6 +138,7 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(BinaryOp* binary_op) override; void handle(ReductionOp* reduction_op) override; void handle(ShareMemHandles* share_mem_handles) override; + void handle(HirAliasSelect* hir_alias_select) override; void unhandled(Statement* stmt) override; c10::cuda::CUDAStream getCUDAStream(Stream* stream); diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 9e1386d0d3d..bf3d5cef9eb 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -355,6 +355,51 @@ std::string ShareMemHandles::toInlineString(int indent_size) const { NVF_THROW("Cannot be printed inline"); } +HirAliasSelect::HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index) + : Expr(passkey, {in, index}, {}, {}) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + this, + "must be registered in a HostIrContainer"); + NVF_ERROR( + static_cast(in->getLogicalDomain().size()) > axis, + "Select axis ", + axis, + " is out of bounds for tensor ", + in->toString(), + " with ", + in->getLogicalDomain().size(), + " dimensions"); + // "out" is not added as an output because the current op doesn't "define" it, + // but rather sets its allocation. Since "out" will be used in another + // producing expression, this avoids unnecessary cyclic dependencies. This + // ressembles how kir::Allocate treats its allocated TensorView. + addAttribute(out); + addDataAttribute(axis); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(HirAliasSelect) + +std::string HirAliasSelect::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent_size++; + indent(ss, indent_size) << " = HirAliasSelect( " << in()->toString() + << ", axis = " << in()->getLogicalDomain().at(axis()) + << ", index = " << index()->toString() << " )\n"; + return ss.str(); +} + +std::string HirAliasSelect::toInlineString(int indent_size) const { + NVF_THROW("Cannot be printed inline"); +} + } // namespace hir } // namespace nvfuser diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bad3a6ef722..d267d23ab1f 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -351,6 +351,49 @@ class ShareMemHandles : public Expr { } }; +// This op mimicks the semantics of SelectOp but is used in HIR non-SSA context +// to index into a TensorView, returning an alias "slice" of the original +// TensorView. +class HirAliasSelect : public Expr { + public: + using Expr::Expr; + HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index); + + HirAliasSelect(const HirAliasSelect& other) = delete; + HirAliasSelect& operator=(const HirAliasSelect& other) = delete; + HirAliasSelect(HirAliasSelect&& other) = delete; + HirAliasSelect& operator=(HirAliasSelect&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::HirAliasSelect"; + } + + TensorView* in() const { + return inputs().at(0)->as(); + } + + TensorView* out() const { + return attributeVal(0)->as(); + } + + int64_t axis() const { + return attribute(1); + } + + Val* index() const { + return inputs().at(1); + } +}; + } // namespace hir } // namespace nvfuser diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 633ebc83504..06c029fb6f6 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -1307,6 +1307,44 @@ TEST_F(HirAlias, ThrowOnInputAlias) { EXPECT_ANY_THROW(HostIrEvaluator hie(std::move(hic))); } +using HirAliasSelectHostIrTest = NVFuserTest; + +TEST_F(HirAliasSelectHostIrTest, SelectingTensor) { + constexpr int64_t ndims = 2; + constexpr int64_t dim = 1; + constexpr int64_t index = 3; + const std::vector input_sizes = {32, 32}; + + ASSERT_LT(dim, ndims); + ASSERT_EQ(input_sizes.size(), ndims); + ASSERT_LT(index, input_sizes.at(dim)); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + TensorView* in = makeContigTensor(ndims); + TensorView* out = makeContigTensor(ndims - 1); + auto* index_val = IrBuilder::create(index, DataType::Index); + auto* select_op = IrBuilder::create(in, out, dim, index_val); + + hic->addInput(in); + hic->addOutput(out); + hic->pushBackTopLevelExprs(select_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); + auto in_aten = at::randn(input_sizes, options); + std::unordered_map concrete_input_buffers = { + {in, in_aten}}; + + auto out_aten = hie.runWithInput(concrete_input_buffers)[0].as(); + + // validate + auto ref_out = in_aten.select(dim, index); + EXPECT_TRUE(ref_out.equal(out_aten)); +} + using HirSetTest = NVFuserTest; TEST_F(HirSetTest, HostIr) { From 7777fe0038c0c1568f5f8abc91b20618bb360e63 Mon Sep 17 00:00:00 2001 From: snordmann Date: Sun, 27 Apr 2025 07:27:05 -0700 Subject: [PATCH 66/68] lint --- csrc/host_ir/executor.cpp | 2 +- csrc/host_ir/lower.cpp | 3 +- csrc/host_ir/pass/stream_parallel_type.h | 3 +- python/python_frontend/fusion_definition.cpp | 4 +- tests/cpp/test_host_ir_stream_lowering.cpp | 51 ++++++++------------ tests/cpp/test_multidevice_host_ir.cpp | 6 +-- 6 files changed, 29 insertions(+), 40 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 96a462a9444..e089fec32cf 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -771,7 +771,7 @@ void HostIrEvaluator::handle(ReductionOp* reduction_op) { } } switch (reduction_op->getReductionOpType()) { - case BinaryOpType::Add: + case BinaryOpType::Add: at::sum_out(output, input, reduction_axes); return; case BinaryOpType::Max: diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index 313970f703c..ca9bb80ae4e 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -769,8 +769,7 @@ std::unique_ptr HostIrLower::lower( } hic->resetTopLevelExprs(new_top_level_exprs); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); return hic; } diff --git a/csrc/host_ir/pass/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h index f389dbe1ff7..8b5f138ad7e 100644 --- a/csrc/host_ir/pass/stream_parallel_type.h +++ b/csrc/host_ir/pass/stream_parallel_type.h @@ -22,7 +22,8 @@ namespace nvfuser::hir { // An illustration of the pass can be found in the tests // `test_host_ir_stream_lowering.cpp` // with the option `NVFUSER_DUMP=host_ir`. -class StreamParallelType : public preseg_passes::OptimizationPass { +class StreamParallelType + : public preseg_passes::OptimizationPass { friend class preseg_passes::OptimizationPass; protected: diff --git a/python/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp index fd3e714f7cb..b77947f1415 100644 --- a/python/python_frontend/fusion_definition.cpp +++ b/python/python_frontend/fusion_definition.cpp @@ -455,8 +455,8 @@ std::pair> FusionDefinition:: params.lower.communicator_backend = backend_type_; // Disable StreamParallelType pass temporarily as proper stream lowering // gets implemented - preseg_passes::OptimizationPassGuard - guard(false); + preseg_passes::OptimizationPassGuard guard( + false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), Communicator::getInstance(), diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp index e03fccb34e0..b77df002bc6 100644 --- a/tests/cpp/test_host_ir_stream_lowering.cpp +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -35,8 +35,8 @@ TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { hic->addInput(tv); tv->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Split) { @@ -50,8 +50,8 @@ TEST_F(HirLowerStreamTest, Split) { tv1->split(0, 2); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Merge) { @@ -65,8 +65,8 @@ TEST_F(HirLowerStreamTest, Merge) { tv1->merge(0, 1); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } TEST_F(HirLowerStreamTest, SingleSetOp) { @@ -81,8 +81,7 @@ TEST_F(HirLowerStreamTest, SingleSetOp) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -111,8 +110,7 @@ TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { tv1->setMemoryType(MemoryType::Global); tv1->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -144,8 +142,7 @@ TEST_F(HirLowerStreamTest, SingleBinaryOp) { tv2->setMemoryType(MemoryType::Global); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -180,8 +177,7 @@ TEST_F(HirLowerStreamTest, TwoSetOps) { tv1->axis(0)->parallelize(ParallelType::Stream); tv2->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 3); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -218,8 +214,7 @@ TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { tv1->axis(0)->parallelize(ParallelType::Stream); tv3->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 5); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -251,8 +246,8 @@ TEST_F(HirLowerStreamTest, ReductionUnsupported) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } TEST_F(HirLowerStreamTest, Reduction) { @@ -267,8 +262,7 @@ TEST_F(HirLowerStreamTest, Reduction) { tv1->setMemoryType(MemoryType::Global); tv1->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -301,8 +295,7 @@ TEST_F(HirLowerStreamTest, Matmul_M) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -338,8 +331,7 @@ TEST_F(HirLowerStreamTest, BatchedMatmul) { c->setMemoryType(MemoryType::Global); c->axis(0)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -375,8 +367,7 @@ TEST_F(HirLowerStreamTest, Matmul_N) { c->setMemoryType(MemoryType::Global); c->axis(1)->parallelize(ParallelType::Stream); - preseg_passes::OptimizationPass::runPass( - hic.get()); + preseg_passes::OptimizationPass::runPass(hic.get()); EXPECT_EQ(hic->topLevelExprs().size(), 2); EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); @@ -412,8 +403,8 @@ TEST_F(HirLowerStreamTest, Matmul_K) { c->setMemoryType(MemoryType::Global); c->axis(-1)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } // We don's support PostOnStream because it does not support well pre-allocated @@ -460,8 +451,8 @@ TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { output->axis(-1)->parallelize(ParallelType::Stream); - EXPECT_ANY_THROW(preseg_passes::OptimizationPass< - StreamParallelType>::runPass(hic.get())); + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); } } // namespace hir diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 6932a40fe5c..db53f7f114d 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -365,8 +365,7 @@ using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { // Disable StreamParallelType pass temporarily as proper stream lowering gets // implemented - preseg_passes::OptimizationPassGuard guard( - false); + preseg_passes::OptimizationPassGuard guard(false); constexpr int64_t M = 32768; constexpr int64_t K = 32768; @@ -424,8 +423,7 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { TEST_F(OverlapDistributedMatmulTest, AG_linear) { // Disable StreamParallelType pass tempor - preseg_passes::OptimizationPassGuard guard( - false); + preseg_passes::OptimizationPassGuard guard(false); constexpr int64_t M = 32768; constexpr int64_t K = 32768; From e517bc31fad5cd39f4d84e5d813f469f6ba90a26 Mon Sep 17 00:00:00 2001 From: snordmann Date: Sun, 27 Apr 2025 07:36:08 -0700 Subject: [PATCH 67/68] fix merge --- csrc/host_ir/executor.cpp | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index e089fec32cf..2f2cf9e7b92 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -705,8 +705,6 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } -<<<<<<< HEAD -======= void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { auto index = expr_evaluator_.evaluate(hir_alias_select->index()).as(); @@ -716,7 +714,6 @@ void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { bind(hir_alias_select->out(), input.select(axis, index)); } ->>>>>>> bfc7ba836400aa349fab473fa04bab204e9c5601 void HostIrEvaluator::handle(BinaryOp* binary_op) { if (!isKnown(binary_op->outputs().at(0))) { return unhandled(binary_op); @@ -789,18 +786,6 @@ void HostIrEvaluator::handle(ReductionOp* reduction_op) { } } -<<<<<<< HEAD -void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { - auto index = - expr_evaluator_.evaluate(hir_alias_select->index()).as(); - auto input = getKnownConcreteValue(hir_alias_select->in()->as()) - .as(); - int64_t axis = hir_alias_select->axis(); - bind(hir_alias_select->out(), input.select(axis, index)); -} - -======= ->>>>>>> bfc7ba836400aa349fab473fa04bab204e9c5601 void HostIrEvaluator::unhandled(Statement* stmt) { NVF_ERROR(stmt->isA(), stmt, " must be an Expr"); auto* expr = stmt->as(); From 35ff4dab62b187e3842086ff20a223af267c1a50 Mon Sep 17 00:00:00 2001 From: snordmann Date: Mon, 28 Apr 2025 09:58:44 +0300 Subject: [PATCH 68/68] empty commit to trigger the CI