From 4994d8487a0e84615db9653f20c4d063397653ee Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 18:05:44 -0800 Subject: [PATCH 1/6] Do some more simplifications specific to extents --- csrc/ir/builder.cpp | 11 +++++++ csrc/ops/alias.cpp | 73 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 12 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index ef8b95a70c8..5121a6104ce 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -380,6 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { + // Simplify x + (-x) to 0 + if (auto neg_expr = dynamic_cast(lhs->definition()); + neg_expr != nullptr && neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } else if (auto neg_expr = dynamic_cast(rhs->definition()); + neg_expr != nullptr && + neg_expr->getUnaryOpType() == UnaryOpType::Neg && + neg_expr->in()->sameAs(lhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } return IrBuilder::addExpr(lhs, rhs); } } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 5729fed5b3f..fd543e37df4 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -747,17 +747,58 @@ TensorView* slice( ", Expected: ", ndims); - const auto normalize_slice_range = [&manual_normalization]( - Slice range, Val* extent) -> Slice { + const auto get_int = [](Val* x) -> std::optional { + if (x != nullptr && x->isConstInt()) { + return x->evaluate().as(); + } else { + return std::nullopt; + } + }; + + // Specialized min for extents. Do some more simplification beyond + // SimplifyingIrBuilder that are only valid for extents. + const auto min_extents = [&](Val* x, Val* y) -> Val* { + auto x_int = get_int(x); + auto y_int = get_int(y); + // Since extents are never negative, if one is 0, that must be the mininum. + if (x_int == 0) { + return x; + } else if (y_int == 0) { + return y; + } + // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it + // isn't uncommon. + auto bop = dynamic_cast(x->definition()); + if (y_int != std::nullopt && bop != nullptr && + bop->getBinaryOpType() == BinaryOpType::Min) { + if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->rhs(), IrBuilder::create(std::min(*lhs_int, *y_int))); + } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->lhs(), IrBuilder::create(std::min(*rhs_int, *y_int))); + } + } + + return SimplifyingIrBuilder::minExpr(x, y); + }; + + const auto normalize_slice_range = + [&manual_normalization, &min_extents, &get_int]( + Slice range, Val* extent) -> Slice { auto cast_extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + auto start_int = get_int(range.start); + auto stop_int = get_int(range.stop); + // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; - } else if (!range.start->isZeroInt()) { + start_int = 0; + } else if (start_int != 0) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); if (!manual_normalization) { @@ -768,6 +809,9 @@ TensorView* slice( SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); } + if (range.start->isConstInt()) { + start_int = range.start->evaluate().as(); + } } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) @@ -776,15 +820,20 @@ TensorView* slice( } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - if (!manual_normalization) { - range.stop = SimplifyingIrBuilder::maxExpr( - range.start, - SimplifyingIrBuilder::minExpr( - cast_extent, - SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::ltExpr(range.stop, zero), - SimplifyingIrBuilder::addExpr(range.stop, cast_extent), - range.stop))); + // Commonly, range.start is zero and stop is non negative + if (start_int == 0 && stop_int >= 0) { + range.stop = min_extents(cast_extent, range.stop); + } else { + if (!manual_normalization) { + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + min_extents( + cast_extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), + range.stop))); + } } } From cd59f29943b951470592623085bc16b7681c3aa1 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 19:16:49 -0800 Subject: [PATCH 2/6] test fix --- tests/cpp/test_gpu3.cpp | 22 ++++++++++++++++++++++ tests/cpp/test_resize.cpp | 4 ++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 69b11a5d69f..43d3bc4ed19 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9342,6 +9342,28 @@ TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) { } } +TEST_F(NVFuserTest, TMP) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(2); + fusion.addInput(tv0); + + auto tv1 = sum(tv0, {1}); + fusion.addOutput(tv1); + + tv1->split(1, 4); + auto rf_tv = tv1->rFactor({-1}); + std::cerr << "RF: " << rf_tv->toString() << "\n"; + + fusion.print(); + + ComputeAtMap ca_map(&fusion); + scheduler_utils::propagateReshapeTransforms(&fusion, ca_map); + + fusion.print(); +} + // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index a4988592269..d70bb7c17e8 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1397,7 +1397,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { // By default, the extent of the tv1 domain is: // i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) ) // This should be simplified to just: - // fmax(0, ( fmin(i0, 1) )) + // fmin(i0, 1) fusion.addOutput(tv1); @@ -1405,7 +1405,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { auto bop = dynamic_cast(resize_extent->definition()); ASSERT_TRUE(bop != nullptr) << "Unexpected resize output extent: " << resize_extent->toInlineString(); - EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Max) + EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min) << "Unexpected resize output extent: " << resize_extent->toInlineString(); } From b979fc58e0be68ccc1d50b5c3fd9a00a4c62d00c Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Thu, 13 Feb 2025 22:42:17 -0800 Subject: [PATCH 3/6] cleanup --- tests/cpp/test_gpu3.cpp | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 43d3bc4ed19..69b11a5d69f 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -9342,28 +9342,6 @@ TEST_F(NVFuserTest, RegisteredExactMappingWithExtentReplacment) { } } -TEST_F(NVFuserTest, TMP) { - Fusion fusion; - FusionGuard fg(&fusion); - - auto tv0 = makeSymbolicTensor(2); - fusion.addInput(tv0); - - auto tv1 = sum(tv0, {1}); - fusion.addOutput(tv1); - - tv1->split(1, 4); - auto rf_tv = tv1->rFactor({-1}); - std::cerr << "RF: " << rf_tv->toString() << "\n"; - - fusion.print(); - - ComputeAtMap ca_map(&fusion); - scheduler_utils::propagateReshapeTransforms(&fusion, ca_map); - - fusion.print(); -} - // Test file size should be up to 10K LoC. Create a new file for more tests. } // namespace nvfuser From d72668ed3ba2b6366b840b57df7e4382335a9e20 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Sun, 16 Feb 2025 10:59:37 -0800 Subject: [PATCH 4/6] clang-tidy --- csrc/ir/builder.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 5121a6104ce..756b0e1b98c 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -381,14 +381,18 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { // Simplify x + (-x) to 0 - if (auto neg_expr = dynamic_cast(lhs->definition()); - neg_expr != nullptr && neg_expr->getUnaryOpType() == UnaryOpType::Neg && - neg_expr->in()->sameAs(rhs)) { - return lhs->fusion()->zeroVal(lhs->dtype()); - } else if (auto neg_expr = dynamic_cast(rhs->definition()); - neg_expr != nullptr && - neg_expr->getUnaryOpType() == UnaryOpType::Neg && - neg_expr->in()->sameAs(lhs)) { + Val* x = nullptr; + auto uop = dynamic_cast(lhs->definition()); + if (uop != nullptr) { + // lhs may be (-x). Pick rhs as x + x = rhs; + } else { + uop = dynamic_cast(rhs->definition()); + // rhs may be (-x). Pick lhs as x + x = lhs; + } + if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && + uop->in()->sameAs(x)) { return lhs->fusion()->zeroVal(lhs->dtype()); } return IrBuilder::addExpr(lhs, rhs); From a891d4964bf404e76981091877b03011259e14b2 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 19 Feb 2025 16:11:42 -0800 Subject: [PATCH 5/6] PR feedback --- csrc/ir/builder.cpp | 20 ++++++++------------ csrc/ops/alias.cpp | 14 ++++++++------ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index 756b0e1b98c..e4551312c3a 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -380,21 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { - // Simplify x + (-x) to 0 - Val* x = nullptr; - auto uop = dynamic_cast(lhs->definition()); - if (uop != nullptr) { - // lhs may be (-x). Pick rhs as x - x = rhs; - } else { - uop = dynamic_cast(rhs->definition()); - // rhs may be (-x). Pick lhs as x - x = lhs; + // Simplify (-x) + x to 0 + if (auto uop = dynamic_cast(lhs->definition()); uop != nullptr && + uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); } - if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && - uop->in()->sameAs(x)) { + // Simplify x + (-x) to 0 + if (auto uop = dynamic_cast(rhs->definition()); uop != nullptr && + uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(lhs)) { return lhs->fusion()->zeroVal(lhs->dtype()); } + return IrBuilder::addExpr(lhs, rhs); } } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index fd543e37df4..759edb35d89 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -747,9 +748,12 @@ TensorView* slice( ", Expected: ", ndims); - const auto get_int = [](Val* x) -> std::optional { - if (x != nullptr && x->isConstInt()) { - return x->evaluate().as(); + ExpressionEvaluator expr_eval; + + const auto get_int = [&expr_eval](Val* x) -> std::optional { + auto inferred_val = expr_eval.evaluate(x); + if (inferred_val.hasValue()) { + return inferred_val.as(); } else { return std::nullopt; } @@ -809,9 +813,7 @@ TensorView* slice( SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); } - if (range.start->isConstInt()) { - start_int = range.start->evaluate().as(); - } + start_int = get_int(range.start); } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) From 885d52f98a1b540017d8d40c09d508d5f52cc1d9 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 19 Feb 2025 17:38:53 -0800 Subject: [PATCH 6/6] fix --- csrc/ops/alias.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 759edb35d89..7e056a277f9 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -751,6 +751,9 @@ TensorView* slice( ExpressionEvaluator expr_eval; const auto get_int = [&expr_eval](Val* x) -> std::optional { + if (x == nullptr) { + return std::nullopt; + } auto inferred_val = expr_eval.evaluate(x); if (inferred_val.hasValue()) { return inferred_val.as();