diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a747af5170..366262da237 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,6 +185,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp + ${NVFUSER_SRCS_DIR}/optimization/remove_empty.cpp ) if(BUILD_PYTHON) diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 33a5f0f4b05..e71c86fda6e 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -933,4 +933,163 @@ std::vector InputsOf::outputs( return io.ordered_inputs; } +/* DEAD CODE REMOVER */ +bool DeadCodeRemover::run() { + // First we build a set of all live Statements so that we can detect dead + // branches. + for (auto stmt : StmtSort::getStmts(fusion_, fusion_->outputs())) { + markLive(stmt); + } + + // Note that StmtSort::getStmts() is also run in traverseTo. In the future, + // we could potentially refactor this so that derived classes from + // BackwardVisitor can make use of that traversal instead of repeating it. + traverseTo(fusion_, fusion_->outputs(), false); + + // We do not remove Statements from the Fusion while traversing, to avoid + // dereferencing invalid pointers. Instead, we wait until this point to do the + // removal. + return modifyFusion(); +} + +void DeadCodeRemover::handle(Statement* stmt) { + if (isDead(stmt)) { + // We check whether stmt is dead before we dereference it, since it may + // have been removed from the Fusion. + return; + } + BackwardVisitor::handle(stmt); +} + +void DeadCodeRemover::handle(Expr* expr) { + if (maybeRemoveExpr(expr)) { + // maybeRemoveExp will remove expr from the Fusion if all its uses are + // marked dead. In that case, we should not continue handling it since the + // expr pointer is invalid. + return; + } + BackwardVisitor::handle(expr); +} + +void DeadCodeRemover::handle(TensorView* tv) { + if (!tv->isFusionOutput() && !tv->isFusionInput() && allUsesDead(tv)) { + if (!markDead(tv)) { + return; + } + + if (tv->definition()) { + // If tv has a definition, it can only be removed by removing its + // definition + maybeRemoveExpr(tv->definition()); + } else { + registerRemoval(tv); + } + return; + } + BackwardVisitor::handle(tv); +} + +bool DeadCodeRemover::registerReplacement(Val* old_val, Val* new_val) { + vals_to_replace_.emplace_back(old_val, new_val); + + if (old_val->isFusionInput()) { + // Skip removing Fusion inputs + return false; + } + TORCH_CHECK( + old_val->definition(), + "Found non-input ", + old_val->toString(), + " with no definition."); + + // Mark old_val dead even if we can't yet remove it due to its definition + // having some live outputs + TORCH_CHECK( + markDead(old_val), + "Attempted to replace ", + old_val->toString(), + " which was previously marked dead."); + + // If old_val has a definition, it can only be removed by removing its + // definition + return maybeRemoveExpr(old_val->definition()); +} + +bool DeadCodeRemover::maybeRemoveExpr(Expr* expr) { + if (allOutputsDead(expr)) { + if (!markDead(expr)) { + // Expr was already marked dead, so don't try to remove it again + return false; + } + + const auto outputs = expr->outputs(); + for (auto outp : outputs) { + registerRemoval(outp); + } + registerRemoval(expr); + return true; + } else { + return false; + } +} + +void DeadCodeRemover::registerRemoval(Val* val) { + TORCH_INTERNAL_ASSERT( + !val->isFusionInput(), + "Call to registerRemoval on Fusion input is illegal: ", + val->toString()); + vals_to_remove_.push_back(val); +} + +void DeadCodeRemover::markLiveRecursive(Statement* stmt) { + if (isLive(stmt)) { + return; + } + markLive(stmt); + if (stmt->isVal() && stmt->asVal()->definition()) { + markLiveRecursive(stmt); + } else { + auto expr = stmt->asExpr(); + for (const auto inp : expr->outputs()) { + markLive(inp); + } + for (const auto inp : expr->inputs()) { + markLiveRecursive(inp); + } + } +} + +bool DeadCodeRemover::markDead(Statement* stmt) { + return (bool)live_statements_.erase(stmt); +} + +bool DeadCodeRemover::modifyFusion() const { + bool modified_fusion = false; + for (auto [old_val, new_val] : vals_to_replace_) { + if (old_val->isFusionOutput()) { + fusion_->replaceOutput(old_val, new_val); + } + for (auto use : old_val->uses()) { + ir_utils::replaceValInExpr(use, old_val, new_val); + } + modified_fusion = true; + } + for (auto val : vals_to_remove_) { + fusion_->removeVal(val); + modified_fusion = true; + } + for (auto expr : exprs_to_remove_) { + // Fusion::removeVal(val) actually removes val->definition() from the + // Fusion, and sets all its outputs' definitions to nullptr. So we should + // not need to manually remove Exprs here. Instead, we just assert that + // they have already been removed if they were registered for removal. + TORCH_INTERNAL_ASSERT( + !fusion_->inContainer(expr), + "Expression ", + expr->toString(), + " was marked for removal but has not yet been removed."); + } + return modified_fusion; +} + } // namespace nvfuser diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index b042766750a..425463cbc36 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -19,9 +20,6 @@ namespace nvfuser { class Fusion; -class Statement; -class Expr; -class Val; /* * IterVisitor starts from leaf nodes, fusion outputs, or the provided values. @@ -372,4 +370,149 @@ class TORCH_CUDA_CU_API InputsOf : public IterVisitor { const std::vector& outputs_); }; +//! This is a generic traversal class that is used to modify a Fusion graph by +//! replacing Vals to simplify computation or remove dead code. This differs +//! from OptOutMutator, which is built for mutating TensorViews in-place in a +//! graph by altering the associated IterDomains, and which does not easily +//! handle modifying TensorView definitions and Fusion outputs during traversal. +//! +//! Derived classes should override handle() for relevant Exprs and they should +//! make use of registerReplacement() to change the definitions of Vals in the +//! graph. Note that if replacements are made using registerReplacement(old_val, +//! new_val), then neither new_val nor any new Statements produced in creating +//! it will be traversed by this class. Also note that any Vals or Exprs that +//! are previously marked dead will not be processed by handle(). +class DeadCodeRemover : BackwardVisitor { + public: + DeadCodeRemover(Fusion* fusion) : BackwardVisitor(false), fusion_(fusion) {} + + DeadCodeRemover(const DeadCodeRemover& other) = default; + DeadCodeRemover& operator=(const DeadCodeRemover& other) = default; + + DeadCodeRemover(DeadCodeRemover&& other) = default; + DeadCodeRemover& operator=(DeadCodeRemover&& other) = default; + + //! Instead of traverseTo, run() is the entry point for this class, and we + //! always traverse from outputs backward to their inputs. + //! + //! Returns a bool indicating whether the Fusion was modified or not. + bool run(); + + inline Fusion* fusion() const { + return fusion_; + } + + protected: + using BackwardVisitor::handle; + + void handle(Statement* stmt) override; + void handle(Expr* expr) override; + + //! We implement this in order to remove dangling TensorViews whose uses are + //! all dead. Note that we do not remove other ValTypes like Scalars since + //! they might still be used as attributes or members of other objects, which + //! is not reflected by Val::uses(). + void handle(TensorView* tv) override; + + //! Registers a Val for replacement in outputs and in all its uses. + //! + //! Note that replacement does not occur immediately, but will be done after + //! the traversal is completed. This is so that any Val* and Expr* pointers + //! may be safely dereferenced during traversal. + //! + //! The argument old_val is always marked Dead by this method. If old_val is a + //! Fusion input, we do not replace it. If old_val's definition is non-null + //! and has other outputs which are not dead, we do not remove old_val. + //! + //! Returns whether old_val was registered for removal from the Fusion. + bool registerReplacement(Val* old_val, Val* new_val); + + //! Find whether a statement is not marked as live code. + inline bool isDead(Statement* stmt) const { + return live_statements_.find(stmt) == live_statements_.end(); + } + + //! Find whether a statement is marked as live code. + inline bool isLive(Statement* stmt) const { + return !isDead(stmt); + } + + //! Check whether all outputs of an expression have been marked dead + inline bool allOutputsDead(Expr* expr) const { + return std::all_of( + expr->outputs().begin(), expr->outputs().end(), [&](Val* outp) { + return isDead(outp); + }); + } + + //! Check whether all uses have been marked dead + inline bool allUsesDead(Val* val) const { + return std::all_of(val->uses().begin(), val->uses().end(), [&](Expr* use) { + return isDead(use); + }); + } + + private: + //! Removes an Expr* from the Fusion, if possible. + //! + //! The Expr will _only_ be marked dead and removed if all of its outputs are + //! already marked dead. In this case all the outputs will also be removed + //! from the Fusion. + //! + //! Returns whether the Expr was marked dead and removed from the Fusion. + bool maybeRemoveExpr(Expr* expr); + + //! Mark a single Statement as being alive. + inline void markLive(Statement* stmt) { + live_statements_.insert(stmt); + } + + //! Ensure that a Statement and its upstream Statements are alive. If it is an + //! Expr, ensure all its inputs are alive. If it's a Val with a definition, + //! recursive to the definition. Newly-created Statements default to being + //! dead, so this method is called when adding a Statement to the active path + //! of the Fusion inside registerReplacement. + void markLiveRecursive(Statement* stmt); + + //! Mark a single Statement as being dead. This does not remove stmt from the + //! Fusion. It is an error to call this on a Fusion output. + //! + //! Returns true if the statement was previously live, and false otherwise. + bool markDead(Statement* stmt); + + //! Register a Val for later removal. + void registerRemoval(Val* val); + + //! Register an Expr for later removal. + //! + //! Note that if any of its outputs are removed, expr will be removed even if + //! it is not marked for removal, and all its outputs will have their + //! definitions set to nullptr. + inline void registerRemoval(Expr* expr) { + exprs_to_remove_.push_back(expr); + } + + //! All modifications to the Fusion are registered during traversal then + //! later they are committed by this method. For safety, this should only be + //! run after traversing the graph. + //! + //! Returns a bool indicating whether any modifications were performed. + bool modifyFusion() const; + + private: + //! The Fusion associated with live_statements_ + Fusion* fusion_; + + //! Statements are marked dead by removing them from this set + std::unordered_set live_statements_; + + //! Vals to be replaced in outputs and with replaceValInExpr in all uses. + std::vector> vals_to_replace_; + + //! Statements that will be removed. We remove Vals before Exprs, so we track + //! them separately here. + std::vector vals_to_remove_; + std::vector exprs_to_remove_; +}; + } // namespace nvfuser diff --git a/csrc/kernel_cache.cpp b/csrc/kernel_cache.cpp index 6600ebc1d6c..1c6307a8250 100644 --- a/csrc/kernel_cache.cpp +++ b/csrc/kernel_cache.cpp @@ -673,6 +673,12 @@ FusionKernelRuntime::FusionKernelRuntime( optimization::OptimizationPass::runPass( fusion.get()); + if (isDebugDumpEnabled(DebugDumpOption::FusionIrPreseg)) { + std::cout << "Fusion IR after pre-segmenter optimization passes:" + << std::endl; + fusion->printMath(); + } + all_tvs_ = ir_utils::allTvs(fusion.get()); // Run segmentation on the copied fusion diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index 30bdf87bf34..b9a058daff0 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -9,10 +9,13 @@ #include #include +#include namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { + // Replace TensorViews with zero extent. Outputs and inputs may still be empty + OptimizationPass::runPass(fusion); // removes consecutive cast operations OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/optimization/remove_empty.cpp b/csrc/optimization/remove_empty.cpp new file mode 100644 index 00000000000..b133422d8ad --- /dev/null +++ b/csrc/optimization/remove_empty.cpp @@ -0,0 +1,308 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace nvfuser::optimization { + +namespace { + +//! Get a vector of the integer positions of constant zero extent axes in the +//! input domain. This will typically be used like +//! `emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain()))` +std::vector emptyAxes(const std::vector& domain) { + std::vector empty_axes; + for (auto ax : c10::irange(domain.size())) { + auto id = domain.at(ax); + if (id->extent()->isConst() && id->extent()->evaluateInt() == 0) { + empty_axes.push_back((int64_t)ax); + } + } + return empty_axes; +} + +//! Check whether a TensorView is empty. During concretization, we traverse to +//! find a minimal set of TensorViews that have zero extents, and we then set +//! their extents to a constant 0. Here we check for those constant zero +//! extents. +bool isTVEmpty(TensorView* tv) { + return !emptyAxes(TensorDomain::noReductions(tv->getMaybeRFactorDomain())) + .empty(); +} + +//! EmptyTensorRemover performs a backward traversal of the Fusion. When it +//! detects a TensorView that has at least one extent that is zero, we do the +//! following: +//! +//! 1. If the empty Tensorview is a Fusion output, we replace it with a +//! TensorView created by `full` having the same shape. Since the original +//! tensor is empty, there is nothing to compute, so this eliminates a branch +//! of trivial code. +//! 2. If the empty TensorView is the input of a `cat` op along the empty +//! dimension, we replace the cat op with a new one having the empty input +//! removed. +//! 3. If the empty Tensorview is the input to a ReductionOp or WelfordOp and +//! the empty dimensions are reduced, we replace the op with `full` since +//! there are no elements being reduced. Note that if any empty axes are not +//! reduced, we will not encounter this case since it will have been removed +//! earlier in the backward traversal under condition 1. +//! 4. If the empty Tensorview is the input to a PadOp (which is not input to +//! a CatOp) then we replace the pad with `full(pad_value)`. +//! 5. If empty TensorViews are input to an MmaOp and they are empty in +//! contracted axes, we replace with `full({m, n}, zeroVal())`. +//! +class EmptyTensorRemover : public DeadCodeRemover { + public: + EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {} + + protected: + using DeadCodeRemover::handle; + + //! If tv is a fusion output, we check whether it is empty and if so, replace + //! it with full(). For non-outputs that are not inputs, we simply check that + //! the tensor is not provably empty. + void handle(TensorView* tv) final { + DeadCodeRemover::handle(tv); + if (isDead(tv)) { + // DeadCodeRemover::handle might have set this dead, in which case we + // don't need to process it any further + return; + } + + if (isTVEmpty(tv)) { + if (tv->isFusionInput()) { + TORCH_INTERNAL_ASSERT( + allUsesDead(tv), + "Empty Fusion input ", + tv, + " should not have any live uses."); + // Empty inputs do not have a definition to redefine + return; + } + + // Any non-input that we traverse to should be the input to an expression, + // or a Fusion output. If it's the input to an expression, we should have + // replaced that expression by handling the appropriate Expr subclass. + TORCH_INTERNAL_ASSERT( + tv->isFusionOutput(), + "Found unexpected empty intermediate TensorView ", + tv->toString()); + auto shape = noReductionShape(tv); + auto dtype = tv->getDataType().value(); + auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); + registerReplacement(tv, new_tv); + } + } + + //! Gets a vector of extents for noReduction(tv->getMaybeRFactorDomain()) + static std::vector noReductionShape(TensorView* tv) { + std::vector shape; + for (auto id : TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { + shape.push_back(id->extent()); + } + return shape; + } + + //! A reduction over empty axes is equal to the initial value of the + //! reduction, as if the reduction were written as follows: + //! + //! auto result = init_value; + //! for (auto element : reduction_elements) { + //! result = reduction_op(result, element); + //! } + //! return result; + //! + void handle(ReductionOp* rop) final { + auto in = rop->in()->as(); + auto empty_input_axes = + emptyAxes(TensorDomain::noReductions(in->getMaybeRFactorDomain())); + if (empty_input_axes.empty()) { + // Input is not empty, handle like any other op + return; + } + auto out = rop->out()->as(); + // The input is empty in some axes. Assert that they are all reduced + for (auto ax : empty_input_axes) { + auto id = out->getRootDomain().at(ax); + // Input rfactor domain positions correspond to output root positions + TORCH_INTERNAL_ASSERT( + id->isReduction(), + "Found unexpected unreduced empty axis at position ", + ax, + " in expression ", + rop->toString()); + } + + auto new_tv = + full(noReductionShape(out), rop->init(), out->getDataType().value()); + registerReplacement(out, new_tv); + } + + //! A WelfordOp is similar to a ReductionOp, but has three outputs: avg, var, + //! N. For an empty reduction N will be zero, so we fill the output with zero. + //! The avg and var is obtained by summing then dividing by N. For empty + //! reductions this leads to 0.0 / 0 so we fill it with a constant NAN. The + //! .var variable is actually an unnormalized variance which is a sum without + //! dividing by N or N-1, so we fill it with zeros. + void handle(WelfordOp* wop) final { + auto in = wop->in()->as(); + auto empty_input_axes = + emptyAxes(TensorDomain::noReductions(in->getMaybeRFactorDomain())); + if (empty_input_axes.empty()) { + // Input is not empty, handle like any other op + return; + } + auto avg = wop->outAvg()->as(); + auto var_sum = wop->outVar()->as(); + auto N = wop->outN()->as(); + // The input is empty in some axes. Assert that they are all reduced + for (auto ax : empty_input_axes) { + auto id = avg->getRootDomain().at(ax); + // Input rfactor domain positions correspond to output root positions + TORCH_INTERNAL_ASSERT( + id->isReduction(), + "Found unexpected unreduced empty axis at position ", + ax, + " in expression ", + wop->toString()); + } + + // Since WelfordOp has multiple outputs, we need to check whether each is + // live before replacing it, to avoid replacing a dead output with a live + // one. + auto shape = noReductionShape(avg); + if (isLive(avg)) { + auto nan = IrBuilder::create( + std::numeric_limits::quiet_NaN(), avg->getDataType().value()); + auto nan_tensor = full(shape, nan, avg->getDataType().value()); + registerReplacement(avg, nan_tensor); + } + if (isLive(var_sum)) { + auto new_var_sum = full( + shape, + fusion()->zeroVal(var_sum->getDataType().value()), + var_sum->getDataType().value()); + registerReplacement(var_sum, new_var_sum); + } + if (isLive(N)) { + auto new_N = full(shape, fusion()->zeroVal(), N->getDataType().value()); + registerReplacement(N, new_N); + } + } + + //! A cat op can have input empty tensors and still output a non-empty + //! tensor. This is only possible if there is more than one input, so we + //! only need to handle those cases. We find the non-empty inputs to cat + //! then replace with another cat (or `set` if n=1). + //! + //! The `cat` function creates a CatOp object, but its inputs() are not + //! the original inputs. Rather, they are the inputs after padding to the + //! output extent in the concatenated dimension. Thus, in the IR graph, + //! instead of the following: + //! + //! T0 T1 T2 + //! \ | / + //! CatOp + //! | + //! T3 + //! + //! a cat is represented as: + //! + //! T0 T1 T2 + //! | | | + //! PadOp PadOp PadOp + //! \ | / + //! CatOp + //! | + //! T3 + //! + //! If we determine that one of the inputs, T1, is empty in the cat + //! dimension, then we rewrite this as: + //! + //! T0 T2 + //! | | + //! PadOp PadOp + //! \ / + //! CatOp + //! | + //! T3 + //! + //! This is done by simply calling the cat() command with only {T0, T2}. + void handle(CatOp* cop) final { + auto dim = cop->concatenatedDim(); + std::vector non_empty_inputs; + for (auto inp : cop->inputs()) { + TORCH_INTERNAL_ASSERT( + inp->definition() && inp->definition()->isA(), + "Inputs to CatOp must be outputs of PadOps"); + auto tv = inp->definition()->as()->in()->as(); + auto cat_id = + TensorDomain::noReductions(tv->getMaybeRFactorDomain()).at(dim); + if (cat_id->extent()->isConst() && cat_id->extent()->evaluateInt() == 0) { + continue; + } + non_empty_inputs.push_back(tv); + } + if (non_empty_inputs.size() != cop->inputs().size()) { + // Replace this op with a new cat op + auto old_tv = cop->outputs()[0]->as(); + // NOTE: cat() will translate to set() if non_empty_inputs.size() == 1 + auto new_tv = cat(non_empty_inputs, dim); + registerReplacement(old_tv, new_tv); + } + } + + //! Replace pad(tv) if tv is empty in any dimension. Note that since we detect + //! empty tensors by looking for constant extents, the output extents will be + //! correct here already, so there is no value in removing the empty input + //! extent when we do the replacement. + void handle(PadOp* pop) final { + auto in = pop->in()->as(); + auto in_rfactor = TensorDomain::noReductions(in->getMaybeRFactorDomain()); + if (!emptyAxes(in_rfactor).empty()) { + auto out = pop->out()->as(); + auto shape = noReductionShape(out); + auto dtype = out->getDataType().value(); + auto new_tv = full(shape, pop->value(), dtype); + registerReplacement(out, new_tv); + } + } + + //! We handle MmaOp just as if it were written as a sum ReductionOp. + void handle(MmaOp* mop) final { + auto A = mop->inA()->as(); + auto A_rfactor = TensorDomain::noReductions(A->getMaybeRFactorDomain()); + // We only need to check empty axes in A. If any reduced axes are empty + // here, they will be empty in B also. If any non-reduced axes are empty, + // the output will also be empty, and this expression will already be dead. + if (!emptyAxes(A_rfactor).empty()) { + auto out = mop->out()->as(); + auto shape = noReductionShape(out); + auto dtype = out->getDataType().value(); + auto new_tv = full(shape, fusion()->zeroVal(dtype), dtype); + registerReplacement(out, new_tv); + } + } +}; + +} // namespace + +void RemoveEmptyPass::runPass(Fusion* fusion) { + EmptyTensorRemover(fusion).run(); +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/remove_empty.h b/csrc/optimization/remove_empty.h new file mode 100644 index 00000000000..2c190619335 --- /dev/null +++ b/csrc/optimization/remove_empty.h @@ -0,0 +1,22 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +//! RemoveEmptyPass removes intermediate empty tensors (those with at least one +//! extent zero thar are neither a fusion output or input). +class TORCH_CUDA_CU_API RemoveEmptyPass + : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); +}; + +} // namespace nvfuser::optimization diff --git a/csrc/options.cpp b/csrc/options.cpp index 1d346decb55..195cbc7017e 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -117,6 +117,7 @@ std::unordered_map> Options< {"fusion_args", DebugDumpOption::FusionArgs}, {"fusion_ir", DebugDumpOption::FusionIr}, {"fusion_ir_concretized", DebugDumpOption::FusionIrConcretized}, + {"fusion_ir_preseg", DebugDumpOption::FusionIrPreseg}, {"fusion_ir_math", DebugDumpOption::FusionIrMath}, {"fusion_ir_presched", DebugDumpOption::FusionIrPresched}, {"halo", DebugDumpOption::Halo}, diff --git a/csrc/options.h b/csrc/options.h index 4fc85c0f36c..5cf065c0793 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -25,6 +25,7 @@ enum class DebugDumpOption { FusionIrMath, //!< Dump just the compute (math) part of the Fusion IR FusionIrPresched, //!< Dump the Fusion IR before it is scheduled. FusionIrConcretized, //!< Dump the Fusion IR after concretization + FusionIrPreseg, //!< Dump the Fusion IR after pre-segmenter optimization KernelIr, //!< Dump the compiler Kernel IR ComputeAtMap, //!< Dump the computeAt map CudaKernel, //!< Dump the generated CUDA C++ kernel code diff --git a/test/test_optimization_pass.cpp b/test/test_optimization_pass.cpp index bf7b1f1c88d..158ac4dbd84 100644 --- a/test/test_optimization_pass.cpp +++ b/test/test_optimization_pass.cpp @@ -365,4 +365,307 @@ TEST_F(NVFuserTest, FusionTestCastOptimization_CUDA) { } } +// Test that we remove empty output branch before segmentation +TEST_F(NVFuserTest, FusionRemoveEmptyOutput_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = set(tv0); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + // In the FusionKernelRuntime, before segmentation a number of optimization + // passes are performed. One of those is RemoveEmptyPass, which should replace + // the empty output tv1 with a new TensorView defined by `full({0, 3})` in + // this case. + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0], tv1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate(preseg_fusion, outputs, aten_inputs, {at0}, __LINE__, __FILE__); +} + +// Test that we replace empty reduction with full +TEST_F(NVFuserTest, FusionRemoveEmptyReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {0}); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::sum(at0, {0})}, + __LINE__, + __FILE__); +} + +// In this test, a reduction over a non-empty axis occurs first, followed by a +// reduction over the remaining empty axis. The output is actually not empty, +// even though the first reduction results in an empty tensor. +TEST_F(NVFuserTest, FusionRemoveEmptyReductionWithNonReduction_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3, 2}); + fusion.addInput(tv0); + auto tv1 = sum(tv0, {1}); + auto tv2 = sum(tv1, {0}); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3, 2}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::sum(at::sum(at0, 1), 0)}, + __LINE__, + __FILE__); +} + +// Test that we replace empty Welford with full +TEST_F(NVFuserTest, FusionRemoveEmptyWelford_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto w = Welford(tv0, {0}); + fusion.addOutput(w.avg); + auto var = div(w.var_sum, fusion_ptr->zeroVal(DataType::Float)); + fusion.addOutput(var); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + + EXPECT_NE(var->definition(), nullptr); + EXPECT_TRUE(var->definition()->isA()); + // We divide in the fusion to normalize the variance, so here we have to peel + // that back + auto var_sum = var->definition()->inputs()[0]->as(); + EXPECT_TRUE(var_sum->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::mean(at0, 0), at::var(at0, 0)}, + __LINE__, + __FILE__); +} + +// Test that we replace empty tensors in cat properly +TEST_F(NVFuserTest, FusionRemoveEmptyCat_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({0, 3}); + fusion.addInput(tv0); + auto tv1 = makeConcreteTensor({2, 3}); + fusion.addInput(tv1); + auto tv2 = makeConcreteTensor({4, 3}); + fusion.addInput(tv2); + + // equivalent to cat({tv1, tv2}, 0) + auto tv3 = cat({tv0, tv1, tv2}, 0); + fusion.addOutput(tv3); + // set(tv1) + auto tv4 = cat({tv0, tv1}, 0); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({0, 3}, options); + at::Tensor at1 = at::randn({2, 3}, options); + at::Tensor at2 = at::randn({4, 3}, options); + std::vector aten_inputs = {at0, at1, at2}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[0]->definition()->isA()); + EXPECT_EQ(preseg_fusion->outputs()[0]->definition()->inputs().size(), 2); + + EXPECT_NE(preseg_fusion->outputs()[1]->definition(), nullptr); + EXPECT_TRUE(preseg_fusion->outputs()[1]->definition()->isA()); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::cat({at1, at2}, 0), at1}, + __LINE__, + __FILE__); +} + +// Test that we replace empty tensors in pad properly +TEST_F(NVFuserTest, FusionRemoveEmptyPad_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + // Concrete tensor with zero for one extent, so that we can prove the output + // is empty + auto tv0 = makeConcreteTensor({3, 0}); + fusion.addInput(tv0); + + // Use a non-zero pad value to verify that it is used in the rewritten fill + auto pad_val = IrBuilder::create(3.14, DataType::Float); + + // equivalent to full({3, 2}, pad_val, DataType::Float) + auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}, pad_val); + fusion.addOutput(tv1); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({3, 0}, options); + std::vector aten_inputs = {at0}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + auto rewritten_def = preseg_fusion->outputs()[0]->definition(); + EXPECT_TRUE(rewritten_def->isA()); + EXPECT_TRUE(rewritten_def->as()->getFillValue()->sameAs(pad_val)); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::pad(at0, {1, 1}, "constant", 3.14)}, + __LINE__, + __FILE__); +} + +// Test that we replace empty tensors in matmuls properly +TEST_F(NVFuserTest, FusionRemoveEmptyMatmul_CUDA) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(fusion_ptr.get()); + + // [M, K] + auto tv0 = makeConcreteTensor({16, 0}, DataType::Half); + // [K, N] + auto tv1 = makeConcreteTensor({0, 8}, DataType::Half); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // [M, N, K] + auto tv0b = broadcast(tv0, {false, true, false}); + // [M, K, N] + auto tv1b = broadcast(tv1, {true, false, false}); + // [M, N, K] + auto tv1t = transpose(tv1b, 1, 2); + + auto tv2 = fusedMultiplySum(tv0b, tv1t, {2}); + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0); + at::Tensor at0 = at::randn({16, 0}, options); + at::Tensor at1 = at::randn({0, 8}, options); + std::vector aten_inputs = {at0, at1}; + + auto args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs); + FusionKernelRuntime runtime(std::move(fusion_ptr), args); + + auto preseg_fusion = runtime.fusionSegments()->completeFusion(); + EXPECT_EQ(preseg_fusion->outputs().size(), 1); + + EXPECT_NE(preseg_fusion->outputs()[0]->definition(), nullptr); + auto rewritten_def = preseg_fusion->outputs()[0]->definition(); + EXPECT_TRUE(rewritten_def->isA()); + EXPECT_EQ(rewritten_def->as()->getFillValue()->evaluateDouble(), 0.0); + + runtime.compileFusionParallel(args); + auto outputs = runtime.runWithInputs(args); + + testValidate( + preseg_fusion, + outputs, + aten_inputs, + {at::zeros({16, 8}, options)}, + __LINE__, + __FILE__); +} + } // namespace nvfuser::optimization