diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f0463ae8f9b..2a2bc935280 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -61,7 +61,9 @@ jobs: git checkout origin/main head_commit=$(git rev-parse HEAD) git checkout $this_commit - git --no-pager diff --name-only $head_commit | grep -e "csrc/.*\.cpp" -e "csrc/.*\.h" | xargs lintrunner --take CLANGTIDY --force-color + # diff-filter for lower case letter: + # https://github.com/git/git/commit/7f2ea5f0f2fb056314092cce23202096ca70f076 + git --no-pager diff --diff-filter=d --name-only $head_commit | grep -e "csrc/.*\.cpp" -e "csrc/.*\.h" | xargs lintrunner --take CLANGTIDY --force-color lintrunner: runs-on: ubuntu-latest diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f97c6518c1..654a36ae7d4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,7 +63,6 @@ endif() # nvfuser codegen sources set(NVFUSER_SRCS) list(APPEND NVFUSER_SRCS - ${NVFUSER_SRCS_DIR}/assume.cpp ${NVFUSER_SRCS_DIR}/compute_at.cpp ${NVFUSER_SRCS_DIR}/inlining.cpp ${NVFUSER_SRCS_DIR}/compute_at_map.cpp @@ -181,6 +180,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/utils.cpp ${NVFUSER_SRCS_DIR}/mma_type.cpp ${NVFUSER_SRCS_DIR}/scheduler/mma_utils.cpp + ${NVFUSER_SRCS_DIR}/optimization/add_axioms.cpp ${NVFUSER_SRCS_DIR}/optimization/consecutive_cast.cpp ${NVFUSER_SRCS_DIR}/optimization/pre_segmenter.cpp ) diff --git a/benchmark/matmul.cpp b/benchmark/matmul.cpp index a5f1a1a71d9..404c9ac4715 100644 --- a/benchmark/matmul.cpp +++ b/benchmark/matmul.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -141,6 +142,8 @@ static void SingleMatmulBase( // Define fusion graph setupMatmul(fusion, layout, params, turing_or_later); + optimization::OptimizationPass::runPass(fusion); + // inputs at::manual_seed(0); diff --git a/csrc/assume.cpp b/csrc/assume.cpp deleted file mode 100644 index c0ca3be778c..00000000000 --- a/csrc/assume.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include -#include - -#include - -namespace nvfuser::assume { - -Bool* tensorsAreNotEmpty(Val* value) { - std::vector todo{value}; - std::vector tensor_sizes; - while (!todo.empty()) { - auto v = todo.back(); - todo.pop_back(); - TORCH_INTERNAL_ASSERT(v != nullptr); - if (auto ns = dynamic_cast(v)) { - if (ns->isTensorSize()) { - tensor_sizes.emplace_back(v); - continue; - } - } - if (auto def = v->definition()) { - for (auto inp : def->inputs()) { - todo.emplace_back(inp); - } - } - } - Bool* result = nullptr; - // tensor_sizes might contain duplicate, and we should remove this duplication - std::vector tensor_sizes_applied; - for (auto ts : tensor_sizes) { - bool is_duplicate = false; - for (auto existing : tensor_sizes_applied) { - if (existing->sameAs(ts)) { - is_duplicate = true; - break; - } - } - if (!is_duplicate) { - tensor_sizes_applied.emplace_back(ts); - result = SimplifyingIrBuilder::andExpr( - result, SimplifyingIrBuilder::gtExpr(ts, ts->container()->zeroVal())); - } - } - return result; -} - -} // namespace nvfuser::assume diff --git a/csrc/assume.h b/csrc/assume.h deleted file mode 100644 index 9645d145e9e..00000000000 --- a/csrc/assume.h +++ /dev/null @@ -1,18 +0,0 @@ -#include - -// Return boolean predicates representing the conditional you want to assume. -// The return value is typically used as the `assumptions` argument of -// `simplifyExpr` - -namespace nvfuser::assume { - -// Return a boolean predicate stating that all tensor sizes appearing in `value` -// are positive. Return nullptr if `value` does not depend on any tensor size. -// For example: -// tensorsAreNotEmpty(ceilDiv(T0.size[0], 5) * T0.size[1]) -// -> T0.size[0] > 0 && T0.size[1] > 0 -// tensorsAreNotEmpty(ceilDiv(i1, 5) * i2) -// -> nullptr -Bool* tensorsAreNotEmpty(Val* value); - -} // namespace nvfuser::assume diff --git a/csrc/ir/container.cpp b/csrc/ir/container.cpp index ed23cccdf18..2f0b9f5d559 100644 --- a/csrc/ir/container.cpp +++ b/csrc/ir/container.cpp @@ -61,6 +61,13 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) { to->val_type_name_map_ = from->val_type_name_map_; to->expr_name_counter_ = from->expr_name_counter_; + if (from->axioms_ != nullptr) { + to->axioms_ = std::make_unique>(); + for (auto pred : *from->axioms_) { + to->axioms_->emplace_back(ir_cloner.clone(pred)); + } + } + return ir_cloner; } @@ -189,7 +196,7 @@ void IrContainer::clear() noexcept { exprs_.clear(); exprs_up_.clear(); raw_ptrs_.clear(); - + axioms_.reset(); val_type_name_map_.clear(); expr_name_counter_ = 0; } @@ -305,7 +312,7 @@ NamedScalar* IrContainer::magicZeroVal() { return magic_zero_val_.get(); } -const std::vector& IrContainer::axioms() { +void IrContainer::lazyInitAxioms() { if (!axioms_) { axioms_ = std::make_unique>(); axioms_->reserve(kParallelTypeThreads.size() * 3); @@ -318,7 +325,18 @@ const std::vector& IrContainer::axioms() { axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim)); } } - return *axioms_; +} + +void IrContainer::assumePositive(Val* val) { + TORCH_INTERNAL_ASSERT(val->container() == this); + lazyInitAxioms(); + axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal())); +} + +void IrContainer::assumeNonNegative(Val* val) { + TORCH_INTERNAL_ASSERT(val->container() == this); + lazyInitAxioms(); + axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal())); } } // namespace nvfuser diff --git a/csrc/ir/container.h b/csrc/ir/container.h index 6415400d0ad..453dee59444 100644 --- a/csrc/ir/container.h +++ b/csrc/ir/container.h @@ -96,7 +96,13 @@ class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase { Val* zeroVal(DataType dtype); Val* oneVal(DataType dtype); // Axioms about CUDA programming, for example: threadIdx.x < blockDim.x - const std::vector& axioms(); + const std::vector& axioms() { + lazyInitAxioms(); + return *axioms_; + } + + void assumePositive(Val* val); + void assumeNonNegative(Val* val); protected: static IrCloner copy(const IrContainer* from, IrContainer* to); @@ -131,6 +137,8 @@ class TORCH_CUDA_CU_API IrContainer : public PolymorphicBase { void clear() noexcept; + void lazyInitAxioms(); + // Deque of unique pointer is the memory owning data structure std::deque> vals_up_; diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 0f7efda48c1..bc48db2f000 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -477,6 +477,11 @@ class ValReplacementMutator : private OptOutMutator { more.emplace_back(v); } } + for (auto v : fusion->axioms()) { + if (std::find(stmts.begin(), stmts.end(), v) == stmts.end()) { + more.emplace_back(v); + } + } auto more_stmts = StmtSort::getStmts(fusion, more, true, true); more_stmts.insert(more_stmts.end(), stmts.begin(), stmts.end()); diff --git a/csrc/optimization/add_axioms.cpp b/csrc/optimization/add_axioms.cpp new file mode 100644 index 00000000000..461448a3d89 --- /dev/null +++ b/csrc/optimization/add_axioms.cpp @@ -0,0 +1,41 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include + +#include + +namespace nvfuser::optimization { + +void AddAxiomsPass::runPass(Fusion* fusion) { + auto all_vals = fusion->usedMathVals(); + std::unordered_set assumed_vals; + for (auto tv : ir_utils::filterByType(all_vals)) { + std::vector*> interested_domains{ + &tv->getRootDomain()}; + if (tv->hasRFactor()) { + interested_domains.push_back(&tv->getRFactorDomain()); + } + if (tv->hasAllocation()) { + interested_domains.push_back(&tv->getAllocationDomain()); + } + for (auto dom : interested_domains) { + for (auto id : *dom) { + auto extent = id->extent(); + if (extent->definition() == nullptr && !extent->isConstScalar() && + assumed_vals.insert(extent).second) { + fusion->assumePositive(extent); + } + } + } + } +} + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/add_axioms.h b/csrc/optimization/add_axioms.h new file mode 100644 index 00000000000..7bdd2fa87c9 --- /dev/null +++ b/csrc/optimization/add_axioms.h @@ -0,0 +1,20 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +namespace nvfuser::optimization { + +//! AddAxiomsPass adds extent > 0 as axioms of the IR container for all tensors +class TORCH_CUDA_CU_API AddAxiomsPass : public OptimizationPass { + friend class OptimizationPass; + + protected: + static void runPass(Fusion* fusion); +}; + +} // namespace nvfuser::optimization diff --git a/csrc/optimization/pre_segmenter.cpp b/csrc/optimization/pre_segmenter.cpp index baaa69fd7e4..30bdf87bf34 100644 --- a/csrc/optimization/pre_segmenter.cpp +++ b/csrc/optimization/pre_segmenter.cpp @@ -7,6 +7,7 @@ // clang-format on #include +#include #include namespace nvfuser::optimization { @@ -14,6 +15,7 @@ namespace nvfuser::optimization { void PreSegmenter::runPass(Fusion* fusion) { // removes consecutive cast operations OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); } } // namespace nvfuser::optimization diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 7bfad576cc7..1a3e049c368 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include #include @@ -70,16 +69,7 @@ void ParallelDimensionMap::build(Fusion* fusion) { // Simplify dim_map_ for (auto& [k, v] : dim_map_) { - // Well, this isn't really correct, but we need this assumption to better - // handle non-empty cases. If this turn out to be an issue, I believe we - // then need to find a more systematic way to handle empty tensor, rather - // than just disable this assumption. - auto assume = assume::tensorsAreNotEmpty(v); - if (assume != nullptr) { - v = simplifyExpr(v, {}, {assume}); - } else { - v = simplifyExpr(v); - } + v = simplifyExpr(v); } // Compute exact_types_ diff --git a/test/test_expr_simplifier.cpp b/test/test_expr_simplifier.cpp index f8a71b4500b..5985c289143 100644 --- a/test/test_expr_simplifier.cpp +++ b/test/test_expr_simplifier.cpp @@ -7,7 +7,6 @@ // clang-format on #include -#include #include #include #include @@ -1011,23 +1010,10 @@ TEST_F(ExprSimplifierTest, MinMax_CUDA) { auto expr = "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[0] , 128 ) ) , 4 )"_; - EXPECT_TRUE(simplify(expr, assume::tensorsAreNotEmpty(expr)) + EXPECT_TRUE(simplify(expr, "T0.size[0] > 0"_b) ->sameAs("ceilDiv( T0.size[0] , 128 ) * 4"_)); } -TEST_F(ExprSimplifierTest, Assume_CUDA) { - auto expr = - "max( max( ceilDiv( T0.size[0] , 128 ) * 4 , ceilDiv( T0.size[1] , 128 ) ) , 4 )"_; - EXPECT_EQ( - simplifyExpr(IrBuilder::eqExpr( - assume::tensorsAreNotEmpty(expr), - "T0.size[0] > 0 && T0.size[1] > 0"_)) - ->getBool(), - true); - expr = "ceilDiv( T0.size[0] , T0.size[0] ) * T0.size[0]"_; - EXPECT_TRUE(assume::tensorsAreNotEmpty(expr)->sameAs("T0.size[0] > 0"_)); -} - TEST_F(ExprSimplifierTest, PredicateDivToMul_CUDA) { auto simplified = simplifyExpr("i1 / T0.size[0] < i2"_, {}, {"i1 >= 0"_b}); auto expect = "i1 < ( i2 * T0.size[0] )"_; diff --git a/test/test_gpu_tensorcore.cpp b/test/test_gpu_tensorcore.cpp index b7ad14fcb61..2b29c553799 100644 --- a/test/test_gpu_tensorcore.cpp +++ b/test/test_gpu_tensorcore.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -936,6 +937,9 @@ TEST_F(NVFuserTest, FusionAmpereSwizzle_CUDA) { fusion.addOutput(tv2); + optimization::OptimizationPass::runPass( + &fusion); + MatMulTileOptions gemm_tile; gemm_tile.cta_tile = GemmTile(128, 128, 32); gemm_tile.warp_tile = GemmTile(64, 64, 32);