Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,11 +898,6 @@ class IterMapRewriter : public ExprMutator {
PrimExpr SplitFloorModConst(IterSplitExpr lhs, PrimExpr base, PrimExpr rhs);

static void AddToLhs(IterSumExprNode* lhs, IterSplitExpr rhs, int sign) {
if (sign < 0 && is_const_int(rhs->extent, 2)) {
lhs->base -= rhs->scale;
sign = 1;
}

tir::ExprDeepEqual equal;
for (size_t i = 0; i < lhs->args.size(); ++i) {
IterSplitExpr lvalue = lhs->args[i];
Expand Down
17 changes: 13 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {

TVM_TRY_RECURSIVE_REWRITE(floordiv(x, 2) + floormod(x, 2), floordiv(x + 1, 2));

// Simplify (x + 1) % 2 + x % 2 => 1
// NOTE: we should avoid simplifying (x + 1) %2 => 1 - x % 2 though
// mainly because introducing extra negative signs to expression can harm itertaor
// analysis which usually relies on positive itertator co-efficients.
TVM_TRY_REWRITE_IF(floormod(x + c1, 2) + floormod(x, 2), OneWithTypeLike(x),
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x, 2) + floormod(x + c1, 2), OneWithTypeLike(x),
floormod(c1.Eval()->value, 2) == 1);

// canonicalization rule
// will try rewrite again after canonicalization.

Expand Down Expand Up @@ -1018,10 +1027,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0);

TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + c1, 2), floormod(x, 2) * (-1) + 1,
floormod(c1.Eval()->value, 2) == 1);
TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
// (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x
TVM_TRY_REWRITE_IF(
floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || c1.Eval()->value < 0));

TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
c2.Eval()->value > 0);
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,41 @@ def test_simplify_normalize_min_value_expr():
ck.verify(0 == x + te.min_value("int32"), False)


def test_proddiv_simplify():
ck = CanonicalChecker()
flm = tvm.te.floormod
fld = tvm.te.floordiv
tdiv = tvm.te.truncdiv
tmod = tvm.te.truncmod

x, y, z = te.var("x"), te.var("y"), te.var("y")

ck.verify(flm(x * 32 * x, x), 0)
ck.verify(flm(z * x * 32 * x * y, x * z), 0)
ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0)
ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2))
ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x)

ck.verify(tmod(x * 32 * x, x), 0)
ck.verify(tmod(z * x * 32 * x * y, x * z), 0)
ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2))
ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x)

ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2))
ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2)
ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z))

ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2))
ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2)
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))


def test_floormod_two():
ck = CanonicalChecker()
flm = tvm.te.floormod
x, y = te.var("x"), te.var("y")
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)


if __name__ == "__main__":
tvm.testing.main()
10 changes: 5 additions & 5 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,14 @@ def test_compound():
assert_iter_sum_pattern({z[0]: (18, 0, 1, sz), xi[0]: (5, 0)}, var_dom([(x, 10), (y, 9)]))


def test_compound_floormod_two():
def test_compound_floormod_two_regression():
x = tvm.tir.Var("x", "int32")
fld = tvm.tir.floordiv
flm = tvm.tir.floormod

# extent of 2 are normalized to positive scale
assert_iter_sum_pattern(
expect_dict={fld(x, 2) * 2 - flm(x, 2) + 1: (8, 0, 1)},
# regression
# extent of 2 of negative scale cannot be normalized
assert_iter_sum_failure(
[fld(x, 2) * 2 - flm(x, 2) + 1],
dom_map=var_dom([(x, 8)]),
)

Expand Down
22 changes: 15 additions & 7 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ class TestSubIndex(BaseCompare):
TestCase(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)),
TestCase(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1),
TestCase(fld(y, 3) * 3 - y, 0 - flm(y, 3)),
TestCase(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6),
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)),
TestCase(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6),
TestCase(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)),
TestCase(y - fld(y + z, 5) * 5, flm(y + z, 5) - z),
TestCase(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)),
TestCase(y - fld(y - z, 5) * 5, flm(y - z, 5) + z),
Expand Down Expand Up @@ -554,13 +554,15 @@ class TestFloormodIndex(BaseCompare):
TestCase(flm(x + 10, 2), flm(x, 2)),
TestCase(flm(x + y * 10, 2), flm(x, 2)),
TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
TestCase(flm(x * (-10), 2), 0),
TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
TestCase(flm(x + (-10), 2), flm(x, 2)),
TestCase(flm(x + y * (-10), 2), flm(x, 2)),
TestCase(flm(x * 32 + y, 64), flm(x, 2) * 32 + y, [y >= 0, y < 32]),
TestCase(flm(x * 32 - y, 64), flm(x * 32 - y, 64), [y >= 0, y < 32]),
# NOTE: the followng case is covered by canonical simplify
# long range simplifcation in general can be covered by canonical simplify
# TestCase(flm(x * 10 + 1 + y * 2 + 2, 2), 1),
)


Expand All @@ -574,13 +576,14 @@ class TestFloorModTwo(BaseCompare):
require identifying more related terms in order to apply.

(x + c1)//2 - (x+c2)//2 => (x%2)*( c1%2 - c1%2 ) + (c1//2 - c2//2)

We should not introduce extra negative coeficient to iterators
however during simplification
"""

x, y, z = te.var("x"), te.var("y"), te.var("z")
test_case = tvm.testing.parameter(
# Removing offsets from floormod
TestCase(flm(x + 1, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x + 5, 2), flm(x, 2) * (-1) + 1),
TestCase(flm(x, 2) + flm(x + 1, 2), 1),
TestCase(flm(x + 1, 2) + flm(x, 2), 1),
# Difference of floordiv yields floormod
Expand All @@ -592,8 +595,13 @@ class TestFloorModTwo(BaseCompare):
# Sum of floordiv and floormod to yield floordiv
TestCase(fld(x + 1, 2) - flm(x, 2), fld(x, 2)),
TestCase(fld(x, 2) + flm(x, 2), fld(x + 1, 2)),
# Removal of floormod where possible
TestCase(flm(x + 1, 2) * 8192, x * (-8192) + 8192, [x >= 0, x < 2]),
# regression: although we can rewrite (x + 1) %2 => 1 - x%2
# doing so would introduce negative co-efficient to iterators
# which makes later iter map detection harder, in principle we
# should not introduce additional negative signs of iterator in rewriting
TestCase(flm(x + 1, 2), flm(x + 1, 2)),
TestCase(flm(x + 5, 2), flm(x + 1, 2)),
TestCase(flm(x + 1, 2) * 8192, flm(x + 1, 2) * 8192, [x >= 0, x < 2]),
)


Expand Down
Loading