Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3041,15 +3041,39 @@ IterDomain* IterDomain::resize(
// The overall extent is (in->extent() + left_expansion +
// right_expansion). This can be simplified for a slice op as
// the right expansion should look like (slice_end_offset -
// in->extent()), so the overall extent is left_expansion + slice_end_offset.
// in->extent()), or (slice_end_offset + (- in->extent())), so the
// overall extent is left_expansion + slice_end_offset.

// Detect common slice patterns and return a simplified Val
// representing (in->extent() + right_expansion) if possible
auto simplify_input_extent_plus_right_expansion = [](Val* right_expansion,
Val* in_extent) -> Val* {
auto bop = dynamic_cast<BinaryOp*>(right_expansion->definition());
if (bop == nullptr) {
return nullptr;
}
Val* sub_rhs = nullptr;
if (bop->getBinaryOpType() == BinaryOpType::Sub) {
sub_rhs = bop->rhs();
} else if (bop->getBinaryOpType() == BinaryOpType::Add) {
// Note that SimplifyingIrBuilder may turn (a - b) to (a + (- b))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the crux of it, and it's why the previous pattern was not firing. Thanks for figuring this out!

In the future, perhaps we could perform this simplification in SimplifyingIrBuilder itself. To do so, we would need to "flatten" associative/commutative ops like add and mul, then perform cancellation before unflattening. Then anywhere we use SimplifyingIrBuilder::addExpr to construct a + (b + (c + (-a))) we could return b + c. That kind of transform is already done by simplifyExpr but that is only run automatically in certain cases like index expressions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #2020 for the SimplifyingIrBuilder approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented at the PR

if (auto uop = dynamic_cast<UnaryOp*>(bop->rhs()->definition());
uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg) {
sub_rhs = uop->in();
}
}
if (sub_rhs == in_extent) {
return bop->lhs();
} else {
return nullptr;
}
};

Val* resized_id_size = nullptr;
if (right_expansion->definition() != nullptr &&
right_expansion->definition()->isA<BinaryOp>() &&
right_expansion->definition()->as<BinaryOp>()->getBinaryOpType() ==
BinaryOpType::Sub &&
right_expansion->definition()->as<BinaryOp>()->rhs() == in->extent()) {
resized_id_size = SimplifyingIrBuilder::addExpr(
left_expansion, right_expansion->definition()->as<BinaryOp>()->lhs());
if (auto simplified_val = simplify_input_extent_plus_right_expansion(
right_expansion, in->extent())) {
resized_id_size =
SimplifyingIrBuilder::addExpr(left_expansion, simplified_val);
} else {
resized_id_size = SimplifyingIrBuilder::addExpr(
SimplifyingIrBuilder::addExpr(
Expand Down
Loading