Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions csrc/ir/builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,17 @@ Val* SimplifyingIrBuilder::addExpr(Val* lhs, Val* rhs) {
} else if (rhs->isConst()) {
return addExpr(lhs, rhs->value(), rhs->dtype());
} else {
// Simplify (-x) + x to 0
if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop != nullptr &&
uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(rhs)) {
return lhs->fusion()->zeroVal(lhs->dtype());
}
// Simplify x + (-x) to 0
if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop != nullptr &&
uop->getUnaryOpType() == UnaryOpType::Neg && uop->in()->sameAs(lhs)) {
return lhs->fusion()->zeroVal(lhs->dtype());
}

return IrBuilder::addExpr(lhs, rhs);
}
}
Expand Down
78 changes: 66 additions & 12 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <expr_evaluator.h>
#include <expr_simplifier.h>
#include <ir/builder.h>
#include <ir/utils.h>
Expand Down Expand Up @@ -747,17 +748,64 @@ TensorView* slice(
", Expected: ",
ndims);

const auto normalize_slice_range = [&manual_normalization](
Slice range, Val* extent) -> Slice {
ExpressionEvaluator expr_eval;

const auto get_int = [&expr_eval](Val* x) -> std::optional<int64_t> {
if (x == nullptr) {
return std::nullopt;
}
auto inferred_val = expr_eval.evaluate(x);
if (inferred_val.hasValue()) {
return inferred_val.as<int64_t>();
} else {
return std::nullopt;
}
};

// Specialized min for extents. Do some more simplification beyond
// SimplifyingIrBuilder that are only valid for extents.
const auto min_extents = [&](Val* x, Val* y) -> Val* {
auto x_int = get_int(x);
auto y_int = get_int(y);
// Since extents are never negative, if one is 0, that must be the mininum.
if (x_int == 0) {
return x;
} else if (y_int == 0) {
return y;
}
// Simplify patterns like min(min(x, 32), 32) to min(x, 32) as it
// isn't uncommon.
auto bop = dynamic_cast<BinaryOp*>(x->definition());
if (y_int != std::nullopt && bop != nullptr &&
bop->getBinaryOpType() == BinaryOpType::Min) {
if (auto lhs_int = get_int(bop->lhs()); lhs_int != std::nullopt) {
return SimplifyingIrBuilder::minExpr(
bop->rhs(), IrBuilder::create<Val>(std::min(*lhs_int, *y_int)));
} else if (auto rhs_int = get_int(bop->rhs()); rhs_int != std::nullopt) {
return SimplifyingIrBuilder::minExpr(
bop->lhs(), IrBuilder::create<Val>(std::min(*rhs_int, *y_int)));
}
}

return SimplifyingIrBuilder::minExpr(x, y);
};

const auto normalize_slice_range =
[&manual_normalization, &min_extents, &get_int](
Slice range, Val* extent) -> Slice {
auto cast_extent =
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, extent);

auto zero = FusionGuard::getCurFusion()->zeroVal(DataType::Index);

auto start_int = get_int(range.start);
auto stop_int = get_int(range.stop);

// norm_start = max(0, start < 0 ? start + extent : start)
if (range.start == nullptr) {
range.start = zero;
} else if (!range.start->isZeroInt()) {
start_int = 0;
} else if (start_int != 0) {
range.start =
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.start);
if (!manual_normalization) {
Expand All @@ -768,6 +816,7 @@ TensorView* slice(
SimplifyingIrBuilder::addExpr(range.start, cast_extent),
range.start));
}
start_int = get_int(range.start);
}

// norm_stop = max(norm_start, min(extent, stop < 0 ? stop + extent : stop)
Expand All @@ -776,15 +825,20 @@ TensorView* slice(
} else if (!range.stop->sameAs(extent)) {
range.stop =
SimplifyingIrBuilder::maybeCastExpr(DataType::Index, range.stop);
if (!manual_normalization) {
range.stop = SimplifyingIrBuilder::maxExpr(
range.start,
SimplifyingIrBuilder::minExpr(
cast_extent,
SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::ltExpr(range.stop, zero),
SimplifyingIrBuilder::addExpr(range.stop, cast_extent),
range.stop)));
// Commonly, range.start is zero and stop is non negative
if (start_int == 0 && stop_int >= 0) {
range.stop = min_extents(cast_extent, range.stop);
} else {
if (!manual_normalization) {
range.stop = SimplifyingIrBuilder::maxExpr(
range.start,
min_extents(
cast_extent,
SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::ltExpr(range.stop, zero),
SimplifyingIrBuilder::addExpr(range.stop, cast_extent),
range.stop)));
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1368,15 +1368,15 @@ TEST_F(ResizeTest, SliceExtentSimplification) {
// By default, the extent of the tv1 domain is:
// i0 + ( ( fmax(0, ( fmin(i0, 1) )) ) + ( -i0 ) )
// This should be simplified to just:
// fmax(0, ( fmin(i0, 1) ))
// fmin(i0, 1)

fusion.addOutput(tv1);

auto resize_extent = tv1->axis(0)->extent();
auto bop = dynamic_cast<BinaryOp*>(resize_extent->definition());
ASSERT_TRUE(bop != nullptr)
<< "Unexpected resize output extent: " << resize_extent->toInlineString();
EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Max)
EXPECT_EQ(bop->getBinaryOpType(), BinaryOpType::Min)
<< "Unexpected resize output extent: " << resize_extent->toInlineString();
}

Expand Down