diff --git a/csrc/ir/builder.cpp b/csrc/ir/builder.cpp index ef8b95a70c8..e4551312c3a 100644 --- a/csrc/ir/builder.cpp +++ b/csrc/ir/builder.cpp @@ -380,6 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) { } else if (rhs->isConst()) { return addExpr(lhs, rhs->value(), rhs->dtype()); } else { + // Simplify (-x) + x to 0 + if (auto uop = dynamic_cast(lhs->definition()); uop != nullptr && + uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(rhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } + // Simplify x + (-x) to 0 + if (auto uop = dynamic_cast(rhs->definition()); uop != nullptr && + uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(lhs)) { + return lhs->fusion()->zeroVal(lhs->dtype()); + } + return IrBuilder::addExpr(lhs, rhs); } } diff --git a/csrc/ops/alias.cpp b/csrc/ops/alias.cpp index 5729fed5b3f..7e056a277f9 100644 --- a/csrc/ops/alias.cpp +++ b/csrc/ops/alias.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -747,17 +748,64 @@ TensorView* slice( ", Expected: ", ndims); - const auto normalize_slice_range = [&manual_normalization]( - Slice range, Val* extent) -> Slice { + ExpressionEvaluator expr_eval; + + const auto get_int = [&expr_eval](Val* x) -> std::optional { + if (x == nullptr) { + return std::nullopt; + } + auto inferred_val = expr_eval.evaluate(x); + if (inferred_val.hasValue()) { + return inferred_val.as(); + } else { + return std::nullopt; + } + }; + + // Specialized min for extents. Do some more simplification beyond + // SimplifyingIrBuilder that are only valid for extents. + const auto min_extents = [&](Val* x, Val* y) -> Val* { + auto x_int = get_int(x); + auto y_int = get_int(y); + // Since extents are never negative, if one is 0, that must be the mininum. + if (x_int == 0) { + return x; + } else if (y_int == 0) { + return y; + } + // Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it + // isn't uncommon. + auto bop = dynamic_cast(x->definition()); + if (y_int != std::nullopt && bop != nullptr && + bop->getBinaryOpType() == BinaryOpType::Min) { + if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->rhs(), IrBuilder::create(std::min(*lhs_int, *y_int))); + } else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) { + return SimplifyingIrBuilder::minExpr( + bop->lhs(), IrBuilder::create(std::min(*rhs_int, *y_int))); + } + } + + return SimplifyingIrBuilder::minExpr(x, y); + }; + + const auto normalize_slice_range = + [&manual_normalization, &min_extents, &get_int]( + Slice range, Val* extent) -> Slice { auto cast_extent = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent); auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index); + auto start_int = get_int(range.start); + auto stop_int = get_int(range.stop); + // norm_start = max(0, start < 0 ? start + extent : start) if (range.start == nullptr) { range.start = zero; - } else if (!range.start->isZeroInt()) { + start_int = 0; + } else if (start_int != 0) { range.start = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start); if (!manual_normalization) { @@ -768,6 +816,7 @@ TensorView* slice( SimplifyingIrBuilder::addExpr(range.start, cast_extent), range.start)); } + start_int = get_int(range.start); } // norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop) @@ -776,15 +825,20 @@ TensorView* slice( } else if (!range.stop->sameAs(extent)) { range.stop = SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop); - if (!manual_normalization) { - range.stop = SimplifyingIrBuilder::maxExpr( - range.start, - SimplifyingIrBuilder::minExpr( - cast_extent, - SimplifyingIrBuilder::whereExpr( - SimplifyingIrBuilder::ltExpr(range.stop, zero), - SimplifyingIrBuilder::addExpr(range.stop, cast_extent), - range.stop))); + // Commonly, range.start is zero and stop is non negative + if (start_int == 0 && stop_int >= 0) { + range.stop = min_extents(cast_extent, range.stop); + } else { + if (!manual_normalization) { + range.stop = SimplifyingIrBuilder::maxExpr( + range.start, + min_extents( + cast_extent, + SimplifyingIrBuilder::whereExpr( + SimplifyingIrBuilder::ltExpr(range.stop, zero), + SimplifyingIrBuilder::addExpr(range.stop, cast_extent), + range.stop))); + } } } diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index eaa1418b6b1..3f2b0c36f3a 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -1368,7 +1368,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { // By default, the extent of the tv1 domain is: // i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) ) // This should be simplified to just: - // fmax(0, ( fmin(i0, 1) )) + // fmin(i0, 1) fusion.addOutput(tv1); @@ -1376,7 +1376,7 @@ TEST_F(ResizeTest, SliceExtentSimplification) { auto bop = dynamic_cast(resize_extent->definition()); ASSERT_TRUE(bop != nullptr) << "Unexpected resize output extent: " << resize_extent->toInlineString(); - EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Max) + EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min) << "Unexpected resize output extent: " << resize_extent->toInlineString(); }