Do some more simplifications specific to extents#3891
Conversation
|
Review updated until commit 885d52f Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
|
!test |
|
Maybe it's time for me to revive #511. With runtime info we should be able to fully simplify these expressions at concretization. |
jacobhinkle
left a comment
There was a problem hiding this comment.
LGTM. Comments are minor.
csrc/ir/builder.cpp
Outdated
| // Simplify x + (-x) to 0 | ||
| Val* x = nullptr; | ||
| auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); | ||
| if (uop != nullptr) { | ||
| // lhs may be (-x). Pick rhs as x | ||
| x = rhs; | ||
| } else { | ||
| uop = dynamic_cast<UnaryOp*>(rhs->definition()); | ||
| // rhs may be (-x). Pick lhs as x | ||
| x = lhs; | ||
| } | ||
| if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && | ||
| uop->in()->sameAs(x)) { | ||
| return lhs->fusion()->zeroVal(lhs->dtype()); | ||
| } |
There was a problem hiding this comment.
I think if you had abs(y) + (-abs(y)) this might not catch it because lhs=abs(y) is not a Neg. Instead, what about something like this?
| // Simplify x + (-x) to 0 | |
| Val* x = nullptr; | |
| auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); | |
| if (uop != nullptr) { | |
| // lhs may be (-x). Pick rhs as x | |
| x = rhs; | |
| } else { | |
| uop = dynamic_cast<UnaryOp*>(rhs->definition()); | |
| // rhs may be (-x). Pick lhs as x | |
| x = lhs; | |
| } | |
| if (uop != nullptr && uop->getUnaryOpType() == UnaryOpType::Neg && | |
| uop->in()->sameAs(x)) { | |
| return lhs->fusion()->zeroVal(lhs->dtype()); | |
| } | |
| // Simplify x + (-x) to 0 | |
| Val* x = nullptr; | |
| Val* neg_x_in = nullptr; | |
| if (auto uop = dynamic_cast<UnaryOp*>(lhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) { | |
| // lhs is -x. Pick rhs as x | |
| neg_x_in = uop->in(); | |
| x = rhs; | |
| } else if (auto uop = dynamic_cast<UnaryOp*>(rhs->definition()); uop && uop->getUnaryOpType() == UnaryOpType::Neg) { | |
| // rhs is -x. Pick lhs as x | |
| neg_x_in = uop->in(); | |
| x = lhs; | |
| } | |
| if (x != nullptr && neg_x_in->sameAs(x)) { | |
| return lhs->fusion()->zeroVal(lhs->dtype()); | |
| } |
There was a problem hiding this comment.
That's right, thanks. My initial version was actually exactly like your suggested code, but clang-tidy didn't like it because of the repetition of the condition expressions. Let me try rewriting the code again.
csrc/ops/alias.cpp
Outdated
| const auto normalize_slice_range = [&manual_normalization]( | ||
| Slice range, Val* extent) -> Slice { | ||
| const auto get_int = [](Val* x) -> std::optional<int64_t> { | ||
| if (x != nullptr && x->isConstInt()) { |
There was a problem hiding this comment.
Minor nit: if you initialize an ExpressionEvaluator in this function I think you could use expr_eval.evaluate(x) to return a PolymorphicValue and check pv.hasValue() instead of using optional<int64_t>. Then you could avoid the isConstInt() condition here which will redundantly evaluate the val and replace it with isIntegralScalar().
csrc/ops/alias.cpp
Outdated
| if (range.start->isConstInt()) { | ||
| start_int = range.start->evaluate().as<int64_t>(); | ||
| } |
There was a problem hiding this comment.
See comment above if we have an expr_eval. Not important but will avoid re-computing these vals.
|
!test |
|
!test |
While working on #3848, I noticed
test_unpadded_catop_issue2275_repro2took an extremely long time (> 30 min). That seems largely due to index hoisting and expression simplification. It took just a few seconds when they were disabled. That's likely due to a lot ofminandmaxdue to slicing of symbolic extents, as shown below:This PR tries to simplifies these extents a little further, which results in:
The test time is reduced to several seconds by these simplifications.
Confirmed no failure with manual_ci.sh on an H100 machine.