From 6061f3ba83a8394988b68d7c77658550d319ec6a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 09:11:50 -0400 Subject: [PATCH 01/37] Initial draft of empty remover pass --- CMakeLists.txt | 1 + csrc/optimization/pre_segmenter.cpp | 3 + csrc/optimization/remove_empty.cpp | 221 ++++++++++++++++++++++++++++ csrc/optimization/remove_empty.h | 21 +++ 4 files changed, 246 insertions(+) create mode 100644 csrc/optimization/remove_empty.cpp create mode 100644 csrc/optimization/remove_empty.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 137dde0f373..e2755b802d3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -184,6 +184,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp + ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ) set(NVFUSER_CODEGEN ${PROJECT_NAME}_codegen) diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 30bdf87bf34..b9a058daff0 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -9,10 +9,13 @@ #include #include +#include namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { + // Replace TensorViews with zero extent. Outputs and inputs may still be empty + OptimizationPass::runPass(fusion); // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp new file mode 100644 index 00000000000..204d71da764 --- /dev/null +++ b/csrc/optimization/remove_empty.cpp @@ -0,0 +1,221 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include + +#include + +namespace nvfuser::optimization { + +namespace { + +//! Get a vector of the integer positions of constant zero extent axes in the +//! input domain. This will typically be used like +//! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` +std::vector emptyAxes(std::vector domain) { + std::vector empty_axes; + for (auto ax : c10::irange(domain.size())) { + auto id = domain.at(ax); + if (id->extent()->isConst() && id->extent()->evaluateInt() == 0) { + empty_axes.push_back(ax); + } + } + return empty_axes; +} + +//! Check whether a TensorView is empty. During concretization, we traverse to +//! find a minimal set of TensorViews that have zero extents, and we then set +//! their extents to a constant 0. Here we check for those constant zero extents. +bool isTVEmpty(TensorView* tv) { + return !emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain())) + .empty(); +} + +//! removeEmptyPass performs a backward traversal of the Fusion. When it detects +//! a TensorView that has at least one extent that is zero, we do the following: +//! +//! 1. If the empty Tensorview is a Fusion output, we replace it with a +//! TensorView created by `full` having the same shape. Since the original +//! tensor is empty, there is nothing to compute, so this eliminates a branch +//! of trivial code. +//! 2. If the empty TensorView is the input of a `cat` op along the empty +//! dimension, we replace the cat op with a new one having the empty input +//! removed. +//! 3. If the empty Tensorview is the input to a ReductionOp or WelfordOp and +//! the empty dimensions are reduced, we replace the op with `full` since +//! there are no elements being reduced. Note that if any empty axes are not +//! reduced, we will not encounter this case since it will have been removed +//! earlier in the backward traversal under condition 1. +//! 4. If the empty Tensorview is the input to a PadOp (which is not input to +//! a CatOp) then we replace the pad with `full(pad_value)`. +//! +//! Note that we do not use BackwardVisitor here even though we are doing a +//! backward traversal. This is because we will actually be changing the Fusion +//! graph as we traverse. BackwardVisitor works by creating a stack of +//! statements using InputsOf and a forward traversal, then iteratively pops +//! that static stack of statements. This does not work well when we want to +//! eliminate dead code while traversing, so instead we implement a simple +//! stack-based depth-first backward traversal manually here. +class EmptyTensorRemover { + public: + EmptyTensorRemover(Fusion* fusion) : fusion_(fusion) { + for (auto outp : fusion->outputs()) { + stmt_stack_.push_back(outp); + } + } + + void run() { + while (!stmt_stack_.empty()) { + auto stmt = stmt_stack_.back(); + stmt_stack_.pop_back(); + handle(stmt); + } + } + + void handle(Statement* stmt) { + if (stmt->isVal()) { + handle(stmt->asVal()); + } else { + TORCH_INTERNAL_ASSERT(stmt->isExpr(), "Statement is neither a Val or Expr: ", stmt->toString()); + handle(stmt->asExpr()); + } + } + + void handle(Val* v) { + if (auto tv = dynamic_cast(v)) { + handle(tv); + } else if (v->definition()) { + // TensorView vals might be overwritten, in which case we should not keep + // traversing. For all other Vals, push their definition to the stack. + stmt_stack_.push_back(v->definition()); + } + } + + //! If tv is a fusion output, we check whether it is empty and if so, replace + //! it with full(). For non-outputs that are not inputs, we simply check that + //! the tensor is not provably empty. + void handle(TensorView* tv) { + if (tv->isFusionOutput()) { + const auto rfactor = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); + const auto empty_axes = emptyAxes(rfactor); + if (!empty_axes.empty()) { + std::vector shape(rfactor.size()); + std::transform( + rfactor.begin(), rfactor.end(), shape.begin(), [](IterDomain* id) { + return id->extent(); + }); + for (auto ax : empty_axes) { + shape[ax] = fusion_->zeroVal(); + } + auto dtype = tv->getDataType().value(); + auto new_tv = full(shape, fusion_->oneVal(dtype), dtype); + replaceTV(tv, new_tv); + // Do not keep traversing upstream if we've replaced tv + return; + } + } else if (!tv->isFusionInput()) { + // TODO: This should be a warning instead + TORCH_INTERNAL_ASSERT( + !isTVEmpty(tv), + "Found unexpected empty intermediate TensorView ", + tv->toString()); + } + if (tv->definition()) { + stmt_stack_.push_back(tv->definition()); + } + } + + void pushInputs(Expr* e) { + for (auto inp : e->inputs()) { + stmt_stack_.push_back(inp); + } + } + + void handle(ReductionOp* rop) { + auto in = rop->in()->as(); + auto empty_input_axes = + emptyAxes(TensorDomain::noReductions(in->getMaybeRFactorDomain())); + if (empty_input_axes.empty()) { + // Input is not empty, handle like any other op + pushInputs(rop); + return; + } + auto out = rop->out()->as(); + // The input is empty in some axes. Assert that they are all reduced + const auto& out_root = out->getRootDomain(); + + std::vector shape; + for (auto id : out_root) { + if (!id->isReduction() && !id->isStride()) { // same as noReductions() + shape.push_back(id->extent()); + } + } + + for (auto ax : empty_input_axes) { + auto id = out_root.at(ax); + // Input rfactor domain positions correspond to output root positions + TORCH_INTERNAL_ASSERT( + id->isReduction(), + "Found unexpected unreduced empty axis at position ", + ax, + " in expression ", + rop->toString()); + shape[ax] = fusion_->zeroVal(); + } + // Find output shape to replace with full + + auto new_tv = full(shape, rop->init(), out->getDataType().value()); + replaceTV(out, new_tv); + } + + //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion + //! input, we do not replace it. After replacement, unless it is a Fusion + //! input, we remove it from the fusion and set the original pointer to zero + //! (hence why old_tv is passed by reference). + void replaceTV(TensorView*& old_tv, TensorView* new_tv) { + if (old_tv->isFusionOutput()) { + fusion_->replaceOutput(old_tv, new_tv); + } + for (auto use : old_tv->uses()) { + ir_utils::replaceValInExpr(use, old_tv, new_tv); + } + if (!old_tv->isFusionInput()) { + fusion_->removeVal(old_tv); + old_tv = nullptr; + } + } + + void handle(Expr* e) { + if (auto rop = dynamic_cast(e)) { + handle(rop); + } else if (auto wop = dynamic_cast(e)) { + handle(wop); + } else if (auto pop = dynamic_cast(e)) { + handle(pop); + } else { + // The handled ops above may terminate this branch, so they will need to + // manually handle their inputs. For unhandled ops, we just handle all + // inputs here. + pushInputs(e); + } + } + + private: + Fusion* fusion_; + std::vector stmt_stack_; +}; + +} // namespace + +void RemoveEmptyPass::runPass(Fusion* fusion) { + EmptyTensorRemover(fusion).run(); +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/remove_empty.h b/csrc/optimization/remove_empty.h new file mode 100644 index 00000000000..6b4f4aee6fa --- /dev/null +++ b/csrc/optimization/remove_empty.h @@ -0,0 +1,21 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +//! RemoveEmptyPass removes empty tensors (those with at least one extent zero). +class TORCH_CUDA_CU_API RemoveEmptyPass + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); +}; + +} // namespace nvfuser::optimization From 4bce330438abe9f7f642221c876c681b2c9e784c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 09:20:34 -0400 Subject: [PATCH 02/37] Add NVFUSER_DUMP=fusion_ir_preseg to view opt pass output --- csrc/kernel_cache.cpp | 6 ++++++ csrc/optimization/remove_empty.cpp | 15 ++++++++++----- csrc/options.cpp | 1 + csrc/options.h | 1 + 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 1f11e83dfd5..8d167826a8f 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -672,6 +672,12 @@ FusionKernelRuntime::FusionKernelRuntime( optimization::OptimizationPass::runPass( fusion.get()); + if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) { + std::cout << "Fusion IR after pre-segmenter optimization passes:" + << std::endl; + fusion->printMath(); + } + all_tvs_ = ir_utils::allTvs(fusion.get()); // Run segmentation on the copied fusion diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 204d71da764..e9dee341eb3 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -7,8 +7,8 @@ // clang-format on #include -#include #include +#include #include @@ -32,7 +32,8 @@ std::vector emptyAxes(std::vector domain) { //! Check whether a TensorView is empty. During concretization, we traverse to //! find a minimal set of TensorViews that have zero extents, and we then set -//! their extents to a constant 0. Here we check for those constant zero extents. +//! their extents to a constant 0. Here we check for those constant zero +//! extents. bool isTVEmpty(TensorView* tv) { return !emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain())) .empty(); @@ -83,7 +84,10 @@ class EmptyTensorRemover { if (stmt->isVal()) { handle(stmt->asVal()); } else { - TORCH_INTERNAL_ASSERT(stmt->isExpr(), "Statement is neither a Val or Expr: ", stmt->toString()); + TORCH_INTERNAL_ASSERT( + stmt->isExpr(), + "Statement is neither a Val or Expr: ", + stmt->toString()); handle(stmt->asExpr()); } } @@ -103,7 +107,8 @@ class EmptyTensorRemover { //! the tensor is not provably empty. void handle(TensorView* tv) { if (tv->isFusionOutput()) { - const auto rfactor = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); + const auto rfactor = + TensorDomain::noReductions(tv->getMaybeRFactorDomain()); const auto empty_axes = emptyAxes(rfactor); if (!empty_axes.empty()) { std::vector shape(rfactor.size()); @@ -150,7 +155,7 @@ class EmptyTensorRemover { auto out = rop->out()->as(); // The input is empty in some axes. Assert that they are all reduced const auto& out_root = out->getRootDomain(); - + std::vector shape; for (auto id : out_root) { if (!id->isReduction() && !id->isStride()) { // same as noReductions() diff --git a/csrc/options.cpp b/csrc/options.cpp index 6c059a83f0c..fe2cd41f9a8 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -117,6 +117,7 @@ std::unordered_map> Options< {"fusion_args", DebugDumpOption::FusionArgs}, {"fusion_ir", DebugDumpOption::FusionIr}, {"fusion_ir_concretized", DebugDumpOption::FusionIrConcretized}, + {"fusion_ir_preseg", DebugDumpOption::FusionIrPreseg}, {"fusion_ir_math", DebugDumpOption::FusionIrMath}, {"fusion_ir_presched", DebugDumpOption::FusionIrPresched}, {"halo", DebugDumpOption::Halo}, diff --git a/csrc/options.h b/csrc/options.h index 168171a6644..076999f87bd 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -25,6 +25,7 @@ enum class DebugDumpOption { FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR FusionIrPresched, //!< Dump the Fusion IR before it is scheduled. FusionIrConcretized, //!< Dump the Fusion IR after concretization + FusionIrPreseg, //!< Dump the Fusion IR after pre-segmenter optimization KernelIr, //!< Dump the compiler Kernel IR ComputeAtMap, //!< Dump the computeAt map CudaKernel, //!< Dump the generated CUDA C++ kernel code From 6c649150f551c5c68abc6722fa16ce406ef37029 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 09:35:02 -0400 Subject: [PATCH 03/37] Add FusionRemoveEmptyOutput_CUDA --- test/test_dynamic_transform.cpp | 35 +++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 64a8827d32f..129ab18a4f4 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -994,4 +994,39 @@ TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); } +// Test that we remove empty output branch before segmentation +TEST_F(NVFuserTest, FusionRemoveEmptyOutput_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0], tv1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate(preseg_fusion, outputs, aten_inputs, {at0}, __LINE__, __FILE__); +} + } // namespace nvfuser From 6fb368cab13940765b7f008f3751849c242aa460 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 09:40:03 -0400 Subject: [PATCH 04/37] Fill with zeroVal instead of oneVal --- csrc/optimization/remove_empty.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index e9dee341eb3..8720f32582b 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -120,7 +120,7 @@ class EmptyTensorRemover { shape[ax] = fusion_->zeroVal(); } auto dtype = tv->getDataType().value(); - auto new_tv = full(shape, fusion_->oneVal(dtype), dtype); + auto new_tv = full(shape, fusion_->zeroVal(dtype), dtype); replaceTV(tv, new_tv); // Do not keep traversing upstream if we've replaced tv return; From 0d8b4fcb34557dae070cc5b1850ea1a583402c7a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 09:40:12 -0400 Subject: [PATCH 05/37] Add FusionRemoveEmptyReduction_CUDA --- test/test_dynamic_transform.cpp | 40 +++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 129ab18a4f4..43b5e7e89dc 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1029,4 +1029,44 @@ TEST_F(NVFuserTest, FusionRemoveEmptyOutput_CUDA) { testValidate(preseg_fusion, outputs, aten_inputs, {at0}, __LINE__, __FILE__); } +// Test that we replace empty reduction with full +TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::sum(at0, {0})}, + __LINE__, + __FILE__); +} + } // namespace nvfuser From 436832f426caf1720f380cd17638359f31e4a94e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 10:06:38 -0400 Subject: [PATCH 06/37] Handle CatOp, with test --- csrc/optimization/remove_empty.cpp | 79 ++++++++++++++++++++++++++++-- test/test_dynamic_transform.cpp | 56 +++++++++++++++++++++ 2 files changed, 132 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 8720f32582b..a2bb28a279a 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -180,6 +181,76 @@ class EmptyTensorRemover { replaceTV(out, new_tv); } + void handle(WelfordOp* wop) {} + + //! A cat op can have input empty tensors and still output a non-empty + //! tensor. This is only possible if there is more than one input, so we + //! only need to handle those cases. We find the non-empty inputs to cat + //! then replace with another cat (or `set` if n=1). + //! + //! The `cat` function creates a CatOp object, but its inputs() are not + //! the original inputs. Rather, they are the inputs after padding to the + //! output extent in the concatenated dimension. Thus, in the IR graph, + //! instead of the following: + //! + //! T0 T1 T2 + //! \ | / + //! CatOp + //! | + //! T3 + //! + //! a cat is represented as: + //! + //! T0 T1 T2 + //! | | | + //! PadOp PadOp PadOp + //! \ | / + //! CatOp + //! | + //! T3 + //! + //! If we determine that one of the inputs, T1, is empty in the cat + //! dimension, then we rewrite this as: + //! + //! T0 T2 + //! | | + //! PadOp PadOp + //! \ / + //! CatOp + //! | + //! T3 + //! + //! This is done by simply calling the cat() command with only {T0, T2}. + void handle(CatOp* cop) { + auto dim = cop->concatenatedDim(); + std::vector non_empty_inputs; + for (auto inp : cop->inputs()) { + TORCH_INTERNAL_ASSERT( + inp->definition() && inp->definition()->isA(), + "Inputs to CatOp must be outputs of PadOps"); + auto tv = inp->definition()->as()->in()->as(); + auto cat_id = + TensorDomain::noReductions(tv->getMaybeRFactorDomain()).at(dim); + if (cat_id->extent()->isConst() && cat_id->extent()->evaluateInt() == 0) { + continue; + } + non_empty_inputs.push_back(tv); + } + if (non_empty_inputs.size() != cop->inputs().size()) { + // Replace this op with a new cat op + auto old_tv = cop->outputs()[0]->as(); + // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1 + auto new_tv = cat(non_empty_inputs, dim); + replaceTV(old_tv, new_tv); + } + for (auto tv : non_empty_inputs) { + // Continue processing non-empty inputs + stmt_stack_.push_back(tv); + } + } + + void handle(PadOp* wop) {} + //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion //! input, we do not replace it. After replacement, unless it is a Fusion //! input, we remove it from the fusion and set the original pointer to zero @@ -202,12 +273,14 @@ class EmptyTensorRemover { handle(rop); } else if (auto wop = dynamic_cast(e)) { handle(wop); + } else if (auto pop = dynamic_cast(e)) { + handle(pop); } else if (auto pop = dynamic_cast(e)) { handle(pop); } else { - // The handled ops above may terminate this branch, so they will need to - // manually handle their inputs. For unhandled ops, we just handle all - // inputs here. + // The handled ops above may terminate this branch of the traversal, so + // they will need to manually handle their inputs. For unhandled ops, we + // just handle all inputs here. pushInputs(e); } } diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index 43b5e7e89dc..c50a3bc6b84 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -1069,4 +1069,60 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { __FILE__); } +// Test that we replace empty tensors in cat properly +TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({2, 3}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({4, 3}); + fusion.addInput(tv2); + + // equivalent to cat({tv1, tv2}, 0) + auto tv3 = cat({tv0, tv1, tv2}, 0); + fusion.addOutput(tv3); + // set(tv1) + auto tv4 = cat({tv0, tv1}, 0); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + at::Tensor at1 = at::randn({2, 3}, options); + at::Tensor at2 = at::randn({4, 3}, options); + std::vector aten_inputs = {at0, at1, at2}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + EXPECT_EQ(preseg_fusion->outputs()[0]->definition()->inputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[1]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[1]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::cat({at1, at2}, 0), at1}, + __LINE__, + __FILE__); +} + } // namespace nvfuser From 50d763f3bafcb0e5c14b9558f243aed538e13536 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 10:19:26 -0400 Subject: [PATCH 07/37] Move tests to test_optimization_pass.cpp --- test/test_dynamic_transform.cpp | 131 -------------------------------- test/test_optimization_pass.cpp | 131 ++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 131 deletions(-) diff --git a/test/test_dynamic_transform.cpp b/test/test_dynamic_transform.cpp index c50a3bc6b84..64a8827d32f 100644 --- a/test/test_dynamic_transform.cpp +++ b/test/test_dynamic_transform.cpp @@ -994,135 +994,4 @@ TEST_F(NVFuserTest, FusionDynamicSliceToBroadcast_CUDA) { testValidate(&fusion, outputs, aten_inputs, {at2}, __LINE__, __FILE__); } -// Test that we remove empty output branch before segmentation -TEST_F(NVFuserTest, FusionRemoveEmptyOutput_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(fusion_ptr.get()); - // Concrete tensor with zero for one extent, so that we can prove the output - // is empty - auto tv0 = makeConcreteTensor({0, 3}); - fusion.addInput(tv0); - auto tv1 = set(tv0); - fusion.addOutput(tv1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at0 = at::randn({0, 3}, options); - std::vector aten_inputs = {at0}; - - auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); - FusionKernelRuntime runtime(std::move(fusion_ptr), args); - - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. - auto preseg_fusion = runtime.fusionSegments()->completeFusion(); - EXPECT_EQ(preseg_fusion->outputs().size(), 1); - EXPECT_NE(preseg_fusion->outputs()[0], tv1); - EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); - EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); - - runtime.compileFusionParallel(args); - auto outputs = runtime.runWithInputs(args); - - testValidate(preseg_fusion, outputs, aten_inputs, {at0}, __LINE__, __FILE__); -} - -// Test that we replace empty reduction with full -TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(fusion_ptr.get()); - // Concrete tensor with zero for one extent, so that we can prove the output - // is empty - auto tv0 = makeConcreteTensor({0, 3}); - fusion.addInput(tv0); - auto tv1 = sum(tv0, {0}); - fusion.addOutput(tv1); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at0 = at::randn({0, 3}, options); - std::vector aten_inputs = {at0}; - - auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); - FusionKernelRuntime runtime(std::move(fusion_ptr), args); - - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. - auto preseg_fusion = runtime.fusionSegments()->completeFusion(); - EXPECT_EQ(preseg_fusion->outputs().size(), 1); - EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); - EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); - - runtime.compileFusionParallel(args); - auto outputs = runtime.runWithInputs(args); - - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::sum(at0, {0})}, - __LINE__, - __FILE__); -} - -// Test that we replace empty tensors in cat properly -TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion& fusion = *fusion_ptr.get(); - FusionGuard fg(fusion_ptr.get()); - // Concrete tensor with zero for one extent, so that we can prove the output - // is empty - auto tv0 = makeConcreteTensor({0, 3}); - fusion.addInput(tv0); - auto tv1 = makeConcreteTensor({2, 3}); - fusion.addInput(tv1); - auto tv2 = makeConcreteTensor({4, 3}); - fusion.addInput(tv2); - - // equivalent to cat({tv1, tv2}, 0) - auto tv3 = cat({tv0, tv1, tv2}, 0); - fusion.addOutput(tv3); - // set(tv1) - auto tv4 = cat({tv0, tv1}, 0); - fusion.addOutput(tv4); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor at0 = at::randn({0, 3}, options); - at::Tensor at1 = at::randn({2, 3}, options); - at::Tensor at2 = at::randn({4, 3}, options); - std::vector aten_inputs = {at0, at1, at2}; - - auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); - FusionKernelRuntime runtime(std::move(fusion_ptr), args); - - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. - auto preseg_fusion = runtime.fusionSegments()->completeFusion(); - EXPECT_EQ(preseg_fusion->outputs().size(), 2); - - EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); - EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); - EXPECT_EQ(preseg_fusion->outputs()[0]->definition()->inputs().size(), 2); - - EXPECT_NE(preseg_fusion->outputs()[1]->definition(), nullptr); - EXPECT_TRUE(preseg_fusion->outputs()[1]->definition()->isA()); - - runtime.compileFusionParallel(args); - auto outputs = runtime.runWithInputs(args); - - testValidate( - preseg_fusion, - outputs, - aten_inputs, - {at::cat({at1, at2}, 0), at1}, - __LINE__, - __FILE__); -} - } // namespace nvfuser diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index bf7b1f1c88d..00cc3fd53a5 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -365,4 +365,135 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } } +// Test that we remove empty output branch before segmentation +TEST_F(NVFuserTest, FusionRemoveEmptyOutput_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0], tv1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate(preseg_fusion, outputs, aten_inputs, {at0}, __LINE__, __FILE__); +} + +// Test that we replace empty reduction with full +TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::sum(at0, {0})}, + __LINE__, + __FILE__); +} + +// Test that we replace empty tensors in cat properly +TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({2, 3}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({4, 3}); + fusion.addInput(tv2); + + // equivalent to cat({tv1, tv2}, 0) + auto tv3 = cat({tv0, tv1, tv2}, 0); + fusion.addOutput(tv3); + // set(tv1) + auto tv4 = cat({tv0, tv1}, 0); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + at::Tensor at1 = at::randn({2, 3}, options); + at::Tensor at2 = at::randn({4, 3}, options); + std::vector aten_inputs = {at0, at1, at2}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + EXPECT_EQ(preseg_fusion->outputs()[0]->definition()->inputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[1]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[1]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::cat({at1, at2}, 0), at1}, + __LINE__, + __FILE__); +} + } // namespace nvfuser::optimization From 56dcfae2e34859f3f4c73b680096941febe94367 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 11:37:46 -0400 Subject: [PATCH 08/37] Handle PadOp --- csrc/optimization/remove_empty.cpp | 23 +++++++++++++- test/test_optimization_pass.cpp | 48 ++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index a2bb28a279a..212d37a85ab 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -249,7 +249,28 @@ class EmptyTensorRemover { } } - void handle(PadOp* wop) {} + //! Replace pad(tv) if tv is empty in any dimension. Note that since we detect + //! empty tensors by looking for constant extents, the output extents will be + //! correct here already, so there is no value in removing the empty input + //! extent when we do the replacement. + void handle(PadOp* pop) { + auto in = pop->in()->as(); + auto in_rfactor = TensorDomain::noReductions(in->getMaybeRFactorDomain()); + if (!emptyAxes(in_rfactor).empty()) { + auto out = pop->out()->as(); + auto out_rfactor = + TensorDomain::noReductions(out->getMaybeRFactorDomain()); + std::vector shape; + shape.reserve(out_rfactor.size()); + for (auto id : out_rfactor) { + shape.push_back(id->extent()); + } + auto new_tv = full(shape, pop->value(), out->getDataType().value()); + replaceTV(out, new_tv); + } else { + pushInputs(pop); + } + } //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion //! input, we do not replace it. After replacement, unless it is a Fusion diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 00cc3fd53a5..dc5f92d4d16 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -496,4 +496,52 @@ TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { __FILE__); } +// Test that we replace empty tensors in pad properly +TEST_F(NVFuserTest, FusionRemoveEmptyPad_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({3, 0}); + fusion.addInput(tv0); + + // Use a non-zero pad value to verify that it is used in the rewritten fill + auto pad_val = IrBuilder::create(3.14, DataType::Float); + + // equivalent to full({3, 2}, pad_val, DataType::Float) + auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}, pad_val); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({3, 0}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + auto rewritten_def = preseg_fusion->outputs()[0]->definition(); + EXPECT_TRUE(rewritten_def->isA()); + EXPECT_TRUE(rewritten_def->as()->getFillValue()->sameAs(pad_val)); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::pad(at0, {1, 1}, "constant", 3.14)}, + __LINE__, + __FILE__); +} + } // namespace nvfuser::optimization From c7768788433ba4cc4b4618d8576928b7290b4f7d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 11:43:13 -0400 Subject: [PATCH 09/37] Use deque instead of stack. We remove from the front and push to the back to process Statements in FIFO order. This ensures we have traversed in a reverse topological order, so that we can safely remove any TensorViews downstream of the Statement we're looking at, as they have already been processed and should not appear later in the stack (though we should check still because they might be an output). --- csrc/optimization/remove_empty.cpp | 45 ++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 212d37a85ab..f6a6dd8ef79 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -12,6 +12,7 @@ #include #include +#include namespace nvfuser::optimization { @@ -60,23 +61,23 @@ bool isTVEmpty(TensorView* tv) { //! //! Note that we do not use BackwardVisitor here even though we are doing a //! backward traversal. This is because we will actually be changing the Fusion -//! graph as we traverse. BackwardVisitor works by creating a stack of +//! graph as we traverse. BackwardVisitor works by creating a queue of //! statements using InputsOf and a forward traversal, then iteratively pops -//! that static stack of statements. This does not work well when we want to +//! that static queue of statements. This does not work well when we want to //! eliminate dead code while traversing, so instead we implement a simple -//! stack-based depth-first backward traversal manually here. +//! queue-based breadth-first backward traversal manually here. class EmptyTensorRemover { public: EmptyTensorRemover(Fusion* fusion) : fusion_(fusion) { for (auto outp : fusion->outputs()) { - stmt_stack_.push_back(outp); + stmt_queue_.push_back(outp); } } void run() { - while (!stmt_stack_.empty()) { - auto stmt = stmt_stack_.back(); - stmt_stack_.pop_back(); + while (!stmt_queue_.empty()) { + auto stmt = stmt_queue_.front(); + stmt_queue_.pop_front(); handle(stmt); } } @@ -99,7 +100,7 @@ class EmptyTensorRemover { } else if (v->definition()) { // TensorView vals might be overwritten, in which case we should not keep // traversing. For all other Vals, push their definition to the stack. - stmt_stack_.push_back(v->definition()); + stmt_queue_.push_back(v->definition()); } } @@ -134,16 +135,27 @@ class EmptyTensorRemover { tv->toString()); } if (tv->definition()) { - stmt_stack_.push_back(tv->definition()); + stmt_queue_.push_back(tv->definition()); } } + //! Push the inputs of an expression onto the statement stack for further + //! processing. void pushInputs(Expr* e) { for (auto inp : e->inputs()) { - stmt_stack_.push_back(inp); + stmt_queue_.push_back(inp); } } + //! A reduction over empty axes is equal to the initial value of the + //! reduction, as if the reduction were written as follows: + //! + //! auto result = init_value; + //! for (auto element : reduction_elements) { + //! result = reduction_op(result, element); + //! } + //! return result; + //! void handle(ReductionOp* rop) { auto in = rop->in()->as(); auto empty_input_axes = @@ -181,6 +193,15 @@ class EmptyTensorRemover { replaceTV(out, new_tv); } + //! A reduction over empty axes is equal to the initial value of the + //! reduction, as if the reduction were written as follows: + //! + //! auto result = init_value; + //! for (auto element : reduction_elements) { + //! result = reduction_op(result, element); + //! } + //! return result; + //! void handle(WelfordOp* wop) {} //! A cat op can have input empty tensors and still output a non-empty @@ -245,7 +266,7 @@ class EmptyTensorRemover { } for (auto tv : non_empty_inputs) { // Continue processing non-empty inputs - stmt_stack_.push_back(tv); + stmt_queue_.push_back(tv); } } @@ -308,7 +329,7 @@ class EmptyTensorRemover { private: Fusion* fusion_; - std::vector stmt_stack_; + std::deque stmt_queue_; }; } // namespace From fb1b670ae32bddb540fd6ee31c9594ed53b3dea3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 12:33:36 -0400 Subject: [PATCH 10/37] Use BackwardVisitor --- csrc/optimization/remove_empty.cpp | 124 ++++++++++------------------- 1 file changed, 43 insertions(+), 81 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index f6a6dd8ef79..969dbceb461 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -13,6 +13,8 @@ #include #include +#include +#include namespace nvfuser::optimization { @@ -59,55 +61,33 @@ bool isTVEmpty(TensorView* tv) { //! 4. If the empty Tensorview is the input to a PadOp (which is not input to //! a CatOp) then we replace the pad with `full(pad_value)`. //! -//! Note that we do not use BackwardVisitor here even though we are doing a -//! backward traversal. This is because we will actually be changing the Fusion -//! graph as we traverse. BackwardVisitor works by creating a queue of -//! statements using InputsOf and a forward traversal, then iteratively pops -//! that static queue of statements. This does not work well when we want to -//! eliminate dead code while traversing, so instead we implement a simple -//! queue-based breadth-first backward traversal manually here. -class EmptyTensorRemover { +class EmptyTensorRemover : BackwardVisitor { public: - EmptyTensorRemover(Fusion* fusion) : fusion_(fusion) { - for (auto outp : fusion->outputs()) { - stmt_queue_.push_back(outp); - } - } + EmptyTensorRemover(Fusion* fusion) + : BackwardVisitor(false), fusion_(fusion) {} void run() { - while (!stmt_queue_.empty()) { - auto stmt = stmt_queue_.front(); - stmt_queue_.pop_front(); - handle(stmt); - } + traverseTo(fusion_, fusion_->outputs()); } - void handle(Statement* stmt) { - if (stmt->isVal()) { - handle(stmt->asVal()); - } else { - TORCH_INTERNAL_ASSERT( - stmt->isExpr(), - "Statement is neither a Val or Expr: ", - stmt->toString()); - handle(stmt->asExpr()); - } - } - - void handle(Val* v) { - if (auto tv = dynamic_cast(v)) { - handle(tv); - } else if (v->definition()) { - // TensorView vals might be overwritten, in which case we should not keep - // traversing. For all other Vals, push their definition to the stack. - stmt_queue_.push_back(v->definition()); + void handle(Statement* stmt) final { + if (isDead(stmt)) { + // We check whether stmt is dead before we dereference it, since it may + // have been removed from the Fusion. + return; } + BackwardVisitor::handle(stmt); } //! If tv is a fusion output, we check whether it is empty and if so, replace //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. - void handle(TensorView* tv) { + void handle(TensorView* tv) final { + if (tv->isFusionInput()) { + // Skip inputs since they do not have a definition to redefine + return; + } + if (tv->isFusionOutput()) { const auto rfactor = TensorDomain::noReductions(tv->getMaybeRFactorDomain()); @@ -127,24 +107,13 @@ class EmptyTensorRemover { // Do not keep traversing upstream if we've replaced tv return; } - } else if (!tv->isFusionInput()) { - // TODO: This should be a warning instead + } else { + // TODO: This should be a warning instead of an assert TORCH_INTERNAL_ASSERT( !isTVEmpty(tv), "Found unexpected empty intermediate TensorView ", tv->toString()); } - if (tv->definition()) { - stmt_queue_.push_back(tv->definition()); - } - } - - //! Push the inputs of an expression onto the statement stack for further - //! processing. - void pushInputs(Expr* e) { - for (auto inp : e->inputs()) { - stmt_queue_.push_back(inp); - } } //! A reduction over empty axes is equal to the initial value of the @@ -156,13 +125,12 @@ class EmptyTensorRemover { //! } //! return result; //! - void handle(ReductionOp* rop) { + void handle(ReductionOp* rop) final { auto in = rop->in()->as(); auto empty_input_axes = emptyAxes(TensorDomain::noReductions(in->getMaybeRFactorDomain())); if (empty_input_axes.empty()) { // Input is not empty, handle like any other op - pushInputs(rop); return; } auto out = rop->out()->as(); @@ -202,7 +170,7 @@ class EmptyTensorRemover { //! } //! return result; //! - void handle(WelfordOp* wop) {} + void handle(WelfordOp* wop) final {} //! A cat op can have input empty tensors and still output a non-empty //! tensor. This is only possible if there is more than one input, so we @@ -242,7 +210,7 @@ class EmptyTensorRemover { //! T3 //! //! This is done by simply calling the cat() command with only {T0, T2}. - void handle(CatOp* cop) { + void handle(CatOp* cop) final { auto dim = cop->concatenatedDim(); std::vector non_empty_inputs; for (auto inp : cop->inputs()) { @@ -264,17 +232,13 @@ class EmptyTensorRemover { auto new_tv = cat(non_empty_inputs, dim); replaceTV(old_tv, new_tv); } - for (auto tv : non_empty_inputs) { - // Continue processing non-empty inputs - stmt_queue_.push_back(tv); - } } //! Replace pad(tv) if tv is empty in any dimension. Note that since we detect //! empty tensors by looking for constant extents, the output extents will be //! correct here already, so there is no value in removing the empty input //! extent when we do the replacement. - void handle(PadOp* pop) { + void handle(PadOp* pop) final { auto in = pop->in()->as(); auto in_rfactor = TensorDomain::noReductions(in->getMaybeRFactorDomain()); if (!emptyAxes(in_rfactor).empty()) { @@ -288,8 +252,6 @@ class EmptyTensorRemover { } auto new_tv = full(shape, pop->value(), out->getDataType().value()); replaceTV(out, new_tv); - } else { - pushInputs(pop); } } @@ -304,32 +266,32 @@ class EmptyTensorRemover { for (auto use : old_tv->uses()) { ir_utils::replaceValInExpr(use, old_tv, new_tv); } - if (!old_tv->isFusionInput()) { - fusion_->removeVal(old_tv); - old_tv = nullptr; + // old_tv as well as its definition will be removed by fusion_->removeVal(), + // after which the pointers will be invalid. We mark them as dead to avoid + // dereferencing and processing those here. + markDead(old_tv); + if (old_tv->definition()) { + markDead(old_tv->definition()); } + + fusion_->removeVal(old_tv); } - void handle(Expr* e) { - if (auto rop = dynamic_cast(e)) { - handle(rop); - } else if (auto wop = dynamic_cast(e)) { - handle(wop); - } else if (auto pop = dynamic_cast(e)) { - handle(pop); - } else if (auto pop = dynamic_cast(e)) { - handle(pop); - } else { - // The handled ops above may terminate this branch of the traversal, so - // they will need to manually handle their inputs. For unhandled ops, we - // just handle all inputs here. - pushInputs(e); - } + //! Find whether a statement has been marked dead + bool isDead(Statement* stmt) { + return dead_.find(stmt) != dead_.end(); + } + + //! Mark a Statement* as dead so that we avoid dereferencing it later + void markDead(Statement* stmt) { + dead_.insert(stmt); } private: Fusion* fusion_; - std::deque stmt_queue_; + + //! Statements are marked dead when they are removed from the Fusion + std::unordered_set dead_; }; } // namespace From f6487e0646a191742dee940559df7c8f4130432d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 14:22:18 -0400 Subject: [PATCH 11/37] Handle WelfordOp --- csrc/optimization/remove_empty.cpp | 79 +++++++++++++++++++++--------- test/test_optimization_pass.cpp | 58 +++++++++++++++++----- 2 files changed, 101 insertions(+), 36 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 969dbceb461..e5c86211674 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -12,7 +12,7 @@ #include #include -#include +#include #include #include @@ -116,6 +116,15 @@ class EmptyTensorRemover : BackwardVisitor { } } + //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) + static std::vector noReductionShape(TensorView* tv) { + std::vector shape; + for (auto id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + shape.push_back(id->extent()); + } + return shape; + } + //! A reduction over empty axes is equal to the initial value of the //! reduction, as if the reduction were written as follows: //! @@ -135,17 +144,8 @@ class EmptyTensorRemover : BackwardVisitor { } auto out = rop->out()->as(); // The input is empty in some axes. Assert that they are all reduced - const auto& out_root = out->getRootDomain(); - - std::vector shape; - for (auto id : out_root) { - if (!id->isReduction() && !id->isStride()) { // same as noReductions() - shape.push_back(id->extent()); - } - } - for (auto ax : empty_input_axes) { - auto id = out_root.at(ax); + auto id = out->getRootDomain().at(ax); // Input rfactor domain positions correspond to output root positions TORCH_INTERNAL_ASSERT( id->isReduction(), @@ -153,24 +153,55 @@ class EmptyTensorRemover : BackwardVisitor { ax, " in expression ", rop->toString()); - shape[ax] = fusion_->zeroVal(); } - // Find output shape to replace with full - auto new_tv = full(shape, rop->init(), out->getDataType().value()); + auto new_tv = + full(noReductionShape(out), rop->init(), out->getDataType().value()); replaceTV(out, new_tv); } - //! A reduction over empty axes is equal to the initial value of the - //! reduction, as if the reduction were written as follows: - //! - //! auto result = init_value; - //! for (auto element : reduction_elements) { - //! result = reduction_op(result, element); - //! } - //! return result; - //! - void handle(WelfordOp* wop) final {} + //! A WelfordOp is similar to a ReductionOp, but has three outputs: avg, var, + //! N. For an empty reduction N will be zero, so we fill the output with zero. + //! The avg and var is obtained by summing then dividing by N. For empty + //! reductions this leads to 0.0 / 0 so we fill it with a constant NAN. The + //! .var variable is actually an unnormalized variance which is a sum without + //! dividing by N or N-1, so we fill it with zeros. + void handle(WelfordOp* wop) final { + auto in = wop->in()->as(); + auto empty_input_axes = + emptyAxes(TensorDomain::noReductions(in->getMaybeRFactorDomain())); + if (empty_input_axes.empty()) { + // Input is not empty, handle like any other op + return; + } + auto avg = wop->outAvg()->as(); + auto var_sum = wop->outVar()->as(); + auto N = wop->outN()->as(); + // The input is empty in some axes. Assert that they are all reduced + for (auto ax : empty_input_axes) { + auto id = avg->getRootDomain().at(ax); + // Input rfactor domain positions correspond to output root positions + TORCH_INTERNAL_ASSERT( + id->isReduction(), + "Found unexpected unreduced empty axis at position ", + ax, + " in expression ", + wop->toString()); + } + + auto shape = noReductionShape(avg); + auto nan = IrBuilder::create( + std::numeric_limits::quiet_NaN(), avg->getDataType().value()); + auto nan_tensor = full(shape, nan, avg->getDataType().value()); + auto new_var_sum = full( + shape, + fusion_->zeroVal(var_sum->getDataType().value()), + var_sum->getDataType().value()); + auto new_N = full(shape, fusion_->zeroVal(), N->getDataType().value()); + replaceTV(avg, nan_tensor); + replaceTV(var_sum, new_var_sum); + replaceTV(N, new_N); + } //! A cat op can have input empty tensors and still output a non-empty //! tensor. This is only possible if there is more than one input, so we diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index dc5f92d4d16..32d45e9d9ee 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -419,10 +419,6 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); FusionKernelRuntime runtime(std::move(fusion_ptr), args); - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. auto preseg_fusion = runtime.fusionSegments()->completeFusion(); EXPECT_EQ(preseg_fusion->outputs().size(), 1); EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); @@ -440,6 +436,52 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { __FILE__); } +// Test that we replace empty Welford with full +TEST_F(NVFuserTest, FusionRemoveEmptyWelford_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto w = Welford(tv0, {0}); + fusion.addOutput(w.avg); + auto var = div(w.var_sum, fusion_ptr->zeroVal(DataType::Float)); + fusion.addOutput(var); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + EXPECT_NE(var->definition(), nullptr); + EXPECT_TRUE(var->definition()->isA()); + // We divide in the fusion to normalize the variance, so here we have to peel + // that back + auto var_sum = var->definition()->inputs()[0]->as(); + EXPECT_TRUE(var_sum->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::mean(at0, {0}), at::var(at0, {0})}, + __LINE__, + __FILE__); +} + // Test that we replace empty tensors in cat properly TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); @@ -470,10 +512,6 @@ TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); FusionKernelRuntime runtime(std::move(fusion_ptr), args); - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. auto preseg_fusion = runtime.fusionSegments()->completeFusion(); EXPECT_EQ(preseg_fusion->outputs().size(), 2); @@ -520,10 +558,6 @@ TEST_F(NVFuserTest, FusionRemoveEmptyPad_CUDA) { auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); FusionKernelRuntime runtime(std::move(fusion_ptr), args); - // In the FusionKernelRuntime, before segmentation a number of optimization - // passes are performed. One of those is RemoveEmptyPass, which should replace - // the empty output tv1 with a new TensorView defined by `full({0, 3})` in - // this case. auto preseg_fusion = runtime.fusionSegments()->completeFusion(); EXPECT_EQ(preseg_fusion->outputs().size(), 1); From b9953f74350a4afada8a9e92eea34fcd9cb5f9af Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 15:00:42 -0400 Subject: [PATCH 12/37] Silence clang-tidy and convert empty check to TORCH_WARN --- csrc/optimization/remove_empty.cpp | 15 ++++++++++----- csrc/optimization/remove_empty.h | 3 ++- test/test_optimization_pass.cpp | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index e5c86211674..de5c447dd1c 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -70,6 +70,8 @@ class EmptyTensorRemover : BackwardVisitor { traverseTo(fusion_, fusion_->outputs()); } + using BackwardVisitor::handle; + void handle(Statement* stmt) final { if (isDead(stmt)) { // We check whether stmt is dead before we dereference it, since it may @@ -108,11 +110,14 @@ class EmptyTensorRemover : BackwardVisitor { return; } } else { - // TODO: This should be a warning instead of an assert - TORCH_INTERNAL_ASSERT( - !isTVEmpty(tv), - "Found unexpected empty intermediate TensorView ", - tv->toString()); + // Note that if there empty intermediate tensors with uses that do not + // lead to outputs, this check might fail. + if (!tv->uses().empty() && isTVEmpty(tv)) { + TORCH_WARN( + "Found unexpected empty intermediate TensorView ", + tv->toString(), + ". This TensorView has un-removed uses that might not be used in this Fusion."); + } } } diff --git a/csrc/optimization/remove_empty.h b/csrc/optimization/remove_empty.h index 6b4f4aee6fa..2c190619335 100644 --- a/csrc/optimization/remove_empty.h +++ b/csrc/optimization/remove_empty.h @@ -9,7 +9,8 @@ namespace nvfuser::optimization { -//! RemoveEmptyPass removes empty tensors (those with at least one extent zero). +//! RemoveEmptyPass removes intermediate empty tensors (those with at least one +//! extent zero thar are neither a fusion output or input). class TORCH_CUDA_CU_API RemoveEmptyPass : public OptimizationPass { friend class OptimizationPass; diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index 32d45e9d9ee..ff2a664b19c 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -477,7 +477,7 @@ TEST_F(NVFuserTest, FusionRemoveEmptyWelford_CUDA) { preseg_fusion, outputs, aten_inputs, - {at::mean(at0, {0}), at::var(at0, {0})}, + {at::mean(at0, 0), at::var(at0, 0)}, __LINE__, __FILE__); } From 32ab55a8d1ee78757f525fbe258d638fb8cd93a2 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 28 Jun 2023 15:02:10 -0400 Subject: [PATCH 13/37] Use TORCH_WARN_ONCE instead of TORCH_WARN --- csrc/optimization/remove_empty.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index de5c447dd1c..21a7e3f0a3c 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -113,7 +113,7 @@ class EmptyTensorRemover : BackwardVisitor { // Note that if there empty intermediate tensors with uses that do not // lead to outputs, this check might fail. if (!tv->uses().empty() && isTVEmpty(tv)) { - TORCH_WARN( + TORCH_WARN_ONCE( "Found unexpected empty intermediate TensorView ", tv->toString(), ". This TensorView has un-removed uses that might not be used in this Fusion."); From 7b1d7630346fdf1d2ab5b9894676af4baa894b49 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 07:11:23 -0400 Subject: [PATCH 14/37] Cleanup, mark upstream unused tensors dead --- csrc/optimization/remove_empty.cpp | 36 +++++++++++++++++------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 21a7e3f0a3c..89106600df0 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -70,6 +70,7 @@ class EmptyTensorRemover : BackwardVisitor { traverseTo(fusion_, fusion_->outputs()); } + private: using BackwardVisitor::handle; void handle(Statement* stmt) final { @@ -109,15 +110,18 @@ class EmptyTensorRemover : BackwardVisitor { // Do not keep traversing upstream if we've replaced tv return; } - } else { + } else if (tv->uses().empty()) { + // TensorViews that are not Fusion inputs or outputs and which have no + // uses are dead, so remove them and skip processing them and their + // definition. + removeAndMarkDead(tv); + } else if (isTVEmpty(tv)) { // Note that if there empty intermediate tensors with uses that do not // lead to outputs, this check might fail. - if (!tv->uses().empty() && isTVEmpty(tv)) { - TORCH_WARN_ONCE( - "Found unexpected empty intermediate TensorView ", - tv->toString(), - ". This TensorView has un-removed uses that might not be used in this Fusion."); - } + TORCH_WARN_ONCE( + "Found unexpected empty intermediate TensorView ", + tv->toString(), + ". This TensorView has un-removed uses that might not be used in this Fusion."); } } @@ -302,15 +306,17 @@ class EmptyTensorRemover : BackwardVisitor { for (auto use : old_tv->uses()) { ir_utils::replaceValInExpr(use, old_tv, new_tv); } - // old_tv as well as its definition will be removed by fusion_->removeVal(), - // after which the pointers will be invalid. We mark them as dead to avoid - // dereferencing and processing those here. - markDead(old_tv); - if (old_tv->definition()) { - markDead(old_tv->definition()); - } + removeAndMarkDead(old_tv); + } - fusion_->removeVal(old_tv); + //! Guard removeVal with calls to markDead so that we always detect removed + //! Vals and Exprs before derefencing them. + void removeAndMarkDead(Val* val) { + markDead(val); + if (val->definition()) { + markDead(val->definition()); + } + fusion_->removeVal(val); } //! Find whether a statement has been marked dead From bc57aa375bf626008e2a2181d6870d8108b7d430 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 08:55:15 -0400 Subject: [PATCH 15/37] Add live statement tracking. I am going to abstract this out into a DeadCodeEliminator since this pattern will also be used for concretizing slice, and potentially we may want to combine multiple passes to share the traversal machinery. --- csrc/optimization/remove_empty.cpp | 176 ++++++++++++++++++++++++----- 1 file changed, 150 insertions(+), 26 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 89106600df0..7cf4df008a8 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -61,12 +61,56 @@ bool isTVEmpty(TensorView* tv) { //! 4. If the empty Tensorview is the input to a PadOp (which is not input to //! a CatOp) then we replace the pad with `full(pad_value)`. //! +//! We use BackwardVisitor::traversal_exprs_ which tracks the active expressions +//! in this Fusion, to determine when it is safe to remove statements. We +//! augment this by creating an unordered_set called active_statements_, which +//! is initialized as the Exprs in traversal_exprs_ as well as their inputs and +//! outputs. Marking a Statement as dead removes it from active_statements_, and +//! replacing a Val inserts the Val and its definition, recursively. Since we +//! traverse backwards, and we handle all active Expr outputs, this ensures that +//! it is safe to do so whenever we remove an Expr (i.e. it will not result in +//! erasing definitions of active Expr outputs). class EmptyTensorRemover : BackwardVisitor { public: EmptyTensorRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} void run() { + // First we build a set of all live Statements so that we can detect dead + // branches. + auto exprs = StmtSort::getExprs(fusion_, fusion_->outputs()); + // Mark every Expr, as well as its inputs and outputs as live initially + for (auto expr : exprs) { + markLive(expr); + for (auto inp : expr->inputs()) { + markLive(inp); + } + for (auto outp : expr->outputs()) { + markLive(outp); + } + } + + // We do not traverse all outputs of all Exprs, since this requires that all + // paths lead to fusion_->outputs(). Instead, here now mark all Vals dead + // which do not have any live uses. After this, it is safe to check whether + // all outputs of an Expr are actually dead, which helps us determine when + // it is safe to remove a multi-output definition of a dead TV. + std::vector dead_expr_outputs; + for (auto stmt : live_statements_) { + if (stmt->isVal() && !stmt->asVal()->isFusionOutput() && + allUsesDead(stmt->asVal())) { + // We should not erase from live_statements_ while traversing it, so we + // save the dead expr outputs for later removal instead. + dead_expr_outputs.push_back(stmt); + } + } + for (auto stmt : dead_expr_outputs) { + markDead(stmt); + } + + // Note that getExprs is also run in traverseTo. In the future, we + // could potentially refactor this so that derived classes from + // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs()); } @@ -74,7 +118,7 @@ class EmptyTensorRemover : BackwardVisitor { using BackwardVisitor::handle; void handle(Statement* stmt) final { - if (isDead(stmt)) { + if (!isLive(stmt)) { // We check whether stmt is dead before we dereference it, since it may // have been removed from the Fusion. return; @@ -82,6 +126,21 @@ class EmptyTensorRemover : BackwardVisitor { BackwardVisitor::handle(stmt); } + void handle(Expr* expr) final { + bool all_outputs_dead = true; + for (auto outp : expr->outputs()) { + if (isLive(outp)) { + all_outputs_dead = false; + } + } + if (all_outputs_dead) { + markDead(expr); + fusion_->removeExpr(expr); + } else { + BackwardVisitor::handle(expr); + } + } + //! If tv is a fusion output, we check whether it is empty and if so, replace //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. @@ -107,14 +166,12 @@ class EmptyTensorRemover : BackwardVisitor { auto dtype = tv->getDataType().value(); auto new_tv = full(shape, fusion_->zeroVal(dtype), dtype); replaceTV(tv, new_tv); - // Do not keep traversing upstream if we've replaced tv - return; } - } else if (tv->uses().empty()) { + } else if (allUsesDead(tv)) { // TensorViews that are not Fusion inputs or outputs and which have no // uses are dead, so remove them and skip processing them and their // definition. - removeAndMarkDead(tv); + markDeadAndMaybeRemove(tv); } else if (isTVEmpty(tv)) { // Note that if there empty intermediate tensors with uses that do not // lead to outputs, this check might fail. @@ -198,18 +255,27 @@ class EmptyTensorRemover : BackwardVisitor { wop->toString()); } + // Since WelfordOp has multiple outputs, we need to check whether each is + // live before replacing it, to avoid replacing a dead output with a live + // one. Note that replaceTV will mark the replacement as live automatically. auto shape = noReductionShape(avg); - auto nan = IrBuilder::create( - std::numeric_limits::quiet_NaN(), avg->getDataType().value()); - auto nan_tensor = full(shape, nan, avg->getDataType().value()); - auto new_var_sum = full( - shape, - fusion_->zeroVal(var_sum->getDataType().value()), - var_sum->getDataType().value()); - auto new_N = full(shape, fusion_->zeroVal(), N->getDataType().value()); - replaceTV(avg, nan_tensor); - replaceTV(var_sum, new_var_sum); - replaceTV(N, new_N); + if (isLive(avg)) { + auto nan = IrBuilder::create( + std::numeric_limits::quiet_NaN(), avg->getDataType().value()); + auto nan_tensor = full(shape, nan, avg->getDataType().value()); + replaceTV(avg, nan_tensor); + } + if (isLive(var_sum)) { + auto new_var_sum = full( + shape, + fusion_->zeroVal(var_sum->getDataType().value()), + var_sum->getDataType().value()); + replaceTV(var_sum, new_var_sum); + } + if (isLive(N)) { + auto new_N = full(shape, fusion_->zeroVal(), N->getDataType().value()); + replaceTV(N, new_N); + } } //! A cat op can have input empty tensors and still output a non-empty @@ -306,34 +372,92 @@ class EmptyTensorRemover : BackwardVisitor { for (auto use : old_tv->uses()) { ir_utils::replaceValInExpr(use, old_tv, new_tv); } - removeAndMarkDead(old_tv); + markLiveRecursive(new_tv); + markDeadAndMaybeRemove(old_tv); } //! Guard removeVal with calls to markDead so that we always detect removed //! Vals and Exprs before derefencing them. - void removeAndMarkDead(Val* val) { + //! + //! Note that we only remove val if all of its definition's outputs are marked + //! dead. + void markDeadAndMaybeRemove(Val* val) { markDead(val); if (val->definition()) { - markDead(val->definition()); + // When all outputs of def are marked dead, mark the def dead + bool all_outputs_dead = true; + for (auto outp : val->definition()->outputs()) { + if (isLive(outp)) { + all_outputs_dead = false; + break; + } + } + if (all_outputs_dead) { + // If all other outputs are dead, it's safe to remove the definition as + // well as all its outputs + markDead(val->definition()); + const auto outputs = val->definition()->outputs(); + for (auto outp : outputs) { + fusion_->removeVal(outp); + } + } + } else { + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), "Refusing to remove Fusion input"); + fusion_->removeVal(val); } - fusion_->removeVal(val); } - //! Find whether a statement has been marked dead - bool isDead(Statement* stmt) { - return dead_.find(stmt) != dead_.end(); + //! Find whether a statement is live (i.e. not dead code) + //! + //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing + //! all active expressions in the original graph. + bool isLive(Statement* stmt) { + return live_statements_.find(stmt) != live_statements_.end(); + } + + //! Check whether all uses have been marked dead + bool allUsesDead(Val* val) { + for (const auto use : val->uses()) { + if (isLive(use)) { + return false; + } + } + return true; + } + + //! Mark a single Statement as being alive + void markLive(Statement* stmt) { + live_statements_.insert(stmt); + } + + //! Ensure that a Statement and its upstream Statements are alive. If it is an + //! Expr, ensure all its inputs are alive. If it's a Val with a definition, + //! recursive to the definition. + void markLiveRecursive(Statement* stmt) { + if (isLive(stmt)) { + return; + } + markLive(stmt); + if (stmt->isVal() && stmt->asVal()->definition()) { + markLiveRecursive(stmt); + } else { + for (const auto inp : stmt->asExpr()->inputs()) { + markLiveRecursive(inp); + } + } } //! Mark a Statement* as dead so that we avoid dereferencing it later void markDead(Statement* stmt) { - dead_.insert(stmt); + live_statements_.erase(stmt); } private: Fusion* fusion_; - //! Statements are marked dead when they are removed from the Fusion - std::unordered_set dead_; + //! Statements are marked dead by removing them from this set + std::unordered_set live_statements_; }; } // namespace From cf24bd07179ab9e87ad8ef278c6f34778fff4e37 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 09:11:50 -0400 Subject: [PATCH 16/37] Refactor into parent class DeadCodeRemover --- csrc/optimization/remove_empty.cpp | 311 +++++++++++++++-------------- 1 file changed, 165 insertions(+), 146 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 7cf4df008a8..23360e426f9 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -20,46 +20,10 @@ namespace nvfuser::optimization { namespace { -//! Get a vector of the integer positions of constant zero extent axes in the -//! input domain. This will typically be used like -//! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` -std::vector emptyAxes(std::vector domain) { - std::vector empty_axes; - for (auto ax : c10::irange(domain.size())) { - auto id = domain.at(ax); - if (id->extent()->isConst() && id->extent()->evaluateInt() == 0) { - empty_axes.push_back(ax); - } - } - return empty_axes; -} - -//! Check whether a TensorView is empty. During concretization, we traverse to -//! find a minimal set of TensorViews that have zero extents, and we then set -//! their extents to a constant 0. Here we check for those constant zero -//! extents. -bool isTVEmpty(TensorView* tv) { - return !emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain())) - .empty(); -} - -//! removeEmptyPass performs a backward traversal of the Fusion. When it detects -//! a TensorView that has at least one extent that is zero, we do the following: -//! -//! 1. If the empty Tensorview is a Fusion output, we replace it with a -//! TensorView created by `full` having the same shape. Since the original -//! tensor is empty, there is nothing to compute, so this eliminates a branch -//! of trivial code. -//! 2. If the empty TensorView is the input of a `cat` op along the empty -//! dimension, we replace the cat op with a new one having the empty input -//! removed. -//! 3. If the empty Tensorview is the input to a ReductionOp or WelfordOp and -//! the empty dimensions are reduced, we replace the op with `full` since -//! there are no elements being reduced. Note that if any empty axes are not -//! reduced, we will not encounter this case since it will have been removed -//! earlier in the backward traversal under condition 1. -//! 4. If the empty Tensorview is the input to a PadOp (which is not input to -//! a CatOp) then we replace the pad with `full(pad_value)`. +//! This is a generic traversal class that is used to modify a Fusion graph by +//! replacing TensorViews so that their definitions can be altered to simplify +//! computation or remove dead code. Derived classes should override handle() +//! and make use of replaceTV(), markDeadAndMaybeRemove(), and allUsesDead(). //! //! We use BackwardVisitor::traversal_exprs_ which tracks the active expressions //! in this Fusion, to determine when it is safe to remove statements. We @@ -70,10 +34,9 @@ bool isTVEmpty(TensorView* tv) { //! traverse backwards, and we handle all active Expr outputs, this ensures that //! it is safe to do so whenever we remove an Expr (i.e. it will not result in //! erasing definitions of active Expr outputs). -class EmptyTensorRemover : BackwardVisitor { +class DeadCodeRemover : BackwardVisitor { public: - EmptyTensorRemover(Fusion* fusion) - : BackwardVisitor(false), fusion_(fusion) {} + DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} void run() { // First we build a set of all live Statements so that we can detect dead @@ -114,7 +77,11 @@ class EmptyTensorRemover : BackwardVisitor { traverseTo(fusion_, fusion_->outputs()); } - private: + Fusion* fusion() const { + return fusion_; + } + + protected: using BackwardVisitor::handle; void handle(Statement* stmt) final { @@ -141,6 +108,156 @@ class EmptyTensorRemover : BackwardVisitor { } } + //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion + //! input, we do not replace it. After replacement, unless it is a Fusion + //! input, we remove it from the fusion and set the original pointer to zero + //! (hence why old_tv is passed by reference). + void replaceTV(TensorView*& old_tv, TensorView* new_tv) { + if (old_tv->isFusionOutput()) { + fusion_->replaceOutput(old_tv, new_tv); + } + for (auto use : old_tv->uses()) { + ir_utils::replaceValInExpr(use, old_tv, new_tv); + } + markLiveRecursive(new_tv); + markDeadAndMaybeRemove(old_tv); + } + + //! Guard removeVal with calls to markDead so that we always detect removed + //! Vals and Exprs before derefencing them. + //! + //! Note that we only remove val if all of its definition's outputs are marked + //! dead. + void markDeadAndMaybeRemove(Val* val) { + markDead(val); + if (val->definition()) { + // When all outputs of def are marked dead, mark the def dead + bool all_outputs_dead = true; + for (auto outp : val->definition()->outputs()) { + if (isLive(outp)) { + all_outputs_dead = false; + break; + } + } + if (all_outputs_dead) { + // If all other outputs are dead, it's safe to remove the definition as + // well as all its outputs + markDead(val->definition()); + const auto outputs = val->definition()->outputs(); + for (auto outp : outputs) { + fusion_->removeVal(outp); + } + } + } else { + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), "Refusing to remove Fusion input"); + fusion_->removeVal(val); + } + } + + //! Find whether a statement is live (i.e. not dead code) + //! + //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing + //! all active expressions in the original graph. + bool isLive(Statement* stmt) { + return live_statements_.find(stmt) != live_statements_.end(); + } + + //! Check whether all uses have been marked dead + bool allUsesDead(Val* val) { + for (const auto use : val->uses()) { + if (isLive(use)) { + return false; + } + } + return true; + } + + //! Mark a single Statement as being alive + void markLive(Statement* stmt) { + live_statements_.insert(stmt); + } + + //! Ensure that a Statement and its upstream Statements are alive. If it is an + //! Expr, ensure all its inputs are alive. If it's a Val with a definition, + //! recursive to the definition. + void markLiveRecursive(Statement* stmt) { + if (isLive(stmt)) { + return; + } + markLive(stmt); + if (stmt->isVal() && stmt->asVal()->definition()) { + markLiveRecursive(stmt); + } else { + for (const auto inp : stmt->asExpr()->inputs()) { + markLiveRecursive(inp); + } + } + } + + //! Mark a Statement* as dead so that we avoid dereferencing it later + void markDead(Statement* stmt) { + live_statements_.erase(stmt); + } + + private: + Fusion* fusion_; + + //! Statements are marked dead by removing them from this set + std::unordered_set live_statements_; +}; + +//! Get a vector of the integer positions of constant zero extent axes in the +//! input domain. This will typically be used like +//! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` +std::vector emptyAxes(std::vector domain) { + std::vector empty_axes; + for (auto ax : c10::irange(domain.size())) { + auto id = domain.at(ax); + if (id->extent()->isConst() && id->extent()->evaluateInt() == 0) { + empty_axes.push_back(ax); + } + } + return empty_axes; +} + +//! Check whether a TensorView is empty. During concretization, we traverse to +//! find a minimal set of TensorViews that have zero extents, and we then set +//! their extents to a constant 0. Here we check for those constant zero +//! extents. +bool isTVEmpty(TensorView* tv) { + return !emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain())) + .empty(); +} + +//! EmptyTensorRemover performs a backward traversal of the Fusion. When it +//! detects a TensorView that has at least one extent that is zero, we do the +//! following: +//! +//! 1. If the empty Tensorview is a Fusion output, we replace it with a +//! TensorView created by `full` having the same shape. Since the original +//! tensor is empty, there is nothing to compute, so this eliminates a branch +//! of trivial code. +//! 2. If the empty TensorView is the input of a `cat` op along the empty +//! dimension, we replace the cat op with a new one having the empty input +//! removed. +//! 3. If the empty Tensorview is the input to a ReductionOp or WelfordOp and +//! the empty dimensions are reduced, we replace the op with `full` since +//! there are no elements being reduced. Note that if any empty axes are not +//! reduced, we will not encounter this case since it will have been removed +//! earlier in the backward traversal under condition 1. +//! 4. If the empty Tensorview is the input to a PadOp (which is not input to +//! a CatOp) then we replace the pad with `full(pad_value)`. +//! +class EmptyTensorRemover : DeadCodeRemover { + public: + EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {} + + using DeadCodeRemover::run; + + protected: + using DeadCodeRemover::handle; + //! If tv is a fusion output, we check whether it is empty and if so, replace //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. @@ -161,10 +278,10 @@ class EmptyTensorRemover : BackwardVisitor { return id->extent(); }); for (auto ax : empty_axes) { - shape[ax] = fusion_->zeroVal(); + shape[ax] = fusion()->zeroVal(); } auto dtype = tv->getDataType().value(); - auto new_tv = full(shape, fusion_->zeroVal(dtype), dtype); + auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); replaceTV(tv, new_tv); } } else if (allUsesDead(tv)) { @@ -268,12 +385,12 @@ class EmptyTensorRemover : BackwardVisitor { if (isLive(var_sum)) { auto new_var_sum = full( shape, - fusion_->zeroVal(var_sum->getDataType().value()), + fusion()->zeroVal(var_sum->getDataType().value()), var_sum->getDataType().value()); replaceTV(var_sum, new_var_sum); } if (isLive(N)) { - auto new_N = full(shape, fusion_->zeroVal(), N->getDataType().value()); + auto new_N = full(shape, fusion()->zeroVal(), N->getDataType().value()); replaceTV(N, new_N); } } @@ -360,104 +477,6 @@ class EmptyTensorRemover : BackwardVisitor { replaceTV(out, new_tv); } } - - //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion - //! input, we do not replace it. After replacement, unless it is a Fusion - //! input, we remove it from the fusion and set the original pointer to zero - //! (hence why old_tv is passed by reference). - void replaceTV(TensorView*& old_tv, TensorView* new_tv) { - if (old_tv->isFusionOutput()) { - fusion_->replaceOutput(old_tv, new_tv); - } - for (auto use : old_tv->uses()) { - ir_utils::replaceValInExpr(use, old_tv, new_tv); - } - markLiveRecursive(new_tv); - markDeadAndMaybeRemove(old_tv); - } - - //! Guard removeVal with calls to markDead so that we always detect removed - //! Vals and Exprs before derefencing them. - //! - //! Note that we only remove val if all of its definition's outputs are marked - //! dead. - void markDeadAndMaybeRemove(Val* val) { - markDead(val); - if (val->definition()) { - // When all outputs of def are marked dead, mark the def dead - bool all_outputs_dead = true; - for (auto outp : val->definition()->outputs()) { - if (isLive(outp)) { - all_outputs_dead = false; - break; - } - } - if (all_outputs_dead) { - // If all other outputs are dead, it's safe to remove the definition as - // well as all its outputs - markDead(val->definition()); - const auto outputs = val->definition()->outputs(); - for (auto outp : outputs) { - fusion_->removeVal(outp); - } - } - } else { - TORCH_INTERNAL_ASSERT( - !val->isFusionInput(), "Refusing to remove Fusion input"); - fusion_->removeVal(val); - } - } - - //! Find whether a statement is live (i.e. not dead code) - //! - //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing - //! all active expressions in the original graph. - bool isLive(Statement* stmt) { - return live_statements_.find(stmt) != live_statements_.end(); - } - - //! Check whether all uses have been marked dead - bool allUsesDead(Val* val) { - for (const auto use : val->uses()) { - if (isLive(use)) { - return false; - } - } - return true; - } - - //! Mark a single Statement as being alive - void markLive(Statement* stmt) { - live_statements_.insert(stmt); - } - - //! Ensure that a Statement and its upstream Statements are alive. If it is an - //! Expr, ensure all its inputs are alive. If it's a Val with a definition, - //! recursive to the definition. - void markLiveRecursive(Statement* stmt) { - if (isLive(stmt)) { - return; - } - markLive(stmt); - if (stmt->isVal() && stmt->asVal()->definition()) { - markLiveRecursive(stmt); - } else { - for (const auto inp : stmt->asExpr()->inputs()) { - markLiveRecursive(inp); - } - } - } - - //! Mark a Statement* as dead so that we avoid dereferencing it later - void markDead(Statement* stmt) { - live_statements_.erase(stmt); - } - - private: - Fusion* fusion_; - - //! Statements are marked dead by removing them from this set - std::unordered_set live_statements_; }; } // namespace From 4d0c6566154e25a54d3d371216c0d14297a354cc Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 09:15:24 -0400 Subject: [PATCH 17/37] Update comment for DeadCodeRemover --- csrc/optimization/remove_empty.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 23360e426f9..3256b5d1936 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -25,15 +25,13 @@ namespace { //! computation or remove dead code. Derived classes should override handle() //! and make use of replaceTV(), markDeadAndMaybeRemove(), and allUsesDead(). //! -//! We use BackwardVisitor::traversal_exprs_ which tracks the active expressions -//! in this Fusion, to determine when it is safe to remove statements. We -//! augment this by creating an unordered_set called active_statements_, which -//! is initialized as the Exprs in traversal_exprs_ as well as their inputs and -//! outputs. Marking a Statement as dead removes it from active_statements_, and -//! replacing a Val inserts the Val and its definition, recursively. Since we -//! traverse backwards, and we handle all active Expr outputs, this ensures that -//! it is safe to do so whenever we remove an Expr (i.e. it will not result in -//! erasing definitions of active Expr outputs). +//! We use unordered_set called live_statements_, which is initialized as the +//! Exprs in traversal_exprs_ as well as their inputs and their outputs with +//! live uses. Marking a Statement as dead removes it from live_statements_, +//! and replacing a Val inserts the Val and its definition, recursively. Since +//! we traverse backwards, and we handle all active Expr outputs, this ensures +//! that it is safe to removing an Expr will not result in erasing definitions +//! of active Expr outputs. class DeadCodeRemover : BackwardVisitor { public: DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} From d1532c29ae2647c18e03e80a961fda638e86e5bb Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 09:17:12 -0400 Subject: [PATCH 18/37] Comment update --- csrc/optimization/remove_empty.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 3256b5d1936..baaf6ec5e04 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -52,10 +52,11 @@ class DeadCodeRemover : BackwardVisitor { } // We do not traverse all outputs of all Exprs, since this requires that all - // paths lead to fusion_->outputs(). Instead, here now mark all Vals dead - // which do not have any live uses. After this, it is safe to check whether - // all outputs of an Expr are actually dead, which helps us determine when - // it is safe to remove a multi-output definition of a dead TV. + // paths lead to fusion_->outputs(). Instead, here we mark any Vals dead + // that do not have any live uses. After this, it is safe to check whether + // all outputs of an Expr are marked dead to determine if it should be dead, + // which helps us determine when it is safe to remove a multi-output + // definition of a dead TV. std::vector dead_expr_outputs; for (auto stmt : live_statements_) { if (stmt->isVal() && !stmt->asVal()->isFusionOutput() && From 4f83f3fdb5c4506a94e465d5cfa0641997d8c670 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 09:31:20 -0400 Subject: [PATCH 19/37] Update comment --- csrc/optimization/remove_empty.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index baaf6ec5e04..d45de792639 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -22,8 +22,9 @@ namespace { //! This is a generic traversal class that is used to modify a Fusion graph by //! replacing TensorViews so that their definitions can be altered to simplify -//! computation or remove dead code. Derived classes should override handle() -//! and make use of replaceTV(), markDeadAndMaybeRemove(), and allUsesDead(). +//! computation or remove dead code. This differs from OptOutMutator, which is +//! built for mutating TensorViews in a graph, and does not easily handle +//! modifying TensorView definitions and Expr Fusion inputs during traversal. //! //! We use unordered_set called live_statements_, which is initialized as the //! Exprs in traversal_exprs_ as well as their inputs and their outputs with @@ -32,6 +33,11 @@ namespace { //! we traverse backwards, and we handle all active Expr outputs, this ensures //! that it is safe to removing an Expr will not result in erasing definitions //! of active Expr outputs. +//! +//! Derived classes should override handle() and make use of replaceTV(), +//! markDeadAndMaybeRemove(), and allUsesDead(). Note that if replacements are +//! made using replaceTV(old_tv, new_tv), then neither new_tv or any new +//! Statements produced in creating it will be traversed by this class. class DeadCodeRemover : BackwardVisitor { public: DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} @@ -70,8 +76,8 @@ class DeadCodeRemover : BackwardVisitor { markDead(stmt); } - // Note that getExprs is also run in traverseTo. In the future, we - // could potentially refactor this so that derived classes from + // Note that StmtSort::getExprs() is also run in traverseTo. In the future, + // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs()); } From 940dd73e811a4cb9b523c57a5bec2e8a90fcbafe Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 19:26:01 -0400 Subject: [PATCH 20/37] Use getStmts in DeadCodeRemover::run --- csrc/optimization/remove_empty.cpp | 33 +++--------------------------- 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index d45de792639..a5fa8a8e38e 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -45,38 +45,11 @@ class DeadCodeRemover : BackwardVisitor { void run() { // First we build a set of all live Statements so that we can detect dead // branches. - auto exprs = StmtSort::getExprs(fusion_, fusion_->outputs()); - // Mark every Expr, as well as its inputs and outputs as live initially - for (auto expr : exprs) { - markLive(expr); - for (auto inp : expr->inputs()) { - markLive(inp); - } - for (auto outp : expr->outputs()) { - markLive(outp); - } - } - - // We do not traverse all outputs of all Exprs, since this requires that all - // paths lead to fusion_->outputs(). Instead, here we mark any Vals dead - // that do not have any live uses. After this, it is safe to check whether - // all outputs of an Expr are marked dead to determine if it should be dead, - // which helps us determine when it is safe to remove a multi-output - // definition of a dead TV. - std::vector dead_expr_outputs; - for (auto stmt : live_statements_) { - if (stmt->isVal() && !stmt->asVal()->isFusionOutput() && - allUsesDead(stmt->asVal())) { - // We should not erase from live_statements_ while traversing it, so we - // save the dead expr outputs for later removal instead. - dead_expr_outputs.push_back(stmt); - } - } - for (auto stmt : dead_expr_outputs) { - markDead(stmt); + for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { + markLive(stmt); } - // Note that StmtSort::getExprs() is also run in traverseTo. In the future, + // Note that StmtSort::getStmts() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs()); From 5349160019318995f9b5457bc7dd491a08f8028c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 19:40:14 -0400 Subject: [PATCH 21/37] Simplify PadOp handling --- csrc/optimization/remove_empty.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index a5fa8a8e38e..08e94c1de9f 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -444,14 +444,9 @@ class EmptyTensorRemover : DeadCodeRemover { auto in_rfactor = TensorDomain::noReductions(in->getMaybeRFactorDomain()); if (!emptyAxes(in_rfactor).empty()) { auto out = pop->out()->as(); - auto out_rfactor = - TensorDomain::noReductions(out->getMaybeRFactorDomain()); - std::vector shape; - shape.reserve(out_rfactor.size()); - for (auto id : out_rfactor) { - shape.push_back(id->extent()); - } - auto new_tv = full(shape, pop->value(), out->getDataType().value()); + auto shape = noReductionShape(out); + auto dtype = out->getDataType().value(); + auto new_tv = full(shape, pop->value(), dtype); replaceTV(out, new_tv); } } From 267011cac47532387b6b2f51fe8686cafe45497a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 20:15:19 -0400 Subject: [PATCH 22/37] Handle MmaOp --- csrc/optimization/remove_empty.cpp | 18 +++++++++++ test/test_optimization_pass.cpp | 51 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 08e94c1de9f..9e4b92ab39f 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -226,6 +226,8 @@ bool isTVEmpty(TensorView* tv) { //! earlier in the backward traversal under condition 1. //! 4. If the empty Tensorview is the input to a PadOp (which is not input to //! a CatOp) then we replace the pad with `full(pad_value)`. +//! 5. If empty TensorViews are input to an MmaOp and they are empty in +//! contracted axes, we replace with `full({m, n}, zeroVal())`. //! class EmptyTensorRemover : DeadCodeRemover { public: @@ -450,6 +452,22 @@ class EmptyTensorRemover : DeadCodeRemover { replaceTV(out, new_tv); } } + + //! We handle MmaOp just as if it were written as a sum ReductionOp. + void handle(MmaOp* mop) final { + auto A = mop->inA()->as(); + auto A_rfactor = TensorDomain::noReductions(A->getMaybeRFactorDomain()); + // We only need to check empty axes in A. If any reduced axes are empty + // here, they will be empty in B also. If any non-reduced axes are empty, + // the output will also be empty, and this expression will already be dead. + if (!emptyAxes(A_rfactor).empty()) { + auto out = mop->out()->as(); + auto shape = noReductionShape(out); + auto dtype = out->getDataType().value(); + auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); + replaceTV(out, new_tv); + } + } }; } // namespace diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index ff2a664b19c..d04e887df4f 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -578,4 +578,55 @@ TEST_F(NVFuserTest, FusionRemoveEmptyPad_CUDA) { __FILE__); } +// Test that we replace empty tensors in matmuls properly +TEST_F(NVFuserTest, FusionRemoveEmptyMatmul_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + // [M, K] + auto tv0 = makeConcreteTensor({16, 0}, DataType::Half); + // [K, N] + auto tv1 = makeConcreteTensor({0, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M, N, K] + auto tv0b = broadcast(tv0, {false, true, false}); + // [M, K, N] + auto tv1b = broadcast(tv1, {true, false, false}); + // [M, N, K] + auto tv1t = transpose(tv1b, 1, 2); + + auto tv2 = fusedMultiplySum(tv0b, tv1t, {2}); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({16, 0}, options); + at::Tensor at1 = at::randn({0, 8}, options); + std::vector aten_inputs = {at0, at1}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + auto rewritten_def = preseg_fusion->outputs()[0]->definition(); + EXPECT_TRUE(rewritten_def->isA()); + EXPECT_EQ(rewritten_def->as()->getFillValue()->evaluateDouble(), 0.0); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::zeros({16, 8}, options)}, + __LINE__, + __FILE__); +} + } // namespace nvfuser::optimization From d84f7fec8f8152d285ff27b824de07af4e8050b3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 29 Jun 2023 20:19:08 -0400 Subject: [PATCH 23/37] Use public inheritance in EmptyTensorRemover So that we no longer need `using DeadCodeRemover::run` --- csrc/optimization/remove_empty.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 9e4b92ab39f..34f050d83ea 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -229,12 +229,10 @@ bool isTVEmpty(TensorView* tv) { //! 5. If empty TensorViews are input to an MmaOp and they are empty in //! contracted axes, we replace with `full({m, n}, zeroVal())`. //! -class EmptyTensorRemover : DeadCodeRemover { +class EmptyTensorRemover : public DeadCodeRemover { public: EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {} - using DeadCodeRemover::run; - protected: using DeadCodeRemover::handle; From ea2374c16cf692de30e961b5044259d671426ead Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 09:04:26 -0400 Subject: [PATCH 24/37] Create allOutputsDead method --- csrc/optimization/remove_empty.cpp | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 34f050d83ea..ef9ec0c63b4 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -72,13 +72,7 @@ class DeadCodeRemover : BackwardVisitor { } void handle(Expr* expr) final { - bool all_outputs_dead = true; - for (auto outp : expr->outputs()) { - if (isLive(outp)) { - all_outputs_dead = false; - } - } - if (all_outputs_dead) { + if (allOutputsDead(expr)) { markDead(expr); fusion_->removeExpr(expr); } else { @@ -109,15 +103,7 @@ class DeadCodeRemover : BackwardVisitor { void markDeadAndMaybeRemove(Val* val) { markDead(val); if (val->definition()) { - // When all outputs of def are marked dead, mark the def dead - bool all_outputs_dead = true; - for (auto outp : val->definition()->outputs()) { - if (isLive(outp)) { - all_outputs_dead = false; - break; - } - } - if (all_outputs_dead) { + if (allOutputsDead(val->definition())) { // If all other outputs are dead, it's safe to remove the definition as // well as all its outputs markDead(val->definition()); @@ -141,6 +127,14 @@ class DeadCodeRemover : BackwardVisitor { return live_statements_.find(stmt) != live_statements_.end(); } + //! Check whether all outputs of an expression have been marked dead + bool allOutputsDead(Expr* expr) { + return std::all_of( + expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { + return !isLive(outp); + }); + } + //! Check whether all uses have been marked dead bool allUsesDead(Val* val) { for (const auto use : val->uses()) { From d7d43fc254d09650f91d18a84bdef2504bdd14fc Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 09:07:46 -0400 Subject: [PATCH 25/37] Convert empty intermediate warning to comment --- csrc/optimization/remove_empty.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index ef9ec0c63b4..702289ea065 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -261,14 +261,10 @@ class EmptyTensorRemover : public DeadCodeRemover { // uses are dead, so remove them and skip processing them and their // definition. markDeadAndMaybeRemove(tv); - } else if (isTVEmpty(tv)) { - // Note that if there empty intermediate tensors with uses that do not - // lead to outputs, this check might fail. - TORCH_WARN_ONCE( - "Found unexpected empty intermediate TensorView ", - tv->toString(), - ". This TensorView has un-removed uses that might not be used in this Fusion."); } + // Note: we should not encounter isTVEmpty(tv)==true at this point if we + // have properly set all empty extents in non Fusion-input TensorViews to + // constant zeros. } //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) From 1afed11da688bced23475ebf3a3d9b62719e9879 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 09:14:50 -0400 Subject: [PATCH 26/37] Revert "Convert empty intermediate warning to comment" This reverts commit d7d43fc254d09650f91d18a84bdef2504bdd14fc. --- csrc/optimization/remove_empty.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 702289ea065..ef9ec0c63b4 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -261,10 +261,14 @@ class EmptyTensorRemover : public DeadCodeRemover { // uses are dead, so remove them and skip processing them and their // definition. markDeadAndMaybeRemove(tv); + } else if (isTVEmpty(tv)) { + // Note that if there empty intermediate tensors with uses that do not + // lead to outputs, this check might fail. + TORCH_WARN_ONCE( + "Found unexpected empty intermediate TensorView ", + tv->toString(), + ". This TensorView has un-removed uses that might not be used in this Fusion."); } - // Note: we should not encounter isTVEmpty(tv)==true at this point if we - // have properly set all empty extents in non Fusion-input TensorViews to - // constant zeros. } //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) From b5c46ef7837a79e61679db88ba32c9dcba6a053a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 09:16:02 -0400 Subject: [PATCH 27/37] Convert intermediate empty check to assertion --- csrc/optimization/remove_empty.cpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index ef9ec0c63b4..d2ffe3cecb5 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -261,14 +261,11 @@ class EmptyTensorRemover : public DeadCodeRemover { // uses are dead, so remove them and skip processing them and their // definition. markDeadAndMaybeRemove(tv); - } else if (isTVEmpty(tv)) { - // Note that if there empty intermediate tensors with uses that do not - // lead to outputs, this check might fail. - TORCH_WARN_ONCE( - "Found unexpected empty intermediate TensorView ", - tv->toString(), - ". This TensorView has un-removed uses that might not be used in this Fusion."); } + TORCH_INTERNAL_ASSERT( + !isTVEmpty(tv), + "Found unexpected empty intermediate TensorView ", + tv->toString()); } //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) From 6d62c13b675a80903682791cb14310378612db7a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 13:44:48 -0400 Subject: [PATCH 28/37] Move assert to else block. This is the version I tested. I just missed it in the previous commit... --- csrc/optimization/remove_empty.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index d2ffe3cecb5..dfa4afb6c24 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -261,11 +261,12 @@ class EmptyTensorRemover : public DeadCodeRemover { // uses are dead, so remove them and skip processing them and their // definition. markDeadAndMaybeRemove(tv); + } else { + TORCH_INTERNAL_ASSERT( + !isTVEmpty(tv), + "Found unexpected empty intermediate TensorView ", + tv->toString()); } - TORCH_INTERNAL_ASSERT( - !isTVEmpty(tv), - "Found unexpected empty intermediate TensorView ", - tv->toString()); } //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) From acdfa4c4deb86182a5af978b19add54b4ae29693 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 13:45:33 -0400 Subject: [PATCH 29/37] Test empty-output reduction -> reduction over empty --- test/test_optimization_pass.cpp | 39 +++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index d04e887df4f..158ac4dbd84 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -436,6 +436,45 @@ TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { __FILE__); } +// In this test, a reduction over a non-empty axis occurs first, followed by a +// reduction over the remaining empty axis. The output is actually not empty, +// even though the first reduction results in an empty tensor. +TEST_F(NVFuserTest, FusionRemoveEmptyReductionWithNonReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3, 2}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3, 2}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::sum(at::sum(at0, 1), 0)}, + __LINE__, + __FILE__); +} + // Test that we replace empty Welford with full TEST_F(NVFuserTest, FusionRemoveEmptyWelford_CUDA) { std::unique_ptr fusion_ptr = std::make_unique(); From 3591bface59e9095a39348fcc9e17b084497c4e7 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 30 Jun 2023 13:58:05 -0400 Subject: [PATCH 30/37] Move DeadCodeRemover to iter_visitor.cpp --- csrc/iter_visitor.cpp | 124 ++++++++++++++++++++++ csrc/iter_visitor.h | 84 +++++++++++++++ csrc/optimization/remove_empty.cpp | 160 +---------------------------- 3 files changed, 209 insertions(+), 159 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 33a5f0f4b05..dbbd6578ff1 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -933,4 +933,128 @@ std::vector InputsOf::outputs( return io.ordered_inputs; } +/* DEAD CODE REMOVER */ +void DeadCodeRemover::run() { + // First we build a set of all live Statements so that we can detect dead + // branches. + for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { + markLive(stmt); + } + + // Note that StmtSort::getStmts() is also run in traverseTo. In the future, + // we could potentially refactor this so that derived classes from + // BackwardVisitor can make use of that traversal instead of repeating it. + traverseTo(fusion_, fusion_->outputs()); +} + +Fusion* DeadCodeRemover::fusion() const { + return fusion_; +} + +void DeadCodeRemover::handle(Statement* stmt) { + if (!isLive(stmt)) { + // We check whether stmt is dead before we dereference it, since it may + // have been removed from the Fusion. + return; + } + BackwardVisitor::handle(stmt); +} + +void DeadCodeRemover::handle(Expr* expr) { + if (allOutputsDead(expr)) { + markDead(expr); + fusion_->removeExpr(expr); + } else { + BackwardVisitor::handle(expr); + } +} + +void DeadCodeRemover::replaceTV(TensorView*& old_tv, TensorView* new_tv) { + if (old_tv->isFusionOutput()) { + fusion_->replaceOutput(old_tv, new_tv); + } + for (auto use : old_tv->uses()) { + ir_utils::replaceValInExpr(use, old_tv, new_tv); + } + markLiveRecursive(new_tv); + markDeadAndMaybeRemove(old_tv); +} + +//! Guard removeVal with calls to markDead so that we always detect removed +//! Vals and Exprs before derefencing them. +//! +//! Note that we only remove val if all of its definition's outputs are marked +//! dead. +void DeadCodeRemover::markDeadAndMaybeRemove(Val* val) { + markDead(val); + if (val->definition()) { + if (allOutputsDead(val->definition())) { + // If all other outputs are dead, it's safe to remove the definition as + // well as all its outputs + markDead(val->definition()); + const auto outputs = val->definition()->outputs(); + for (auto outp : outputs) { + fusion_->removeVal(outp); + } + } + } else { + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), "Refusing to remove Fusion input"); + fusion_->removeVal(val); + } +} + +//! Find whether a statement is live (i.e. not dead code) +//! +//! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing +//! all active expressions in the original graph. +bool DeadCodeRemover::isLive(Statement* stmt) { + return live_statements_.find(stmt) != live_statements_.end(); +} + +//! Check whether all outputs of an expression have been marked dead +bool DeadCodeRemover::allOutputsDead(Expr* expr) { + return std::all_of( + expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { + return !isLive(outp); + }); +} + +//! Check whether all uses have been marked dead +bool DeadCodeRemover::allUsesDead(Val* val) { + for (const auto use : val->uses()) { + if (isLive(use)) { + return false; + } + } + return true; +} + +//! Mark a single Statement as being alive +void DeadCodeRemover::markLive(Statement* stmt) { + live_statements_.insert(stmt); +} + +//! Ensure that a Statement and its upstream Statements are alive. If it is an +//! Expr, ensure all its inputs are alive. If it's a Val with a definition, +//! recursive to the definition. +void DeadCodeRemover::markLiveRecursive(Statement* stmt) { + if (isLive(stmt)) { + return; + } + markLive(stmt); + if (stmt->isVal() && stmt->asVal()->definition()) { + markLiveRecursive(stmt); + } else { + for (const auto inp : stmt->asExpr()->inputs()) { + markLiveRecursive(inp); + } + } +} + +//! Mark a Statement* as dead so that we avoid dereferencing it later +void DeadCodeRemover::markDead(Statement* stmt) { + live_statements_.erase(stmt); +} + } // namespace nvfuser diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index b042766750a..ca992cefb55 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -372,4 +372,88 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor { const std::vector& outputs_); }; +//! This is a generic traversal class that is used to modify a Fusion graph by +//! replacing TensorViews so that their definitions can be altered to simplify +//! computation or remove dead code. This differs from OptOutMutator, which is +//! built for mutating TensorViews in a graph, and does not easily handle +//! modifying TensorView definitions and Expr Fusion inputs during traversal. +//! +//! We use unordered_set called live_statements_, which is initialized as the +//! Exprs in traversal_exprs_ as well as their inputs and their outputs with +//! live uses. Marking a Statement as dead removes it from live_statements_, +//! and replacing a Val inserts the Val and its definition, recursively. Since +//! we traverse backwards, and we handle all active Expr outputs, this ensures +//! that it is safe to removing an Expr will not result in erasing definitions +//! of active Expr outputs. +//! +//! Derived classes should override handle() and make use of replaceTV(), +//! markDeadAndMaybeRemove(), and allUsesDead(). Note that if replacements are +//! made using replaceTV(old_tv, new_tv), then neither new_tv or any new +//! Statements produced in creating it will be traversed by this class. +class DeadCodeRemover : BackwardVisitor { + public: + DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} + + DeadCodeRemover(const DeadCodeRemover& other) = default; + DeadCodeRemover& operator=(const DeadCodeRemover& other) = default; + + DeadCodeRemover(DeadCodeRemover&& other) = default; + DeadCodeRemover& operator=(DeadCodeRemover&& other) = default; + + //! Instead of traverseTo, run() is the entry point for this class, and we + //! always traverse from outputs backward to their inputs. + void run(); + + Fusion* fusion() const; + + protected: + using BackwardVisitor::handle; + + void handle(Statement* stmt) override; + void handle(Expr* expr) override; + + //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion + //! input, we do not replace it. After replacement, unless it is a Fusion + //! input, we remove it from the fusion and set the original pointer to zero + //! (hence why old_tv is passed by reference). + void replaceTV(TensorView*& old_tv, TensorView* new_tv); + + //! Guard removeVal with calls to markDead so that we always detect removed + //! Vals and Exprs before derefencing them. + //! + //! Note that we only remove val if all of its definition's outputs are marked + //! dead. + void markDeadAndMaybeRemove(Val* val); + + //! Find whether a statement is live (i.e. not dead code) + //! + //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing + //! all active expressions in the original graph. + bool isLive(Statement* stmt); + + //! Check whether all outputs of an expression have been marked dead + bool allOutputsDead(Expr* expr); + + //! Check whether all uses have been marked dead + bool allUsesDead(Val* val); + + //! Mark a single Statement as being alive + void markLive(Statement* stmt); + + //! Ensure that a Statement and its upstream Statements are alive. If it is an + //! Expr, ensure all its inputs are alive. If it's a Val with a definition, + //! recursive to the definition. + void markLiveRecursive(Statement* stmt); + + //! Mark a Statement* as dead so that we avoid dereferencing it later + void markDead(Statement* stmt); + + private: + //! The Fusion associated with live_statements_ + Fusion* fusion_; + + //! Statements are marked dead by removing them from this set + std::unordered_set live_statements_; +}; + } // namespace nvfuser diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index dfa4afb6c24..117f87f86d3 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -20,165 +21,6 @@ namespace nvfuser::optimization { namespace { -//! This is a generic traversal class that is used to modify a Fusion graph by -//! replacing TensorViews so that their definitions can be altered to simplify -//! computation or remove dead code. This differs from OptOutMutator, which is -//! built for mutating TensorViews in a graph, and does not easily handle -//! modifying TensorView definitions and Expr Fusion inputs during traversal. -//! -//! We use unordered_set called live_statements_, which is initialized as the -//! Exprs in traversal_exprs_ as well as their inputs and their outputs with -//! live uses. Marking a Statement as dead removes it from live_statements_, -//! and replacing a Val inserts the Val and its definition, recursively. Since -//! we traverse backwards, and we handle all active Expr outputs, this ensures -//! that it is safe to removing an Expr will not result in erasing definitions -//! of active Expr outputs. -//! -//! Derived classes should override handle() and make use of replaceTV(), -//! markDeadAndMaybeRemove(), and allUsesDead(). Note that if replacements are -//! made using replaceTV(old_tv, new_tv), then neither new_tv or any new -//! Statements produced in creating it will be traversed by this class. -class DeadCodeRemover : BackwardVisitor { - public: - DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} - - void run() { - // First we build a set of all live Statements so that we can detect dead - // branches. - for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { - markLive(stmt); - } - - // Note that StmtSort::getStmts() is also run in traverseTo. In the future, - // we could potentially refactor this so that derived classes from - // BackwardVisitor can make use of that traversal instead of repeating it. - traverseTo(fusion_, fusion_->outputs()); - } - - Fusion* fusion() const { - return fusion_; - } - - protected: - using BackwardVisitor::handle; - - void handle(Statement* stmt) final { - if (!isLive(stmt)) { - // We check whether stmt is dead before we dereference it, since it may - // have been removed from the Fusion. - return; - } - BackwardVisitor::handle(stmt); - } - - void handle(Expr* expr) final { - if (allOutputsDead(expr)) { - markDead(expr); - fusion_->removeExpr(expr); - } else { - BackwardVisitor::handle(expr); - } - } - - //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion - //! input, we do not replace it. After replacement, unless it is a Fusion - //! input, we remove it from the fusion and set the original pointer to zero - //! (hence why old_tv is passed by reference). - void replaceTV(TensorView*& old_tv, TensorView* new_tv) { - if (old_tv->isFusionOutput()) { - fusion_->replaceOutput(old_tv, new_tv); - } - for (auto use : old_tv->uses()) { - ir_utils::replaceValInExpr(use, old_tv, new_tv); - } - markLiveRecursive(new_tv); - markDeadAndMaybeRemove(old_tv); - } - - //! Guard removeVal with calls to markDead so that we always detect removed - //! Vals and Exprs before derefencing them. - //! - //! Note that we only remove val if all of its definition's outputs are marked - //! dead. - void markDeadAndMaybeRemove(Val* val) { - markDead(val); - if (val->definition()) { - if (allOutputsDead(val->definition())) { - // If all other outputs are dead, it's safe to remove the definition as - // well as all its outputs - markDead(val->definition()); - const auto outputs = val->definition()->outputs(); - for (auto outp : outputs) { - fusion_->removeVal(outp); - } - } - } else { - TORCH_INTERNAL_ASSERT( - !val->isFusionInput(), "Refusing to remove Fusion input"); - fusion_->removeVal(val); - } - } - - //! Find whether a statement is live (i.e. not dead code) - //! - //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing - //! all active expressions in the original graph. - bool isLive(Statement* stmt) { - return live_statements_.find(stmt) != live_statements_.end(); - } - - //! Check whether all outputs of an expression have been marked dead - bool allOutputsDead(Expr* expr) { - return std::all_of( - expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { - return !isLive(outp); - }); - } - - //! Check whether all uses have been marked dead - bool allUsesDead(Val* val) { - for (const auto use : val->uses()) { - if (isLive(use)) { - return false; - } - } - return true; - } - - //! Mark a single Statement as being alive - void markLive(Statement* stmt) { - live_statements_.insert(stmt); - } - - //! Ensure that a Statement and its upstream Statements are alive. If it is an - //! Expr, ensure all its inputs are alive. If it's a Val with a definition, - //! recursive to the definition. - void markLiveRecursive(Statement* stmt) { - if (isLive(stmt)) { - return; - } - markLive(stmt); - if (stmt->isVal() && stmt->asVal()->definition()) { - markLiveRecursive(stmt); - } else { - for (const auto inp : stmt->asExpr()->inputs()) { - markLiveRecursive(inp); - } - } - } - - //! Mark a Statement* as dead so that we avoid dereferencing it later - void markDead(Statement* stmt) { - live_statements_.erase(stmt); - } - - private: - Fusion* fusion_; - - //! Statements are marked dead by removing them from this set - std::unordered_set live_statements_; -}; - //! Get a vector of the integer positions of constant zero extent axes in the //! input domain. This will typically be used like //! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` From 53d51ed92ce0bea8b512c9b59fde5dac547dc4b6 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 3 Jul 2023 12:07:03 -0400 Subject: [PATCH 31/37] Fix constness and use all_of in allUsesDead. Check I/O in markDead --- csrc/iter_visitor.cpp | 26 +++++++++++++++++--------- csrc/iter_visitor.h | 6 +++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index dbbd6578ff1..085ced79713 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -1008,12 +1008,12 @@ void DeadCodeRemover::markDeadAndMaybeRemove(Val* val) { //! //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing //! all active expressions in the original graph. -bool DeadCodeRemover::isLive(Statement* stmt) { +bool DeadCodeRemover::isLive(Statement* stmt) const { return live_statements_.find(stmt) != live_statements_.end(); } //! Check whether all outputs of an expression have been marked dead -bool DeadCodeRemover::allOutputsDead(Expr* expr) { +bool DeadCodeRemover::allOutputsDead(Expr* expr) const { return std::all_of( expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { return !isLive(outp); @@ -1021,13 +1021,10 @@ bool DeadCodeRemover::allOutputsDead(Expr* expr) { } //! Check whether all uses have been marked dead -bool DeadCodeRemover::allUsesDead(Val* val) { - for (const auto use : val->uses()) { - if (isLive(use)) { - return false; - } - } - return true; +bool DeadCodeRemover::allUsesDead(Val* val) const { + return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { + return !isLive(use); + }); } //! Mark a single Statement as being alive @@ -1054,6 +1051,17 @@ void DeadCodeRemover::markLiveRecursive(Statement* stmt) { //! Mark a Statement* as dead so that we avoid dereferencing it later void DeadCodeRemover::markDead(Statement* stmt) { + if (stmt->isVal()) { + auto val = stmt->asVal(); + TORCH_INTERNAL_ASSERT( + !val->isFusionOutput(), + "Call to markDead on Fusion output is illegal: ", + val->toString()); + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), + "Call to markDead on Fusion input is illegal: ", + val->toString()); + } live_statements_.erase(stmt); } diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index ca992cefb55..dee5b548716 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -429,13 +429,13 @@ class DeadCodeRemover : BackwardVisitor { //! //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing //! all active expressions in the original graph. - bool isLive(Statement* stmt); + bool isLive(Statement* stmt) const; //! Check whether all outputs of an expression have been marked dead - bool allOutputsDead(Expr* expr); + bool allOutputsDead(Expr* expr) const; //! Check whether all uses have been marked dead - bool allUsesDead(Val* val); + bool allUsesDead(Val* val) const; //! Mark a single Statement as being alive void markLive(Statement* stmt); From b4c237292d0e81690fbaf972e2847a7fb923461b Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 3 Jul 2023 15:02:18 -0400 Subject: [PATCH 32/37] Rework interface: only replaceVal& removeVal used in child classes --- csrc/iter_visitor.cpp | 109 +++++++++++++++-------------- csrc/iter_visitor.h | 99 ++++++++++++++++++-------- csrc/optimization/remove_empty.cpp | 23 +++--- 3 files changed, 137 insertions(+), 94 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 085ced79713..9aef7a90418 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -952,7 +952,7 @@ Fusion* DeadCodeRemover::fusion() const { } void DeadCodeRemover::handle(Statement* stmt) { - if (!isLive(stmt)) { + if (isDead(stmt)) { // We check whether stmt is dead before we dereference it, since it may // have been removed from the Fusion. return; @@ -961,80 +961,80 @@ void DeadCodeRemover::handle(Statement* stmt) { } void DeadCodeRemover::handle(Expr* expr) { - if (allOutputsDead(expr)) { - markDead(expr); - fusion_->removeExpr(expr); - } else { - BackwardVisitor::handle(expr); + if (maybeRemoveExpr(expr)) { + // maybeRemoveExp will remove expr from the Fusion if all its uses are + // marked dead. In that case, we should not continue handling it since the + // expr pointer is invalid. + return; } + BackwardVisitor::handle(expr); } -void DeadCodeRemover::replaceTV(TensorView*& old_tv, TensorView* new_tv) { - if (old_tv->isFusionOutput()) { - fusion_->replaceOutput(old_tv, new_tv); +bool DeadCodeRemover::replaceVal(Val* old_val, Val* new_val) { + if (old_val->isFusionOutput()) { + fusion_->replaceOutput(old_val, new_val); } - for (auto use : old_tv->uses()) { - ir_utils::replaceValInExpr(use, old_tv, new_tv); + for (auto use : old_val->uses()) { + ir_utils::replaceValInExpr(use, old_val, new_val); } - markLiveRecursive(new_tv); - markDeadAndMaybeRemove(old_tv); + return removeVal(old_val); } -//! Guard removeVal with calls to markDead so that we always detect removed -//! Vals and Exprs before derefencing them. -//! -//! Note that we only remove val if all of its definition's outputs are marked -//! dead. -void DeadCodeRemover::markDeadAndMaybeRemove(Val* val) { - markDead(val); +bool DeadCodeRemover::removeVal(Val* val) { + // Mark val dead even if we can't yet remove it due to its definition having + // some live outputs + if (!markDead(val)) { + // val is already marked dead + return false; + } + if (val->definition()) { - if (allOutputsDead(val->definition())) { - // If all other outputs are dead, it's safe to remove the definition as - // well as all its outputs - markDead(val->definition()); - const auto outputs = val->definition()->outputs(); - for (auto outp : outputs) { - fusion_->removeVal(outp); - } - } + // If val has a definition, it can only be removed by removing its + // definition + return maybeRemoveExpr(val->definition()); } else { TORCH_INTERNAL_ASSERT( !val->isFusionInput(), "Refusing to remove Fusion input"); + markDead(val); fusion_->removeVal(val); + return true; } } -//! Find whether a statement is live (i.e. not dead code) -//! -//! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing -//! all active expressions in the original graph. -bool DeadCodeRemover::isLive(Statement* stmt) const { - return live_statements_.find(stmt) != live_statements_.end(); +bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) { + if (allOutputsDead(expr)) { + if (!markDead(expr)) { + // Expr was already marked dead, so don't try to remove it again + return false; + } + + const auto outputs = expr->outputs(); + for (auto outp : outputs) { + fusion_->removeVal(outp); + } + // Fusion::removeVal(val) calls removeExpr on the definition of val. That + // means expr will be removed while processing its first output. Here we + // check that it was properly removed. + TORCH_INTERNAL_ASSERT(!fusion_->inContainer(expr), "Failed to remove Expr"); + return true; + } else { + return false; + } } -//! Check whether all outputs of an expression have been marked dead bool DeadCodeRemover::allOutputsDead(Expr* expr) const { return std::all_of( expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { - return !isLive(outp); + return isDead(outp); }); } -//! Check whether all uses have been marked dead bool DeadCodeRemover::allUsesDead(Val* val) const { return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { - return !isLive(use); + return isDead(use); }); } -//! Mark a single Statement as being alive -void DeadCodeRemover::markLive(Statement* stmt) { - live_statements_.insert(stmt); -} - -//! Ensure that a Statement and its upstream Statements are alive. If it is an -//! Expr, ensure all its inputs are alive. If it's a Val with a definition, -//! recursive to the definition. void DeadCodeRemover::markLiveRecursive(Statement* stmt) { if (isLive(stmt)) { return; @@ -1043,14 +1043,17 @@ void DeadCodeRemover::markLiveRecursive(Statement* stmt) { if (stmt->isVal() && stmt->asVal()->definition()) { markLiveRecursive(stmt); } else { - for (const auto inp : stmt->asExpr()->inputs()) { + auto expr = stmt->asExpr(); + for (const auto inp : expr->outputs()) { + markLive(inp); + } + for (const auto inp : expr->inputs()) { markLiveRecursive(inp); } } } -//! Mark a Statement* as dead so that we avoid dereferencing it later -void DeadCodeRemover::markDead(Statement* stmt) { +bool DeadCodeRemover::markDead(Statement* stmt) { if (stmt->isVal()) { auto val = stmt->asVal(); TORCH_INTERNAL_ASSERT( @@ -1061,8 +1064,12 @@ void DeadCodeRemover::markDead(Statement* stmt) { !val->isFusionInput(), "Call to markDead on Fusion input is illegal: ", val->toString()); + TORCH_INTERNAL_ASSERT( + allUsesDead(val), + "Attempted to remove Val with live uses: ", + val->toString()); } - live_statements_.erase(stmt); + return (bool)live_statements_.erase(stmt); } } // namespace nvfuser diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index dee5b548716..22cdf0334dc 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -373,23 +373,28 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor { }; //! This is a generic traversal class that is used to modify a Fusion graph by -//! replacing TensorViews so that their definitions can be altered to simplify -//! computation or remove dead code. This differs from OptOutMutator, which is -//! built for mutating TensorViews in a graph, and does not easily handle -//! modifying TensorView definitions and Expr Fusion inputs during traversal. +//! replacing Vals to simplify computation or remove dead code. This differs +//! from OptOutMutator, which is built for mutating TensorViews in-place in a +//! graph by altering the associated IterDomains, and which does not easily +//! handle modifying TensorView definitions and Fusion outputs during traversal. //! -//! We use unordered_set called live_statements_, which is initialized as the +//! We use an unordered_set called live_statements_, which is initialized as the //! Exprs in traversal_exprs_ as well as their inputs and their outputs with //! live uses. Marking a Statement as dead removes it from live_statements_, //! and replacing a Val inserts the Val and its definition, recursively. Since //! we traverse backwards, and we handle all active Expr outputs, this ensures -//! that it is safe to removing an Expr will not result in erasing definitions -//! of active Expr outputs. +//! that removing an Expr will not result in erasing definitions of active Expr +//! outputs. //! -//! Derived classes should override handle() and make use of replaceTV(), -//! markDeadAndMaybeRemove(), and allUsesDead(). Note that if replacements are -//! made using replaceTV(old_tv, new_tv), then neither new_tv or any new -//! Statements produced in creating it will be traversed by this class. +//! Derived classes should override handle() for relevant Exprs and they should +//! make use of replaceVal() to change the definitions of Vals in the graph. +//! Note that if replacements are made using replaceVal(old_val, new_val), then +//! neither new_val nor any new Statements produced in creating it will be +//! traversed by this class. +//! +//! removeVal() may also be used in derived classes to explicitly mark tensors +//! as dead. Note that it is an error to call removeVal() on a Val that has live +//! uses, so this should be used carefully. class DeadCodeRemover : BackwardVisitor { public: DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} @@ -412,24 +417,39 @@ class DeadCodeRemover : BackwardVisitor { void handle(Statement* stmt) override; void handle(Expr* expr) override; - //! Replaces a TensorView in outputs, and in all uses. If old_tv is a Fusion - //! input, we do not replace it. After replacement, unless it is a Fusion - //! input, we remove it from the fusion and set the original pointer to zero - //! (hence why old_tv is passed by reference). - void replaceTV(TensorView*& old_tv, TensorView* new_tv); - - //! Guard removeVal with calls to markDead so that we always detect removed - //! Vals and Exprs before derefencing them. + //! Replaces a Val in outputs, and in all uses. + //! + //! The argument old_val is always marked Dead by this method. If old_val is a + //! Fusion input, we do not replace it. If old_val's definition is non-null + //! and has other outputs which are not dead, we do not remove it. //! - //! Note that we only remove val if all of its definition's outputs are marked - //! dead. - void markDeadAndMaybeRemove(Val* val); + //! Returns whether old_val was removed from the Fusion. + bool replaceVal(Val* old_val, Val* new_val); - //! Find whether a statement is live (i.e. not dead code) + //! Remove a Val* from the Fusion, if possible. //! - //! Inside BackwardVisitor::traverseTo, traversal_exprs_ is built, containing - //! all active expressions in the original graph. - bool isLive(Statement* stmt) const; + //! It is an error to call this function on a Val with any live uses, or on + //! any Fusion input or output. + //! + //! The Val is always marked dead by this function. Additionally, it is + //! removed from the Fusion if possible. Removal is possible if the Val has no + //! definition, or if its definition can be removed by removeExpr(), meaning + //! the definition has no other live outputs. + //! + //! Returns whether the Val was removed from the Fusion. + bool removeVal(Val* val); + + //! Find whether a statement is not marked as live code. Note that if this + //! returns true, the pointer may be invalid. + inline bool isDead(Statement* stmt) const { + return live_statements_.find(stmt) == live_statements_.end(); + } + + //! Find whether a statement is marked as live code. Note that if this returns + //! false, the pointer may be invalid. + inline bool isLive(Statement* stmt) const { + return !isDead(stmt); + } //! Check whether all outputs of an expression have been marked dead bool allOutputsDead(Expr* expr) const; @@ -437,16 +457,33 @@ class DeadCodeRemover : BackwardVisitor { //! Check whether all uses have been marked dead bool allUsesDead(Val* val) const; - //! Mark a single Statement as being alive - void markLive(Statement* stmt); + private: + //! Removes an Expr* from the Fusion, if possible. + //! + //! The Expr will _only_ be marked dead and removed if all of its outputs are + //! already marked dead. In this case all the outputs will also be removed + //! from the Fusion. + //! + //! Returns whether the Expr was marked dead and removed from the Fusion. + bool maybeRemoveExpr(Expr* expr); + + //! Mark a single Statement as being alive. + inline void markLive(Statement* stmt) { + live_statements_.insert(stmt); + } //! Ensure that a Statement and its upstream Statements are alive. If it is an //! Expr, ensure all its inputs are alive. If it's a Val with a definition, - //! recursive to the definition. + //! recursive to the definition. Newly-created Statements default to being + //! dead, so this method is called when adding a Statement to the active path + //! of the Fusion inside replaceVal. void markLiveRecursive(Statement* stmt); - //! Mark a Statement* as dead so that we avoid dereferencing it later - void markDead(Statement* stmt); + //! Mark a single Statement as being dead. This does not remove stmt from the + //! Fusion. + //! + //! Returns true if the statement was previously live, and false otherwise. + bool markDead(Statement* stmt); private: //! The Fusion associated with live_statements_ diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 117f87f86d3..d137e55d23c 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -96,13 +96,12 @@ class EmptyTensorRemover : public DeadCodeRemover { } auto dtype = tv->getDataType().value(); auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); - replaceTV(tv, new_tv); + replaceVal(tv, new_tv); } } else if (allUsesDead(tv)) { // TensorViews that are not Fusion inputs or outputs and which have no - // uses are dead, so remove them and skip processing them and their - // definition. - markDeadAndMaybeRemove(tv); + // uses are dead, so remove them. + removeVal(tv); } else { TORCH_INTERNAL_ASSERT( !isTVEmpty(tv), @@ -152,7 +151,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto new_tv = full(noReductionShape(out), rop->init(), out->getDataType().value()); - replaceTV(out, new_tv); + replaceVal(out, new_tv); } //! A WelfordOp is similar to a ReductionOp, but has three outputs: avg, var, @@ -186,24 +185,24 @@ class EmptyTensorRemover : public DeadCodeRemover { // Since WelfordOp has multiple outputs, we need to check whether each is // live before replacing it, to avoid replacing a dead output with a live - // one. Note that replaceTV will mark the replacement as live automatically. + // one. auto shape = noReductionShape(avg); if (isLive(avg)) { auto nan = IrBuilder::create( std::numeric_limits::quiet_NaN(), avg->getDataType().value()); auto nan_tensor = full(shape, nan, avg->getDataType().value()); - replaceTV(avg, nan_tensor); + replaceVal(avg, nan_tensor); } if (isLive(var_sum)) { auto new_var_sum = full( shape, fusion()->zeroVal(var_sum->getDataType().value()), var_sum->getDataType().value()); - replaceTV(var_sum, new_var_sum); + replaceVal(var_sum, new_var_sum); } if (isLive(N)) { auto new_N = full(shape, fusion()->zeroVal(), N->getDataType().value()); - replaceTV(N, new_N); + replaceVal(N, new_N); } } @@ -265,7 +264,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto old_tv = cop->outputs()[0]->as(); // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1 auto new_tv = cat(non_empty_inputs, dim); - replaceTV(old_tv, new_tv); + replaceVal(old_tv, new_tv); } } @@ -281,7 +280,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto shape = noReductionShape(out); auto dtype = out->getDataType().value(); auto new_tv = full(shape, pop->value(), dtype); - replaceTV(out, new_tv); + replaceVal(out, new_tv); } } @@ -297,7 +296,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto shape = noReductionShape(out); auto dtype = out->getDataType().value(); auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); - replaceTV(out, new_tv); + replaceVal(out, new_tv); } } }; From 4373c1b9d5da41618c2140188297f246ab2224de Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 3 Jul 2023 15:12:42 -0400 Subject: [PATCH 33/37] Return vector in emptyAxes --- csrc/optimization/remove_empty.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index d137e55d23c..2bf5984f23b 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -24,12 +24,12 @@ namespace { //! Get a vector of the integer positions of constant zero extent axes in the //! input domain. This will typically be used like //! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` -std::vector emptyAxes(std::vector domain) { - std::vector empty_axes; +std::vector emptyAxes(const std::vector& domain) { + std::vector empty_axes; for (auto ax : c10::irange(domain.size())) { auto id = domain.at(ax); if (id->extent()->isConst() && id->extent()->evaluateInt() == 0) { - empty_axes.push_back(ax); + empty_axes.push_back((int64_t)ax); } } return empty_axes; From 0a1d21f1190e3c61c6a78a1350db64c2ca3e3181 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 3 Jul 2023 16:02:01 -0400 Subject: [PATCH 34/37] Defer removal of Vals and Exprs until after traversal --- csrc/iter_visitor.cpp | 31 +++++++++++++++++++++++++------ csrc/iter_visitor.h | 33 ++++++++++++++++++++++++++++----- 2 files changed, 53 insertions(+), 11 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 9aef7a90418..d5675af7a22 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -945,6 +945,11 @@ void DeadCodeRemover::run() { // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. traverseTo(fusion_, fusion_->outputs()); + + // We do not remove Statements from the Fusion while traversing, to avoid + // dereferencing invalid pointers. Instead, we wait until this point to do the + // removal. + doRemoval(); } Fusion* DeadCodeRemover::fusion() const { @@ -996,7 +1001,7 @@ bool DeadCodeRemover::removeVal(Val* val) { TORCH_INTERNAL_ASSERT( !val->isFusionInput(), "Refusing to remove Fusion input"); markDead(val); - fusion_->removeVal(val); + registerRemoval(val); return true; } } @@ -1010,12 +1015,9 @@ bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) { const auto outputs = expr->outputs(); for (auto outp : outputs) { - fusion_->removeVal(outp); + registerRemoval(outp); } - // Fusion::removeVal(val) calls removeExpr on the definition of val. That - // means expr will be removed while processing its first output. Here we - // check that it was properly removed. - TORCH_INTERNAL_ASSERT(!fusion_->inContainer(expr), "Failed to remove Expr"); + registerRemoval(expr); return true; } else { return false; @@ -1072,4 +1074,21 @@ bool DeadCodeRemover::markDead(Statement* stmt) { return (bool)live_statements_.erase(stmt); } +void DeadCodeRemover::doRemoval() const { + for (auto val : vals_to_remove_) { + fusion_->removeVal(val); + } + for (auto expr : exprs_to_remove_) { + // Fusion::removeVal(val) actually removes val->definition() from the + // Fusion, and sets all its outputs' definitions to nullptr. So we should + // not need to manually remove Exprs here. Instead, we just assert that they + // have already been removed if they were registered for removal. + TORCH_INTERNAL_ASSERT( + !fusion_->inContainer(expr), + "Expression ", + expr->toString(), + " was marked for removal but has not yet been removed."); + } +} + } // namespace nvfuser diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 22cdf0334dc..5a44c9e3f46 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -423,7 +423,7 @@ class DeadCodeRemover : BackwardVisitor { //! Fusion input, we do not replace it. If old_val's definition is non-null //! and has other outputs which are not dead, we do not remove it. //! - //! Returns whether old_val was removed from the Fusion. + //! Returns whether old_val was registered for removal from the Fusion. bool replaceVal(Val* old_val, Val* new_val); //! Remove a Val* from the Fusion, if possible. @@ -432,11 +432,11 @@ class DeadCodeRemover : BackwardVisitor { //! any Fusion input or output. //! //! The Val is always marked dead by this function. Additionally, it is - //! removed from the Fusion if possible. Removal is possible if the Val has no - //! definition, or if its definition can be removed by removeExpr(), meaning - //! the definition has no other live outputs. + //! registered for removal from the Fusion if possible. Removal is possible if + //! the Val has no definition, or if its definition can be removed by + //! removeExpr(), meaning the definition has no other live outputs. //! - //! Returns whether the Val was removed from the Fusion. + //! Returns whether the Val was registered for removal from the Fusion. bool removeVal(Val* val); //! Find whether a statement is not marked as live code. Note that if this @@ -485,12 +485,35 @@ class DeadCodeRemover : BackwardVisitor { //! Returns true if the statement was previously live, and false otherwise. bool markDead(Statement* stmt); + //! Register a Val for later removal. + inline void registerRemoval(Val* val) { + vals_to_remove_.push_back(val); + } + + //! Register an Expr for later removal. + //! + //! Note that if any of its outputs are removed, expr will be removed even if + //! it is not marked for removal, and all its outputs will have their + //! definitions set to nullptr. + inline void registerRemoval(Expr* expr) { + exprs_to_remove_.push_back(expr); + } + + //! Actually remove Statements that were previously registered. For safety, + //! this should be run after traversing the graph. + void doRemoval() const; + private: //! The Fusion associated with live_statements_ Fusion* fusion_; //! Statements are marked dead by removing them from this set std::unordered_set live_statements_; + + //! Statements that will be removed. We remove Vals before Exprs, so we track + //! them separately here. + std::vector vals_to_remove_; + std::vector exprs_to_remove_; }; } // namespace nvfuser From fa7aed10e1c9482108d5ac6aac4b5cbf12e03f63 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 6 Jul 2023 12:27:52 -0400 Subject: [PATCH 35/37] Defer modifying Fusion at all until after traversal. This renames doRemoval() to modifyFusion() and places ir_utils::replaceValInExpr() there. It performs these replacements in the same order as it receives them. Note that some other cleanup was also done in this commit. --- csrc/iter_visitor.cpp | 114 ++++++++++++++--------------- csrc/iter_visitor.h | 82 ++++++++++----------- csrc/optimization/remove_empty.cpp | 46 ++++++------ 3 files changed, 116 insertions(+), 126 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index d5675af7a22..6ac2b385d42 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -944,16 +944,12 @@ void DeadCodeRemover::run() { // Note that StmtSort::getStmts() is also run in traverseTo. In the future, // we could potentially refactor this so that derived classes from // BackwardVisitor can make use of that traversal instead of repeating it. - traverseTo(fusion_, fusion_->outputs()); + traverseTo(fusion_, fusion_->outputs(), false); // We do not remove Statements from the Fusion while traversing, to avoid // dereferencing invalid pointers. Instead, we wait until this point to do the // removal. - doRemoval(); -} - -Fusion* DeadCodeRemover::fusion() const { - return fusion_; + modifyFusion(); } void DeadCodeRemover::handle(Statement* stmt) { @@ -975,35 +971,47 @@ void DeadCodeRemover::handle(Expr* expr) { BackwardVisitor::handle(expr); } -bool DeadCodeRemover::replaceVal(Val* old_val, Val* new_val) { - if (old_val->isFusionOutput()) { - fusion_->replaceOutput(old_val, new_val); - } - for (auto use : old_val->uses()) { - ir_utils::replaceValInExpr(use, old_val, new_val); +void DeadCodeRemover::handle(TensorView* tv) { + if (!tv->isFusionOutput() && !tv->isFusionInput() && allUsesDead(tv)) { + if (!markDead(tv)) { + return; + } + + if (tv->definition()) { + // If tv has a definition, it can only be removed by removing its + // definition + maybeRemoveExpr(tv->definition()); + } else { + registerRemoval(tv); + } + return; } - return removeVal(old_val); + BackwardVisitor::handle(tv); } -bool DeadCodeRemover::removeVal(Val* val) { - // Mark val dead even if we can't yet remove it due to its definition having - // some live outputs - if (!markDead(val)) { - // val is already marked dead +bool DeadCodeRemover::replaceVal(Val* old_val, Val* new_val) { + registerReplacement(old_val, new_val); + if (old_val->isFusionInput()) { + // Skip removing Fusion inputs return false; } - - if (val->definition()) { - // If val has a definition, it can only be removed by removing its - // definition - return maybeRemoveExpr(val->definition()); - } else { - TORCH_INTERNAL_ASSERT( - !val->isFusionInput(), "Refusing to remove Fusion input"); - markDead(val); - registerRemoval(val); - return true; - } + TORCH_CHECK( + old_val->definition(), + "Found non-input ", + old_val->toString(), + " with no definition."); + + // Mark old_val dead even if we can't yet remove it due to its definition + // having some live outputs + TORCH_CHECK( + markDead(old_val), + "Attempted to replace ", + old_val->toString(), + " which was previously marked dead."); + + // If old_val has a definition, it can only be removed by removing its + // definition + return maybeRemoveExpr(old_val->definition()); } bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) { @@ -1024,17 +1032,12 @@ bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) { } } -bool DeadCodeRemover::allOutputsDead(Expr* expr) const { - return std::all_of( - expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { - return isDead(outp); - }); -} - -bool DeadCodeRemover::allUsesDead(Val* val) const { - return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { - return isDead(use); - }); +void DeadCodeRemover::registerRemoval(Val* val) { + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), + "Call to registerRemoval on Fusion input is illegal: ", + val->toString()); + vals_to_remove_.push_back(val); } void DeadCodeRemover::markLiveRecursive(Statement* stmt) { @@ -1056,33 +1059,26 @@ void DeadCodeRemover::markLiveRecursive(Statement* stmt) { } bool DeadCodeRemover::markDead(Statement* stmt) { - if (stmt->isVal()) { - auto val = stmt->asVal(); - TORCH_INTERNAL_ASSERT( - !val->isFusionOutput(), - "Call to markDead on Fusion output is illegal: ", - val->toString()); - TORCH_INTERNAL_ASSERT( - !val->isFusionInput(), - "Call to markDead on Fusion input is illegal: ", - val->toString()); - TORCH_INTERNAL_ASSERT( - allUsesDead(val), - "Attempted to remove Val with live uses: ", - val->toString()); - } return (bool)live_statements_.erase(stmt); } -void DeadCodeRemover::doRemoval() const { +void DeadCodeRemover::modifyFusion() const { + for (auto [old_val, new_val] : vals_to_replace_) { + if (old_val->isFusionOutput()) { + fusion_->replaceOutput(old_val, new_val); + } + for (auto use : old_val->uses()) { + ir_utils::replaceValInExpr(use, old_val, new_val); + } + } for (auto val : vals_to_remove_) { fusion_->removeVal(val); } for (auto expr : exprs_to_remove_) { // Fusion::removeVal(val) actually removes val->definition() from the // Fusion, and sets all its outputs' definitions to nullptr. So we should - // not need to manually remove Exprs here. Instead, we just assert that they - // have already been removed if they were registered for removal. + // not need to manually remove Exprs here. Instead, we just assert that + // they have already been removed if they were registered for removal. TORCH_INTERNAL_ASSERT( !fusion_->inContainer(expr), "Expression ", diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index 5a44c9e3f46..cab743439c9 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -19,9 +20,6 @@ namespace nvfuser { class Fusion; -class Statement; -class Expr; -class Val; /* * IterVisitor starts from leaf nodes, fusion outputs, or the provided values. @@ -378,23 +376,12 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor { //! graph by altering the associated IterDomains, and which does not easily //! handle modifying TensorView definitions and Fusion outputs during traversal. //! -//! We use an unordered_set called live_statements_, which is initialized as the -//! Exprs in traversal_exprs_ as well as their inputs and their outputs with -//! live uses. Marking a Statement as dead removes it from live_statements_, -//! and replacing a Val inserts the Val and its definition, recursively. Since -//! we traverse backwards, and we handle all active Expr outputs, this ensures -//! that removing an Expr will not result in erasing definitions of active Expr -//! outputs. -//! //! Derived classes should override handle() for relevant Exprs and they should //! make use of replaceVal() to change the definitions of Vals in the graph. //! Note that if replacements are made using replaceVal(old_val, new_val), then //! neither new_val nor any new Statements produced in creating it will be -//! traversed by this class. -//! -//! removeVal() may also be used in derived classes to explicitly mark tensors -//! as dead. Note that it is an error to call removeVal() on a Val that has live -//! uses, so this should be used carefully. +//! traversed by this class. Also note that any Vals or Exprs that are +//! previously marked dead will not be processed by handle(). class DeadCodeRemover : BackwardVisitor { public: DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} @@ -409,7 +396,9 @@ class DeadCodeRemover : BackwardVisitor { //! always traverse from outputs backward to their inputs. void run(); - Fusion* fusion() const; + inline Fusion* fusion() const { + return fusion_; + } protected: using BackwardVisitor::handle; @@ -417,45 +406,45 @@ class DeadCodeRemover : BackwardVisitor { void handle(Statement* stmt) override; void handle(Expr* expr) override; + //! We implement this in order to remove dangling TensorViews whose uses are + //! all dead. Note that we do not remove other ValTypes like Scalars since + //! they might still be used as attributes or members of other objects, which + //! is not reflected by Val::uses(). + void handle(TensorView* tv) override; + //! Replaces a Val in outputs, and in all uses. //! //! The argument old_val is always marked Dead by this method. If old_val is a //! Fusion input, we do not replace it. If old_val's definition is non-null - //! and has other outputs which are not dead, we do not remove it. + //! and has other outputs which are not dead, we do not remove old_val. //! //! Returns whether old_val was registered for removal from the Fusion. bool replaceVal(Val* old_val, Val* new_val); - //! Remove a Val* from the Fusion, if possible. - //! - //! It is an error to call this function on a Val with any live uses, or on - //! any Fusion input or output. - //! - //! The Val is always marked dead by this function. Additionally, it is - //! registered for removal from the Fusion if possible. Removal is possible if - //! the Val has no definition, or if its definition can be removed by - //! removeExpr(), meaning the definition has no other live outputs. - //! - //! Returns whether the Val was registered for removal from the Fusion. - bool removeVal(Val* val); - - //! Find whether a statement is not marked as live code. Note that if this - //! returns true, the pointer may be invalid. + //! Find whether a statement is not marked as live code. inline bool isDead(Statement* stmt) const { return live_statements_.find(stmt) == live_statements_.end(); } - //! Find whether a statement is marked as live code. Note that if this returns - //! false, the pointer may be invalid. + //! Find whether a statement is marked as live code. inline bool isLive(Statement* stmt) const { return !isDead(stmt); } //! Check whether all outputs of an expression have been marked dead - bool allOutputsDead(Expr* expr) const; + inline bool allOutputsDead(Expr* expr) const { + return std::all_of( + expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { + return isDead(outp); + }); + } //! Check whether all uses have been marked dead - bool allUsesDead(Val* val) const; + inline bool allUsesDead(Val* val) const { + return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { + return isDead(use); + }); + } private: //! Removes an Expr* from the Fusion, if possible. @@ -480,14 +469,17 @@ class DeadCodeRemover : BackwardVisitor { void markLiveRecursive(Statement* stmt); //! Mark a single Statement as being dead. This does not remove stmt from the - //! Fusion. + //! Fusion. It is an error to call this on a Fusion output. //! //! Returns true if the statement was previously live, and false otherwise. bool markDead(Statement* stmt); //! Register a Val for later removal. - inline void registerRemoval(Val* val) { - vals_to_remove_.push_back(val); + void registerRemoval(Val* val); + + //! Register a Val for later replacement + inline void registerReplacement(Val* old_val, Val* new_val) { + vals_to_replace_.emplace_back(old_val, new_val); } //! Register an Expr for later removal. @@ -499,9 +491,10 @@ class DeadCodeRemover : BackwardVisitor { exprs_to_remove_.push_back(expr); } - //! Actually remove Statements that were previously registered. For safety, - //! this should be run after traversing the graph. - void doRemoval() const; + //! All modifications to the Fusion are registered during traversal then + //! later they are committed by this method. For safety, this should only be + //! run after traversing the graph. + void modifyFusion() const; private: //! The Fusion associated with live_statements_ @@ -510,6 +503,9 @@ class DeadCodeRemover : BackwardVisitor { //! Statements are marked dead by removing them from this set std::unordered_set live_statements_; + //! Vals to be replaced in outputs and with replaceValInExpr in all uses. + std::vector> vals_to_replace_; + //! Statements that will be removed. We remove Vals before Exprs, so we track //! them separately here. std::vector vals_to_remove_; diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 2bf5984f23b..9a94eecb3bd 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -76,37 +76,35 @@ class EmptyTensorRemover : public DeadCodeRemover { //! it with full(). For non-outputs that are not inputs, we simply check that //! the tensor is not provably empty. void handle(TensorView* tv) final { - if (tv->isFusionInput()) { - // Skip inputs since they do not have a definition to redefine + DeadCodeRemover::handle(tv); + if (isDead(tv)) { + // DeadCodeRemover::handle might have set this dead, in which case we + // don't need to process it any further return; } - if (tv->isFusionOutput()) { - const auto rfactor = - TensorDomain::noReductions(tv->getMaybeRFactorDomain()); - const auto empty_axes = emptyAxes(rfactor); - if (!empty_axes.empty()) { - std::vector shape(rfactor.size()); - std::transform( - rfactor.begin(), rfactor.end(), shape.begin(), [](IterDomain* id) { - return id->extent(); - }); - for (auto ax : empty_axes) { - shape[ax] = fusion()->zeroVal(); - } - auto dtype = tv->getDataType().value(); - auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); - replaceVal(tv, new_tv); + if (isTVEmpty(tv)) { + if (tv->isFusionInput()) { + TORCH_INTERNAL_ASSERT( + allUsesDead(tv), + "Empty Fusion input ", + tv, + " should not have any live uses."); + // Empty inputs do not have a definition to redefine + return; } - } else if (allUsesDead(tv)) { - // TensorViews that are not Fusion inputs or outputs and which have no - // uses are dead, so remove them. - removeVal(tv); - } else { + + // Any non-input that we traverse to should be the input to an expression, + // or a Fusion output. If it's the input to an expression, we should have + // replaced that expression by handling the appropriate Expr subclass. TORCH_INTERNAL_ASSERT( - !isTVEmpty(tv), + tv->isFusionOutput(), "Found unexpected empty intermediate TensorView ", tv->toString()); + auto shape = noReductionShape(tv); + auto dtype = tv->getDataType().value(); + auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); + replaceVal(tv, new_tv); } } From 7ed25be8511a36ab67e1ea943a84368274f178e5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 6 Jul 2023 15:31:12 -0400 Subject: [PATCH 36/37] Rename replaceVal as registerReplacement. --- csrc/iter_visitor.cpp | 5 +++-- csrc/iter_visitor.h | 25 ++++++++++++------------- csrc/optimization/remove_empty.cpp | 16 ++++++++-------- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 6ac2b385d42..24e6640b2f5 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -989,8 +989,9 @@ void DeadCodeRemover::handle(TensorView* tv) { BackwardVisitor::handle(tv); } -bool DeadCodeRemover::replaceVal(Val* old_val, Val* new_val) { - registerReplacement(old_val, new_val); +bool DeadCodeRemover::registerReplacement(Val* old_val, Val* new_val) { + vals_to_replace_.emplace_back(old_val, new_val); + if (old_val->isFusionInput()) { // Skip removing Fusion inputs return false; diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index cab743439c9..fcee5b64665 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -377,11 +377,11 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor { //! handle modifying TensorView definitions and Fusion outputs during traversal. //! //! Derived classes should override handle() for relevant Exprs and they should -//! make use of replaceVal() to change the definitions of Vals in the graph. -//! Note that if replacements are made using replaceVal(old_val, new_val), then -//! neither new_val nor any new Statements produced in creating it will be -//! traversed by this class. Also note that any Vals or Exprs that are -//! previously marked dead will not be processed by handle(). +//! make use of registerReplacement() to change the definitions of Vals in the +//! graph. Note that if replacements are made using registerReplacement(old_val, +//! new_val), then neither new_val nor any new Statements produced in creating +//! it will be traversed by this class. Also note that any Vals or Exprs that +//! are previously marked dead will not be processed by handle(). class DeadCodeRemover : BackwardVisitor { public: DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} @@ -412,14 +412,18 @@ class DeadCodeRemover : BackwardVisitor { //! is not reflected by Val::uses(). void handle(TensorView* tv) override; - //! Replaces a Val in outputs, and in all uses. + //! Registers a Val for replacement in outputs and in all its uses. + //! + //! Note that replacement does not occur immediately, but will be done after + //! the traversal is completed. This is so that any Val* and Expr* pointers + //! may be safely dereferenced during traversal. //! //! The argument old_val is always marked Dead by this method. If old_val is a //! Fusion input, we do not replace it. If old_val's definition is non-null //! and has other outputs which are not dead, we do not remove old_val. //! //! Returns whether old_val was registered for removal from the Fusion. - bool replaceVal(Val* old_val, Val* new_val); + bool registerReplacement(Val* old_val, Val* new_val); //! Find whether a statement is not marked as live code. inline bool isDead(Statement* stmt) const { @@ -465,7 +469,7 @@ class DeadCodeRemover : BackwardVisitor { //! Expr, ensure all its inputs are alive. If it's a Val with a definition, //! recursive to the definition. Newly-created Statements default to being //! dead, so this method is called when adding a Statement to the active path - //! of the Fusion inside replaceVal. + //! of the Fusion inside registerReplacement. void markLiveRecursive(Statement* stmt); //! Mark a single Statement as being dead. This does not remove stmt from the @@ -477,11 +481,6 @@ class DeadCodeRemover : BackwardVisitor { //! Register a Val for later removal. void registerRemoval(Val* val); - //! Register a Val for later replacement - inline void registerReplacement(Val* old_val, Val* new_val) { - vals_to_replace_.emplace_back(old_val, new_val); - } - //! Register an Expr for later removal. //! //! Note that if any of its outputs are removed, expr will be removed even if diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp index 9a94eecb3bd..b133422d8ad 100644 --- a/csrc/optimization/remove_empty.cpp +++ b/csrc/optimization/remove_empty.cpp @@ -104,7 +104,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto shape = noReductionShape(tv); auto dtype = tv->getDataType().value(); auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); - replaceVal(tv, new_tv); + registerReplacement(tv, new_tv); } } @@ -149,7 +149,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto new_tv = full(noReductionShape(out), rop->init(), out->getDataType().value()); - replaceVal(out, new_tv); + registerReplacement(out, new_tv); } //! A WelfordOp is similar to a ReductionOp, but has three outputs: avg, var, @@ -189,18 +189,18 @@ class EmptyTensorRemover : public DeadCodeRemover { auto nan = IrBuilder::create( std::numeric_limits::quiet_NaN(), avg->getDataType().value()); auto nan_tensor = full(shape, nan, avg->getDataType().value()); - replaceVal(avg, nan_tensor); + registerReplacement(avg, nan_tensor); } if (isLive(var_sum)) { auto new_var_sum = full( shape, fusion()->zeroVal(var_sum->getDataType().value()), var_sum->getDataType().value()); - replaceVal(var_sum, new_var_sum); + registerReplacement(var_sum, new_var_sum); } if (isLive(N)) { auto new_N = full(shape, fusion()->zeroVal(), N->getDataType().value()); - replaceVal(N, new_N); + registerReplacement(N, new_N); } } @@ -262,7 +262,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto old_tv = cop->outputs()[0]->as(); // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1 auto new_tv = cat(non_empty_inputs, dim); - replaceVal(old_tv, new_tv); + registerReplacement(old_tv, new_tv); } } @@ -278,7 +278,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto shape = noReductionShape(out); auto dtype = out->getDataType().value(); auto new_tv = full(shape, pop->value(), dtype); - replaceVal(out, new_tv); + registerReplacement(out, new_tv); } } @@ -294,7 +294,7 @@ class EmptyTensorRemover : public DeadCodeRemover { auto shape = noReductionShape(out); auto dtype = out->getDataType().value(); auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); - replaceVal(out, new_tv); + registerReplacement(out, new_tv); } } }; From 96099220393b962e5c0e7190598e38297f2a2dd4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 7 Jul 2023 08:20:46 -0400 Subject: [PATCH 37/37] Return bool from run() and modifyFusion() This just lets us return whether or not any modification was performed, which can give us a termination criterion in optimization. --- csrc/iter_visitor.cpp | 10 +++++++--- csrc/iter_visitor.h | 8 ++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 24e6640b2f5..e71c86fda6e 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -934,7 +934,7 @@ std::vector InputsOf::outputs( } /* DEAD CODE REMOVER */ -void DeadCodeRemover::run() { +bool DeadCodeRemover::run() { // First we build a set of all live Statements so that we can detect dead // branches. for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { @@ -949,7 +949,7 @@ void DeadCodeRemover::run() { // We do not remove Statements from the Fusion while traversing, to avoid // dereferencing invalid pointers. Instead, we wait until this point to do the // removal. - modifyFusion(); + return modifyFusion(); } void DeadCodeRemover::handle(Statement* stmt) { @@ -1063,7 +1063,8 @@ bool DeadCodeRemover::markDead(Statement* stmt) { return (bool)live_statements_.erase(stmt); } -void DeadCodeRemover::modifyFusion() const { +bool DeadCodeRemover::modifyFusion() const { + bool modified_fusion = false; for (auto [old_val, new_val] : vals_to_replace_) { if (old_val->isFusionOutput()) { fusion_->replaceOutput(old_val, new_val); @@ -1071,9 +1072,11 @@ void DeadCodeRemover::modifyFusion() const { for (auto use : old_val->uses()) { ir_utils::replaceValInExpr(use, old_val, new_val); } + modified_fusion = true; } for (auto val : vals_to_remove_) { fusion_->removeVal(val); + modified_fusion = true; } for (auto expr : exprs_to_remove_) { // Fusion::removeVal(val) actually removes val->definition() from the @@ -1086,6 +1089,7 @@ void DeadCodeRemover::modifyFusion() const { expr->toString(), " was marked for removal but has not yet been removed."); } + return modified_fusion; } } // namespace nvfuser diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index fcee5b64665..425463cbc36 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -394,7 +394,9 @@ class DeadCodeRemover : BackwardVisitor { //! Instead of traverseTo, run() is the entry point for this class, and we //! always traverse from outputs backward to their inputs. - void run(); + //! + //! Returns a bool indicating whether the Fusion was modified or not. + bool run(); inline Fusion* fusion() const { return fusion_; @@ -493,7 +495,9 @@ class DeadCodeRemover : BackwardVisitor { //! All modifications to the Fusion are registered during traversal then //! later they are committed by this method. For safety, this should only be //! run after traversing the graph. - void modifyFusion() const; + //! + //! Returns a bool indicating whether any modifications were performed. + bool modifyFusion() const; private: //! The Fusion associated with live_statements_