Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ inline PrimExpr DivImpl(PrimExpr a, PrimExpr b, DivMode mode) {
}
}

/*!
* \brief check if value fits in dtype
* \param value The value to be analyzed
* \param dtype The target dtype
* \param analyzer The analyzer
* \return whether value fits in dtype
*/
bool CastIsSafe(DataType dtype, PrimExpr value, Analyzer* analyzer) {
if (!IsIndexType(dtype)) {
return false;
}
ConstIntBound bound = analyzer->const_int_bound(value);
int64_t ubound = Downcast<IntImm>(max_value(dtype))->value;
int64_t lbound = Downcast<IntImm>(min_value(dtype))->value;
if (value.dtype().bits() <= dtype.bits() || // upcast is safe
(bound->max_value <= ubound && bound->min_value >= lbound)) {
return true;
}
return false;
}

/*!
* \brief Internal "Split normal form" of expression.
*
Expand Down Expand Up @@ -128,6 +149,58 @@ class SplitExprNode : public CanonicalExprNode {

void MulToSelf(int64_t scale) { this->scale *= scale; }

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, index % upper_factor / lower_factor * scale) ==
// cast(dtype, index) % upper_factor / lower_factor * scale
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = this->index;
if (this->scale == 0) {
return true;
}
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
if (this->upper_factor != SplitExprNode::kPosInf) {
res = ModImpl(res, make_const(this->dtype, this->upper_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->lower_factor != 1) {
res = DivImpl(res, make_const(this->dtype, this->lower_factor), div_mode);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
if (this->scale != 1) {
ICHECK(!this->dtype.is_uint() || this->scale > 0);
res = res * make_const(this->dtype, this->scale);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
this->index = cast(dtype, this->index);
this->dtype = dtype;
}

inline bool IndexEqual(const SplitExpr& other) const;
inline bool DivModeCompatibleTo(DivMode mode) const;

Expand Down Expand Up @@ -255,6 +328,69 @@ class SumExprNode : public CanonicalExprNode {

void AddToSelf(const SumExpr& other, int64_t scale);

/*!
* \brief check if cast can be pushed to sub-expressions
* \param dtype The target datatype
* \param analyzer The analyzer
* \return whether the cast can be safely pushed to children
*/
bool CanPushCastToChildren(DataType dtype, Analyzer* analyzer) const {
// cast(dtype, arg_1 + arg_2 + ... arg_n) ==
// cast(dtype, arg_1) + ... + cast(dtype, arg_n)
// iff it is an upcast (dtype.bits >= self.dtype.bits) or all of
// its intermediate results fit in the range of dtype
if (dtype.bits() >= this->dtype.bits()) {
return true; // upcast is safe
}
PrimExpr res = make_const(dtype, 0);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale > 0) {
res = res + args[i]->Normalize();
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base > 0) {
res = res + make_const(dtype, base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
// negative scales follows using sub.
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->scale < 0) {
res = res - args[i]->NormalizeWithScale(-1);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
}
if (base < 0) {
res = res - make_const(dtype, -base);
if (!CastIsSafe(dtype, res, analyzer)) {
return false;
}
}
for (const auto& arg : args) {
if (!arg->CanPushCastToChildren(dtype, analyzer)) {
return false;
}
}
return true;
}

/*!
* \brief self = cast(dtype, self)
* \param dtype The target datatype
*/
void PushCastToChildren(DataType dtype) {
for (auto& arg : args) {
arg.CopyOnWrite()->PushCastToChildren(dtype);
}
this->dtype = dtype;
}

static constexpr const char* _type_key = "arith.SumExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(SumExprNode, CanonicalExprNode);

Expand Down Expand Up @@ -430,6 +566,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorDivNode* op) final;
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;

private:
/*!
Expand Down Expand Up @@ -1071,6 +1208,30 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ReduceNode* op) {
return ret;
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
if (!IsIndexType(op->dtype)) {
return Rewriter::VisitExpr_(op);
}
// normalize
PrimExpr value = this->CanonicalMutate(op->value);
// PushCastToChildren
if (value.as<SumExprNode>()) {
SumExpr se = Downcast<SumExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
if (value.as<SplitExprNode>()) {
SplitExpr se = Downcast<SplitExpr>(value);
if (se->CanPushCastToChildren(op->dtype, analyzer_)) {
se.CopyOnWrite()->PushCastToChildren(op->dtype);
return std::move(se);
}
}
return Rewriter::VisitExpr_(op);
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,46 @@ def test_complex_cases():
ck.verify(res3, tdiv((x * 1024) + y, 256) - tdiv(y, 256) - (x * 4))


def test_simplify_cast():
ck = CanonicalChecker()
tcast = tvm.tir.Cast
fld = tvm.te.floordiv
flm = tvm.te.floormod
# cast(i64, i + j + 1) - cast(i64, i)
i = te.var("i", dtype="int32")
j = te.var("j", dtype="int32")
res = tcast("int64", i + j + 1) - tcast("int64", i)
ck.verify(res, tcast("int64", j) + tvm.tir.const(1, "int64"))
# cast(i32, i + j + 1) - cast(i32, i)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 10))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j + 1) - tcast("int32", i)
ck.verify(res, tcast("int32", j) + 1)
# cast(i32, i + j - 100)
i = te.var("i", dtype="int64")
j = te.var("j", dtype="int64")
ck.analyzer.update(i, tvm.arith.ConstIntBound(0, 2 ** 31 - 1))
ck.analyzer.update(j, tvm.arith.ConstIntBound(0, 10))
res = tcast("int32", i + j - 100)
ck.verify(res, res)
# cast(i32, flm(axis, 7i64) * 2i64 + 1i64) + 1i32
# - cast(i32, flm(axis, 7i64) * 2i64)
axis = te.var("axis", dtype="int64")
ck.analyzer.update(axis, tvm.arith.ConstIntBound(0, 42))
res = (
tcast(
"int32",
flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64")
+ tvm.tir.const(1, "int64"),
)
+ tvm.tir.const(1, "int32")
- tcast("int32", flm(axis, tvm.tir.const(7, "int64")) * tvm.tir.const(2, "int64"))
)
ck.verify(res, 2)


if __name__ == "__main__":
test_floormod_simplify()
test_mul_sum_simplify()
Expand All @@ -321,3 +361,4 @@ def test_complex_cases():
test_split_index_simplify()
test_canonical_mixed()
test_complex_cases()
test_simplify_cast()