diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index f1838f5a9099..482e50be7bf4 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -221,12 +221,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(min(x, y + z * c1) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); - TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), - c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(min(y + z * c1, x) + z * c2, min(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(x, z * c1 + y) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(min(z * c1 + y, x) + z * c2, min(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + + TVM_TRY_REWRITE_IF(max(x, y + z * c1) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE_IF(max(y + z * c1, x) + z * c2, max(x + z * c2, y), c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(x, z * c1 + y) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); + TVM_TRY_REWRITE_IF(max(z * c1 + y, x) + z * c2, max(x + z * c2, y), + c1.Eval()->value == -c2.Eval()->value); TVM_TRY_REWRITE(max(x, y) + min(x, y), x + y); TVM_TRY_REWRITE(min(x, y) + max(x, y), x + y); @@ -241,13 +250,22 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // constant folding // NOTE: canonicalization might better at this. TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); + TVM_TRY_REWRITE((c1 - y) + c2, (c1 + c2) - y); + TVM_TRY_REWRITE((y - c1) + c2, y + (c2 - c1)); + + TVM_TRY_REWRITE(c1 + (x + c2), x + (c1 + c2)); + TVM_TRY_REWRITE((x + c1) + c2, x + (c1 + c2)); + TVM_TRY_REWRITE(c1 + (x - c2), x + (c1 - c2)); + TVM_TRY_REWRITE((x - c1) + c2, x + (c2 - c1)); + TVM_TRY_REWRITE(c1 + (c2 - x), (c1 + c2) - x); + TVM_TRY_REWRITE((c1 - x) + c2, (c1 + c2) - x); // mul co-efficient folding TVM_TRY_REWRITE(x + x, x * 2); TVM_TRY_REWRITE(x * y + x, x * (y + 1)); - TVM_TRY_REWRITE(y * x + x, x * (y + 1)); - TVM_TRY_REWRITE(x + y * x, x * (1 + y)); - TVM_TRY_REWRITE(x + x * y, x * (1 + y)); + TVM_TRY_REWRITE(y * x + x, (y + 1) * x); + TVM_TRY_REWRITE(x + y * x, (y + 1) * x); + TVM_TRY_REWRITE(x + x * y, x * (y + 1)); TVM_TRY_REWRITE(x * y + x * z, x * (y + z)); TVM_TRY_REWRITE(y * x + x * z, x * (y + z)); TVM_TRY_REWRITE(x * y + z * x, x * (y + z)); @@ -265,12 +283,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2), c2.Eval()->value > 0); + auto one = PConst(make_const(op->dtype, 1)); + TVM_TRY_REWRITE(floormod(x, 2) + floormod(x + 1, 2), one); + TVM_TRY_REWRITE(floormod(x + 1, 2) + floormod(x, 2), one); + // canonicalization rule // will try rewrite again after canonicalization. + TVM_TRY_RECURSIVE_REWRITE(c1 + x, x + c1); TVM_TRY_RECURSIVE_REWRITE(x + (c1 - y), (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE((c1 - y) + x, (x - y) + c1); - TVM_TRY_RECURSIVE_REWRITE(x + c1 + y, (x + y) + c1); - TVM_TRY_RECURSIVE_REWRITE(x + (c1 + y), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + (y - c1), (x + y) - c1); + TVM_TRY_RECURSIVE_REWRITE((y - c1) + x, (x + y) - c1); + TVM_TRY_RECURSIVE_REWRITE((x + c1) + y, (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + (y + c1), (x + y) + c1); + TVM_TRY_RECURSIVE_REWRITE(x + max(y, z), max(y, z) + x); TVM_TRY_RECURSIVE_REWRITE(x + min(y, z), min(y, z) + x); @@ -359,6 +385,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(x - min(x, y), max(0, x - y)); TVM_TRY_REWRITE(y - min(x, y), max(y - x, 0)); + // constant cancelation + TVM_TRY_REWRITE(c1 - (x + c2), (c1 - c2) - x); + TVM_TRY_REWRITE(c1 - (x - c2), (c1 + c2) - x); + TVM_TRY_REWRITE(c1 - (c2 - x), x + (c1 - c2)); + TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); + TVM_TRY_REWRITE((x - c1) - c2, x - (c1 + c2)); + TVM_TRY_REWRITE((c1 - x) - c2, (c1 - c2) - x); + TVM_TRY_RECURSIVE_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); + // mul co-efficient folding TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); TVM_TRY_REWRITE(x * y - x, x * (y - 1)); @@ -370,10 +405,6 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(x * y - z * x, x * (y - z)); TVM_TRY_REWRITE(y * x - z * x, x * (y - z)); - // constant cancelation - TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); - TVM_TRY_REWRITE((c1 - x) - (c2 - y), (y - x) + (c1 - c2)); - // cancelization rule involving 4 operands TVM_TRY_REWRITE((x + y) - (x + z), y - z); TVM_TRY_REWRITE((x + y) - (z + x), y - z); @@ -493,6 +524,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { // canonicalization rule // will try rewrite again after canonicalization. TVM_TRY_REWRITE(x - c1, x + (0 - c1)); + TVM_TRY_RECURSIVE_REWRITE(x - (y + c1), (x - y) - c1); TVM_TRY_RECURSIVE_REWRITE((x + c1) - y, (x - y) + c1); TVM_TRY_RECURSIVE_REWRITE(x - (y - z), (x + z) - y); TVM_TRY_RECURSIVE_REWRITE(x - y * c1, x + y * (0 - c1)); @@ -536,7 +568,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode* op) { if (IsIndexType(op->dtype)) { // constant simplification rule - TVM_TRY_REWRITE((x + c1) * c2, x * c2 + c1 * c2); + TVM_TRY_REWRITE_IF((x + c1) * c2, x * c2 + c1 * c2, c1.Eval()->value != 1); TVM_TRY_REWRITE((x * c1) * c2, x * (c1 * c2)); TVM_TRY_REWRITE(min(x, y) * max(x, y), x * y); TVM_TRY_REWRITE(max(x, y) * min(x, y), x * y); @@ -870,7 +902,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { PrimExpr y_div = CanProveEqual(floordiv(yval, c2val), 0) ? 0 : floordiv(yval, c2val); auto bound = analyzer_->const_int_bound(residue); if (bound.defined() && bound->max_value == bound->min_value) { - return x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + PrimExpr out = x.Eval() * floordiv(c1val, c2.Eval()) + (y_div + Integer(bound->max_value)); + return RecursiveRewrite(out); } // try simplify divisor @@ -882,7 +915,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // ==> x' + d + (b * c1 + e) // c2 // ==> x' + d since 0 <= b * c1 <= (a-1) * c1, 0 <= e < c1 // ==> x // (c2 // c1) + (y // c2) - return floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + PrimExpr out = floordiv(x.Eval(), floordiv(c2val, c1val)) + y_div; + return RecursiveRewrite(out); } } @@ -903,6 +937,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(max(y, x * c1), c2), max(floordiv(y, c2), x * floordiv(c1, c2)), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floordiv(x + c1, c2), floordiv(x + floormod(c1, c2), c2) + floordiv(c1, c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + TVM_TRY_RECURSIVE_REWRITE_IF(floordiv(x - c1, c2), + floordiv(x + floormod(-1 * c1, c2), c2) + floordiv(-1 * c1, c2), + c2.Eval()->value > 0); + // Rules involving 3-operands. TVM_TRY_REWRITE_IF(floordiv(x * c1 + y + z, c2), x * floordiv(c1, c2) + floordiv(y + z, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); @@ -951,6 +992,28 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { CanProveGreaterEqual(z.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(x - floormod(x, c1), c1), floordiv(x, c1), c1.Eval()->value != 0); + + if (floordiv(x * c1, c2).Match(ret)) { + auto c1val = c1.Eval()->value; + auto c2val = c2.Eval()->value; + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, c2val)); + + if (gcd->value > 1) { + // floormod(i*gcd, j*gcd) == floormod(i,j) * gcd + return floordiv(x.Eval() * floordiv(c1val, gcd), floordiv(c2val, gcd)); + } + } + + if (floordiv(x, c1).Match(ret)) { + ModularSet mod = analyzer_->modular_set(x.Eval()); + int64_t c1val = c1.Eval()->value; + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff)); + if (gcd->value > 1) { + // floordiv(i*gcd, j*gcd) == floordiv(i,j) + PrimExpr out = floordiv(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd)); + return RecursiveRewrite(out); + } + } } return ret; } @@ -1002,24 +1065,38 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (IsIndexType(op->dtype)) { // Be-aware of the division rules: we use floordiv/floormod here - TVM_TRY_REWRITE_IF(floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), - c2.Eval()->value != 0); + TVM_TRY_REWRITE_IF( + floormod(x * c1, c2), floormod(x * floormod(c1, c2), c2), + c2.Eval()->value != 0 && c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value)); TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x, floordiv(c2, c1)) * c1 + y, c1.Eval()->value > 0 && c2.Eval()->value > 0 && c2.Eval()->value % c1.Eval()->value == 0 && CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0)); - TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), - c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) + y, c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); - TVM_TRY_REWRITE_IF(floormod(x + c1, c2), floormod(x, c2), - c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x + c1, c2), floormod(x + floormod(c1, c2), c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x - c1, c2), floormod(x + floormod(-1 * c1, c2), c2), + c2.Eval()->value > 0); - TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), - c2.Eval()->value > 0); + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x + y * c1, c2), floormod(x + y * floormod(c1, c2), c2), + c2.Eval()->value > 0 && (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); + + TVM_TRY_RECURSIVE_REWRITE_IF( + floormod(x * c1, x * c2), x * floormod(c1, c2), + c2.Eval()->value != 0 && + (c1.Eval()->value != floormod(c1.Eval()->value, c2.Eval()->value))); - TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(floormod(x, c1) + y, c2), floormod(x + y, c2), + floormod(c1.Eval()->value, c2.Eval()->value) == 0); + TVM_TRY_RECURSIVE_REWRITE_IF(floormod(x + floormod(y, c1), c2), floormod(x + y, c2), + floormod(c1.Eval()->value, c2.Eval()->value) == 0); TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); @@ -1031,6 +1108,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { if (mod->coeff % c1val == 0 && c1val > 0) { return floormod(mod->base, c1).Eval(); } + if (c1val > 0 && mod->coeff > 1) { + IntImm gcd(x.Eval()->dtype, ZeroAwareGCD(c1val, mod->coeff)); + if (gcd->value > 1) { + PrimExpr out = floormod(floordiv(x.Eval(), gcd), floordiv(c1.Eval(), gcd)) * gcd; + return RecursiveRewrite(out); + } + } } } return ret; @@ -1402,7 +1486,7 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { // Pattern var to match any expression PVar x, y; // Pattern var match IntImm - PVar c1; + PVar c1, c2, c3; PVar lanes; // vector rule @@ -1424,6 +1508,9 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ ret) { TVM_TRY_REWRITE(c1 - x == 0, x == c1); TVM_TRY_REWRITE(x + c1 == 0, x == 0 - c1); TVM_TRY_RECURSIVE_REWRITE(x * y == 0, x == 0 || y == 0); + + TVM_TRY_RECURSIVE_REWRITE(floormod(x + c1, c2) == c3, + c3 == floormod(c3, c2) && floormod(x, c2) == floormod(c3 - c1, c2)); } return std::move(ret); } diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index d6c2cfe8bbdd..d6ab1126e86d 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -82,7 +82,7 @@ def test_vector_simplify(): ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5)) ck.verify( fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)), - tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4), + tvm.tir.Broadcast(fld(flm(x, 64), 2), 4), ) ck.verify( fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)), @@ -136,10 +136,10 @@ def test_vector_simplify(): flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)), ) ck.verify( - flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 4, 64), 1, 4) + flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 16) * 4, 1, 4) ) ck.verify( - flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x * 8, 64), 2, 4) + flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)), tvm.tir.Ramp(flm(x, 8) * 8, 2, 4) ) ck.verify( flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)), @@ -255,7 +255,7 @@ def test_add_index_simplify(): ck.verify(x * y + 10 * x, x * (y + 10)) ck.verify((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)) - ck.verify(y * x + x, x * (y + 1)) + ck.verify(y * x + x, (y + 1) * x) ck.verify(x * y + x, x * (y + 1)) ck.verify((x + 10) + 13, x + 23) ck.verify((x + 10) + (13 + z), x + z + 23) @@ -281,7 +281,7 @@ def test_add_index_simplify(): flm = tvm.te.floormod ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)) ck.verify(fld(x, 8) * 8 + flm(x, 8), x) - ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)) + ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 1, 2) + 3) def test_sub_index_simplify(): @@ -377,12 +377,12 @@ def test_sub_index_simplify(): ck.analyzer.update(x, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(-1000, 1000), override=True) ck.verify(x - fld(x, 3) * 3, flm(x, 3)) - ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 5, 3)) - ck.verify(fld(x + 5, 3) - fld(x + 2, 3), fld(flm(x + 2, 3), 3) + 1) + ck.verify(fld(x + 5, 3) - fld(x, 3), fld(flm(x, 3) + 2, 3) + 1) + ck.verify(fld(x + 5, 3) - fld(x + 2, 3), 1) ck.verify(fld(y, 3) * 3 - y, 0 - flm(y, 3)) - ck.verify(y - fld(y - 6, 5) * 5, flm(y + (-6), 5) + 6) - ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + (-6), 5)) + ck.verify(y - fld(y - 6, 5) * 5, flm(y + 4, 5) + 6) + ck.verify(fld(y - 6, 5) * 5 - y, (-6) - flm(y + 4, 5)) ck.verify(y - fld(y + z, 5) * 5, flm(y + z, 5) - z) ck.verify(fld(y + z, 5) * 5 - y, z - flm(y + z, 5)) ck.verify(y - fld(y - z, 5) * 5, flm(y - z, 5) + z) @@ -471,14 +471,14 @@ def test_floordiv_index_simplify(): ck.verify(fld(x * 4, 2), x * 2) ck.verify(fld(x * 8 + 7, 16), fld(x, 2)) ck.verify(fld(x * 8 + 39, 16), fld(x, 2) + 2) - ck.verify(fld(x * 8 - 1, 16), fld(x * 8 + -1, 16)) + ck.verify(fld(x * 8 - 1, 16), fld(x + 1, 2) + -1) ck.verify(fld(x * 8 - 9, 16), fld(x, 2) + -1) ck.analyzer.update(x, tvm.arith.ConstIntBound(0, 1), override=True) ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 7), override=True) ck.verify(fld(x * 360 + y, 16), x * 22) ck.verify(fld(x * 360 + y, 25), x * 14) - ck.verify(fld(x * 360 - 8, 25), fld(x * 360 + -8, 25)) + ck.verify(fld(x * 360 - 8, 25), fld(x * 72 + 3, 5) + -1) ck.verify(fld(x * 4 + y, 2), x * 2 + fld(y, 2)) ck.verify(fld(tvm.te.min(x * 6, y), 2), tvm.te.min(x * 3, fld(y, 2))) @@ -488,6 +488,9 @@ def test_floordiv_index_simplify(): ck.verify(fld(tvm.te.min(y, x * 6), 2), tvm.te.min(fld(y, 2), x * 3)) ck.verify(fld(tvm.te.max(y, x * 6), 2), tvm.te.max(fld(y, 2), x * 3)) + # removal of negative offsets + ck.verify(fld(x - 17, 5), fld(x + 3, 5) + -4) + # 3-operands ck.verify(fld(x * 6 + y + z, 2), x * 3 + fld(y + z, 2)) ck.verify(fld(x * 6 - y + (y + 3), 2), x * 3 + 1) @@ -562,7 +565,7 @@ def test_floormod_index_simplify(): x, y, nx, ny, z = te.var("x"), te.var("y"), te.var("nx"), te.var("ny"), te.var("z") ck.verify(flm(x * 10, 2), 0) - ck.verify(flm(x * 9600, 6400), flm(x * 3200, 6400)) + ck.verify(flm(x * 9600, 6400), flm(x, 2) * 3200) ck.verify(flm(x * 10 + y, 2), flm(y, 2)) ck.verify(flm(x * 360 + y, 16), flm(x * 8 + y, 16)) ck.verify(flm(x + 10, 2), flm(x, 2)) @@ -574,6 +577,10 @@ def test_floormod_index_simplify(): ck.verify(flm(x + (-10), 2), flm(x, 2)) ck.verify(flm(x + y * (-10), 2), flm(x, 2)) + # removal of negative offsets + ck.verify(flm(x - 17, 5), flm(x + 3, 5)) + ck.verify(flm(x + 17, 5), flm(x + 2, 5)) + ck.analyzer.update(y, tvm.arith.ConstIntBound(0, 31), override=True) ck.verify(flm(x * 32 + y, 64), flm(x, 2) * 32 + y) ck.verify(flm(x * 32 - y, 64), flm(x * 32 - y, 64)) diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 76cad250e053..7b71f4d4ab17 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -28,6 +28,7 @@ import tvm.testing from tvm import relay, te from tvm.topi.math import cast +from tvm.script import tir as T dtype = tvm.testing.parameter("float32", "int32", "float16", "int8") @@ -558,5 +559,46 @@ def do_compute(ins, outs): tvm.build(s, [Out], target) +def test_negative_operand_divmod(target, dev): + """Test handling of negative offsets to floormod/floordiv + + Even though the SPIR-V spec states that OpSRem and OpSMod can give + the signed modulo, the Vulkan spec states that any use of negative + operands is undefined behavior. This test starts with negative + operands to floordiv, validating that they are simplified into the + corresponding positive operands, such that the final TIR can be + expressed using only positive operands. + + SPIR-V: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpSRem + Vulkan: https://registry.khronos.org/vulkan/specs/1.3/html/chap37.html#spirvenv-op-prec + """ + + N = 32 + offset = 16 + divisor = 5 + + @T.prim_func + def func(A: T.Buffer[(N, 2), "int32"]): + for i in T.serial(N): + with T.block("A"): + v_i = T.axis.spatial(N, i) + A[v_i, 0] = T.floordiv(v_i - offset, divisor) + A[v_i, 1] = T.floormod(v_i - offset, divisor) + + if "gpu" in tvm.target.Target(target).keys: + sch = tvm.tir.Schedule(func) + sch.bind(sch.get_loops("A")[0], "threadIdx.x") + func = sch.mod["main"] + + built = tvm.build(func, target=target) + + a_dev = tvm.nd.empty([N, 2], "int32", dev) + built(a_dev) + a = a_dev.numpy() + + np.testing.assert_array_equal(a[:, 0], (np.arange(N) - offset) // divisor) + np.testing.assert_array_equal(a[:, 1], (np.arange(N) - offset) % divisor) + + if __name__ == "__main__": tvm.testing.main()