From 492fc5b7ea208733eb70260f58d747ec5cb76e62 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Jun 2025 09:14:17 -0400 Subject: [PATCH 1/2] [ARITH] Canonicalize mul-coefficient to rhs This PR updates the rewrite simplify logic to canonicalize mul-coefficient to rhs. This change is consistent with rest of the code base and allows better simplification of more cases. A test case of floormod with linear offset is added. Co-authored-by: Ghosts381937 --- src/arith/rewrite_simplify.cc | 12 ++++---- .../arith/test_arith_rewrite_simplify.py | 28 +++++++++---------- tests/python/arith/test_arith_simplify.py | 13 +++++++++ .../test_tir_transform_common_subexpr_elim.py | 2 +- 4 files changed, 34 insertions(+), 21 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index cda27663520e..c911124700fe 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -446,10 +446,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { // mul co-efficient folding TVM_TRY_REWRITE(x + x, x * 2); - TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), x * (y + 1)); + TVM_TRY_REWRITE(matches_one_of(x * y + x, y * x + x, x + y * x, x + x * y), (y + 1) * x); TVM_TRY_REWRITE(matches_one_of(x * y + x * z, y * x + x * z, x * y + z * x, y * x + z * x), - x * (y + z)); + (y + z) * x); // DivMod rules // truc div @@ -563,12 +563,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) { TVM_TRY_REWRITE(matches_one_of(max(x, y) - y, x - min(y, x)), max(x - y, 0)); TVM_TRY_REWRITE(matches_one_of(x - min(x, y), max(y, x) - y), max(0, x - y)); - // mul co-efficient folding + // mul co-efficient folding: pefer co-effiicent to stay at rhs TVM_TRY_REWRITE(x - x, ZeroWithTypeLike(x)); - TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), x * (y - 1)); - TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), x * (1 - y)); + TVM_TRY_REWRITE(matches_one_of(x * y - x, y * x - x), (y - 1) * x); + TVM_TRY_REWRITE(matches_one_of(x - y * x, x - x * y), (1 - y) * x); TVM_TRY_REWRITE(matches_one_of(x * y - x * z, y * x - x * z, x * y - z * x, y * x - z * x), - x * (y - z)); + (y - z) * x); // constant cancelation TVM_TRY_REWRITE((x + c1) - c2, x + (c1 - c2)); diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index ad4abdfe2934..6954cf4e1d5c 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -391,17 +391,17 @@ class TestAddIndex(BaseCompare): TestCase(tvm.te.max(2 - x * 4, 0) + x * 4, tvm.te.max(x * 4, 2)), TestCase(tvm.te.min(0, 1 - x * 4) + x * 4, tvm.te.min(x * 4, 1)), TestCase(tvm.te.min(2 - x * 4, 0) + x * 4, tvm.te.min(x * 4, 2)), - TestCase(x * y + x * 10, x * (y + 10)), - TestCase(y * x + x * 10, x * (y + 10)), - TestCase(y * x + 10 * x, x * (y + 10)), - TestCase(x * y + 10 * x, x * (y + 10)), + TestCase(x * y + x * 10, (y + 10) * x), + TestCase(y * x + x * 10, (y + 10) * x), + TestCase(y * x + 10 * x, (y + 10) * x), + TestCase(x * y + 10 * x, (y + 10) * x), TestCase((2 * z) + tvm.te.min(x, y - (2 * z)), tvm.te.min(x + (z * 2), y)), - TestCase(y * x + x, x * (y + 1)), - TestCase(x * y + x, x * (y + 1)), + TestCase(y * x + x, (y + 1) * x), + TestCase(x * y + x, (y + 1) * x), TestCase((x + 10) + 13, x + 23), TestCase((x + 10) + (13 + z), x + z + 23), - TestCase(x * y + 10 * x, x * (y + 10)), - TestCase(y * x + x * 3, x * (y + 3)), + TestCase(x * y + 10 * x, (y + 10) * x), + TestCase(y * x + x * 3, (y + 3) * x), TestCase(x + 3 + y, x + y + 3), TestCase((3 - y) + x, x - y + 3), # canonicalization @@ -409,10 +409,10 @@ class TestAddIndex(BaseCompare): TestCase(x + 2 + 3 + 4 + x * 3, x * 4 + 9), # DivMod rules # trunc div - TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), tmod(x, 8) * (y + 10)), + TestCase(y * tmod(x, 8) + 10 * tmod(x, 8), (y + 10) * tmod(x, 8)), TestCase(tdiv(x, 8) * 8 + tmod(x, 8), x), # floor div - TestCase(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10)), + TestCase(y * flm(x, 8) + 10 * flm(x, 8), (y + 10) * flm(x, 8)), TestCase(fld(x, 8) * 8 + flm(x, 8), x), TestCase(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2)), ) @@ -436,10 +436,10 @@ class TestSubIndex(BaseCompare): TestCase(y - tvm.te.max(x, y), tvm.te.min(y - x, 0)), # mul co-efficient foldng TestCase(x - x, 0), - TestCase(x * y - x, x * (y + (-1))), - TestCase(x * y - 10 * x, x * (y + (-10))), - TestCase(y * x - x * z, x * (y - z)), - TestCase(y * x - z * x, x * (y - z)), + TestCase(x * y - x, (y + (-1)) * x), + TestCase(x * y - 10 * x, (y + (-10)) * x), + TestCase(y * x - x * z, (y - z) * x), + TestCase(y * x - z * x, (y - z) * x), TestCase(x + 10 - 20, x + (-10)), # 4-operands pattern TestCase((x + y) - (x + z), y - z), diff --git a/tests/python/arith/test_arith_simplify.py b/tests/python/arith/test_arith_simplify.py index 4971acbd4512..5a61cb8a52a9 100644 --- a/tests/python/arith/test_arith_simplify.py +++ b/tests/python/arith/test_arith_simplify.py @@ -131,5 +131,18 @@ def test_regression_simplify_inf_recursion(): ana.rewrite_simplify(res) +def test_simplify_floor_mod_with_linear_offset(): + """ + Test that the floor_mod is simplified correctly when the offset is linear. + """ + ana = tvm.arith.Analyzer() + past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64") + expr1 = (past_decoder_sequence_length + 1) * 64 + divisor1 = (past_decoder_sequence_length + 1) * 32 + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor1), 0) + divisor2 = 32 * (past_decoder_sequence_length + 1) + assert ana.can_prove_equal(tvm.tir.floormod(expr1, divisor2), 0) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index 5208262221b9..e7e64d89168e 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -352,7 +352,7 @@ def test_no_normalization_without_commoning(): def func_distributivity( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - B[i1] = x * (y + z) + B[i1] = (y + z) * x B[i2] = x * y + x * z From 695b2dadd3a0254176705f0237733df4367affc8 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 2 Jun 2025 12:53:31 -0400 Subject: [PATCH 2/2] Fix the grad testcase caused by the changed simplification behavior --- tests/python/relax/test_transform_legalize_ops_grad.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py index f5a20b298a57..44469acdc1c0 100644 --- a/tests/python/relax/test_transform_legalize_ops_grad.py +++ b/tests/python/relax/test_transform_legalize_ops_grad.py @@ -282,7 +282,8 @@ def avg_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64 T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3]) with T.init(): T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0) - T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), rxplaceholder[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - v_wh * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) - v_wh * T.int64(2) - T.int64(2), T.int64(0))) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - v_ww * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - v_ww * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0)) + T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww < T.int64(5), rxplaceholder[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww] / T.Cast("float32", T.max((T.min(T.Div(v_ax2 + T.int64(2), T.int64(2)) * T.int64(2) + T.int64(3) - v_wh * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh - T.int64(1), T.int64(0)) * T.int64(2)) * (T.min(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) + T.int64(4) - v_ww * T.int64(2), T.int64(10)) - T.max(T.Div(v_ax3 + T.int64(1), T.int64(2)) * T.int64(2) - v_ww * T.int64(2) - T.int64(1), T.int64(0))), T.int64(1))), T.float32(0.0)) + @R.function def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"): cls = Expected