Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 111 additions & 24 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {

TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(x, z * c1 + y) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(min(z * c1 + y, x) + z * c2, min(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);

TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(x, z * c1 + y) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);
TVM_TRY_REWRITE_IF(max(z * c1 + y, x) + z * c2, max(x + z * c2, y),
c1.Eval()->value == -c2.Eval()->value);

TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y);
TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y);
Expand All @@ -241,13 +250,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
// constant folding
// NOTE: canonicalization might better at this.
TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2));
TVM_TRY_REWRITE((c1 - y) + c2, (c1 + c2) - y);
TVM_TRY_REWRITE((y - c1) + c2, y + (c2 - c1));

TVM_TRY_REWRITE(c1 + (x + c2), x + (c1 + c2));
TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2));
TVM_TRY_REWRITE(c1 + (x - c2), x + (c1 - c2));
TVM_TRY_REWRITE((x - c1) + c2, x + (c2 - c1));
TVM_TRY_REWRITE(c1 + (c2 - x), (c1 + c2) - x);
TVM_TRY_REWRITE((c1 - x) + c2, (c1 + c2) - x);

// mul co-efficient folding
TVM_TRY_REWRITE(x + x, x * 2);
TVM_TRY_REWRITE(x * y + x, x * (y + 1));
TVM_TRY_REWRITE(y * x + x, x * (y + 1));
TVM_TRY_REWRITE(x + y * x, x * (1 + y));
TVM_TRY_REWRITE(x + x * y, x * (1 + y));
TVM_TRY_REWRITE(y * x + x, (y + 1) * x);
TVM_TRY_REWRITE(x + y * x, (y + 1) * x);
TVM_TRY_REWRITE(x + x * y, x * (y + 1));
TVM_TRY_REWRITE(x * y + x * z, x * (y + z));
TVM_TRY_REWRITE(y * x + x * z, x * (y + z));
TVM_TRY_REWRITE(x * y + z * x, x * (y + z));
Expand All @@ -265,12 +283,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

auto one = PConst<PrimExpr>(make_const(op->dtype, 1));
TVM_TRY_REWRITE(floormod(x, 2) + floormod(x + 1, 2), one);
TVM_TRY_REWRITE(floormod(x + 1, 2) + floormod(x, 2), one);

// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_RECURSIVE_REWRITE(c1 + x, x + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (y - c1), (x + y) - c1);
TVM_TRY_RECURSIVE_REWRITE((y - c1) + x, (x + y) - c1);
TVM_TRY_RECURSIVE_REWRITE((x + c1) + y, (x + y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x + (y + c1), (x + y) + c1);

TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x);
TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x);

Expand Down Expand Up @@ -359,6 +385,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(x - min(x, y), max(0, x - y));
TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0));

// constant cancelation
TVM_TRY_REWRITE(c1 - (x + c2), (c1 - c2) - x);
TVM_TRY_REWRITE(c1 - (x - c2), (c1 + c2) - x);
TVM_TRY_REWRITE(c1 - (c2 - x), x + (c1 - c2));
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
TVM_TRY_REWRITE((x - c1) - c2, x - (c1 + c2));
TVM_TRY_REWRITE((c1 - x) - c2, (c1 - c2) - x);
TVM_TRY_RECURSIVE_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2));

// mul co-efficient folding
TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x));
TVM_TRY_REWRITE(x * y - x, x * (y - 1));
Expand All @@ -370,10 +405,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
TVM_TRY_REWRITE(x * y - z * x, x * (y - z));
TVM_TRY_REWRITE(y * x - z * x, x * (y - z));

// constant cancelation
TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2));
TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2));

// cancelization rule involving 4 operands
TVM_TRY_REWRITE((x + y) - (x + z), y - z);
TVM_TRY_REWRITE((x + y) - (z + x), y - z);
Expand Down Expand Up @@ -493,6 +524,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
// canonicalization rule
// will try rewrite again after canonicalization.
TVM_TRY_REWRITE(x - c1, x + (0 - c1));
TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) - c1);
TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1);
TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y);
TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1));
Expand Down Expand Up @@ -536,7 +568,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) {

if (IsIndexType(op->dtype)) {
// constant simplification rule
TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2);
TVM_TRY_REWRITE_IF((x + c1) * c2, x * c2 + c1 * c2, c1.Eval()->value != 1);
TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2));
TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y);
TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y);
Expand Down Expand Up @@ -870,7 +902,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val);
auto bound = analyzer_->const_int_bound(residue);
if (bound.defined() && bound->max_value == bound->min_value) {
return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value));
PrimExpr out = x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value));
return RecursiveRewrite(out);
}

// try simplify divisor
Expand All @@ -882,7 +915,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
// ==> x' + d + (b * c1 + e) // c2
// ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1
// ==> x // (c2 // c1) + (y // c2)
return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div;
PrimExpr out = floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div;
return RecursiveRewrite(out);
}
}

Expand All @@ -903,6 +937,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_RECURSIVE_REWRITE_IF(
floordiv(x + c1, c2), floordiv(x + floormod(c1, c2), c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)));
TVM_TRY_RECURSIVE_REWRITE_IF(floordiv(x - c1, c2),
floordiv(x + floormod(-1 * c1, c2), c2) + floordiv(-1 * c1, c2),
c2.Eval()->value > 0);

// Rules involving 3-operands.
TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
Expand Down Expand Up @@ -951,6 +992,28 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
CanProveGreaterEqual(z.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0);

if (floordiv(x * c1, c2).Match(ret)) {
auto c1val = c1.Eval()->value;
auto c2val = c2.Eval()->value;
IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, c2val));

if (gcd->value > 1) {
// floormod(i*gcd, j*gcd) == floormod(i,j) * gcd
return floordiv(x.Eval() * floordiv(c1val, gcd), floordiv(c2val, gcd));
}
}

if (floordiv(x, c1).Match(ret)) {
ModularSet mod = analyzer_->modular_set(x.Eval());
int64_t c1val = c1.Eval()->value;
IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff));
if (gcd->value > 1) {
// floordiv(i*gcd, j*gcd) == floordiv(i,j)
PrimExpr out = floordiv(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd));
return RecursiveRewrite(out);
}
}
}
return ret;
}
Expand Down Expand Up @@ -1002,24 +1065,38 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {

if (IsIndexType(op->dtype)) {
// Be-aware of the division rules: we use floordiv/floormod here
TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2),
c2.Eval()->value != 0);
TVM_TRY_REWRITE_IF(
floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2),
c2.Eval()->value != 0 && c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value));

TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y,
c1.Eval()->value > 0 && c2.Eval()->value > 0 &&
c2.Eval()->value % c1.Eval()->value == 0 &&
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));

TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2),
c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)));

TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2),
c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)));
TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x - c1, c2), floormod(x + floormod(-1 * c1, c2), c2),
c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
c2.Eval()->value > 0);
TVM_TRY_RECURSIVE_REWRITE_IF(
floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2),
c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)));

TVM_TRY_RECURSIVE_REWRITE_IF(
floormod(x * c1, x * c2), x * floormod(c1, c2),
c2.Eval()->value != 0 &&
(c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)));

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
TVM_TRY_RECURSIVE_REWRITE_IF(floormod(floormod(x, c1) + y, c2), floormod(x + y, c2),
floormod(c1.Eval()->value, c2.Eval()->value) == 0);
TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + floormod(y, c1), c2), floormod(x + y, c2),
floormod(c1.Eval()->value, c2.Eval()->value) == 0);

TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));
Expand All @@ -1031,6 +1108,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
if (mod->coeff % c1val == 0 && c1val > 0) {
return floormod(mod->base, c1).Eval();
}
if (c1val > 0 && mod->coeff > 1) {
IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff));
if (gcd->value > 1) {
PrimExpr out = floormod(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd)) * gcd;
return RecursiveRewrite(out);
}
}
}
}
return ret;
Expand Down Expand Up @@ -1402,7 +1486,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1;
PVar<IntImm> c1, c2, c3;
PVar<int> lanes;

// vector rule
Expand All @@ -1424,6 +1508,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) {
TVM_TRY_REWRITE(c1 - x == 0, x == c1);
TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1);
TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0);

TVM_TRY_RECURSIVE_REWRITE(floormod(x + c1, c2) == c3,
c3 == floormod(c3, c2) && floormod(x, c2) == floormod(c3 - c1, c2));
}
return std::move(ret);
}
Expand Down
31 changes: 19 additions & 12 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_vector_simplify():
ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5))
ck.verify(
fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)),
tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4),
tvm.tir.Broadcast(fld(flm(x, 64), 2), 4),
)
ck.verify(
fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
Expand Down Expand Up @@ -136,10 +136,10 @@ def test_vector_simplify():
flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)),
)
ck.verify(
flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4)
flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 16) * 4, 1, 4)
)
ck.verify(
flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4)
flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 8) * 8, 2, 4)
)
ck.verify(
flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
Expand Down Expand Up @@ -255,7 +255,7 @@ def test_add_index_simplify():
ck.verify(x * y + 10 * x, x * (y + 10))

ck.verify((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y))
ck.verify(y * x + x, x * (y + 1))
ck.verify(y * x + x, (y + 1) * x)
ck.verify(x * y + x, x * (y + 1))
ck.verify((x + 10) + 13, x + 23)
ck.verify((x + 10) + (13 + z), x + z + 23)
Expand All @@ -281,7 +281,7 @@ def test_add_index_simplify():
flm = tvm.te.floormod
ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10))
ck.verify(fld(x, 8) * 8 + flm(x, 8), x)
ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2))
ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 1, 2) + 3)


def test_sub_index_simplify():
Expand Down Expand Up @@ -377,12 +377,12 @@ def test_sub_index_simplify():
ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True)
ck.verify(x - fld(x, 3) * 3, flm(x, 3))
ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3))
ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1)
ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 2, 3) + 1)
ck.verify(fld(x + 5, 3) - fld(x + 2, 3), 1)

ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3))
ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6)
ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5))
ck.verify(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6)
ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5))
ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z)
ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5))
ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z)
Expand Down Expand Up @@ -471,14 +471,14 @@ def test_floordiv_index_simplify():
ck.verify(fld(x * 4, 2), x * 2)
ck.verify(fld(x * 8 + 7, 16), fld(x, 2))
ck.verify(fld(x * 8 + 39, 16), fld(x, 2) + 2)
ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16))
ck.verify(fld(x * 8 - 1, 16), fld(x + 1, 2) + -1)
ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1)

ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True)
ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True)
ck.verify(fld(x * 360 + y, 16), x * 22)
ck.verify(fld(x * 360 + y, 25), x * 14)
ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25))
ck.verify(fld(x * 360 - 8, 25), fld(x * 72 + 3, 5) + -1)

ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2))
ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2)))
Expand All @@ -488,6 +488,9 @@ def test_floordiv_index_simplify():
ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3))
ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3))

# removal of negative offsets
ck.verify(fld(x - 17, 5), fld(x + 3, 5) + -4)

# 3-operands
ck.verify(fld(x * 6 + y + z, 2), x * 3 + fld(y + z, 2))
ck.verify(fld(x * 6 - y + (y + 3), 2), x * 3 + 1)
Expand Down Expand Up @@ -562,7 +565,7 @@ def test_floormod_index_simplify():
x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z")

ck.verify(flm(x * 10, 2), 0)
ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400))
ck.verify(flm(x * 9600, 6400), flm(x, 2) * 3200)
ck.verify(flm(x * 10 + y, 2), flm(y, 2))
ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16))
ck.verify(flm(x + 10, 2), flm(x, 2))
Expand All @@ -574,6 +577,10 @@ def test_floormod_index_simplify():
ck.verify(flm(x + (-10), 2), flm(x, 2))
ck.verify(flm(x + y * (-10), 2), flm(x, 2))

# removal of negative offsets
ck.verify(flm(x - 17, 5), flm(x + 3, 5))
ck.verify(flm(x + 17, 5), flm(x + 2, 5))

ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True)
ck.verify(flm(x * 32 + y, 64), flm(x, 2) * 32 + y)
ck.verify(flm(x * 32 - y, 64), flm(x * 32 - y, 64))
Expand Down
Loading