From ed03f1463a985cc9d4d15983ea3611aeebae410d Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 12 Jun 2023 22:23:11 -0400 Subject: [PATCH] [ARITH] Improve arith simplify to handle symbolic reshape pattern This PR enhances arith simplify to handle symbolic reshape patterns. Lift the CombineIters to callers of TryFuseIters so they can be used in early return simplifications. Testcases are added. Also updates a minor spelling issue in the testcase. --- src/arith/iter_affine_map.cc | 11 +++-- .../unittest/test_arith_iter_affine_map.py | 41 ++++++++++++------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index ed2c40da72a1..377f8bb7c9b1 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -722,6 +722,10 @@ class IterMapRewriter : public ExprMutator { IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) { // We are normalizing a regular iter if (expr->args.size() < 1) return expr; + if (auto opt = TryCombineSplitFromSameSource(expr)) { + expr = opt.value(); + if (expr->args.size() < 1) return expr; + } Optional opt = TryFuseIters(expr, check_level_); if (opt.defined()) { return opt.value(); @@ -995,9 +999,6 @@ class IterMapRewriter : public ExprMutator { * \return The sum with the fused IterMark and extra offset if succeed. */ Optional TryFuseIters(IterSumExpr expr, IterMapLevel check_level) { - if (auto opt = TryCombineSplitFromSameSource(expr)) { - expr = opt.value(); - } // select the iterators in order std::vector visited(expr->args.size(), false); int base_index = FindBaseIter(expr, visited, NullOpt); @@ -1553,6 +1554,10 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o return IterSumExpr(); } else if (sum->args.size() == 1) { return sum; + } else if (auto opt = TryCombineSplitFromSameSource(sum)) { + if (opt.value()->args.size() == 1) { + return opt.value(); + } } auto opt_fused = TryFuseIters(sum, check_level_); if (!opt_fused) { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index cbca1bb325d8..640d7592ad88 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -83,7 +83,7 @@ def assert_iter_sum_pattern( tvm.ir.assert_structural_equal(sum_expr, expect_iter) -def assert_iter_map_simplfy( +def assert_iter_map_simplify( expect_dict, dom_map, predicate=True, check_level="surjective", simplify_trivial_iterators=True ): keys = list(expect_dict.keys()) @@ -1120,28 +1120,28 @@ def test_iter_map_simplify_symbolic_case(): def simple_fuse0(x): return (x // n) * n + x % n - assert_iter_map_simplfy({simple_fuse0(x): x}, var_dom([(x, n * 32)])) + assert_iter_map_simplify({simple_fuse0(x): x}, var_dom([(x, n * 32)])) - assert_iter_map_simplfy({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) + assert_iter_map_simplify({simple_fuse0(z): z}, var_dom([(x, n), (y, 32)])) def fsymbolic_fuse0(x): return ((x // (n * n)) % 32) * (n * n) + ((x // n) % n) * n + x % n - assert_iter_map_simplfy({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse0(x): x}, var_dom([(x, n * n * 32)])) - assert_iter_map_simplfy({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) + assert_iter_map_simplify({fsymbolic_fuse0(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse1(x): return ((x % (n * n * 32)) // (n * n) * n + (x % (n * n) // n)) * n + x % n - assert_iter_map_simplfy({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse1(x): x}, var_dom([(x, n * n * 32)])) - assert_iter_map_simplfy({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) + assert_iter_map_simplify({fsymbolic_fuse1(z): z}, var_dom([(x, n * n), (y, 32)])) def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n - assert_iter_map_simplfy({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) + assert_iter_map_simplify({fsymbolic_fuse2(x): x}, var_dom([(x, n * n * 32)])) def test_iter_map_simplify_symbolic_predicate(): @@ -1155,7 +1155,7 @@ def simple_fuse0(x): return (x // n) * n + x % n z = x * 32 + y - assert_iter_map_simplfy( + assert_iter_map_simplify( {simple_fuse0(z): z}, var_dom([(x, (n + 1) // 2), (y, 32)]), predicate=(z < n * 16) ) @@ -1163,13 +1163,26 @@ def fsymbolic_fuse2(i): return (i // (n * n) * n + i % (n * n) // n) * n + i % n z = x * 64 + y - assert_iter_map_simplfy( + assert_iter_map_simplify( {fsymbolic_fuse2(z): z}, var_dom([(x, (n * n + 1) // 2), (y, 64)]), predicate=(z < n * n * 32), ) +def test_iter_map_simplify_symbolic_reshape(): + n = tvm.tir.Var("n", "int64") + fused = tvm.tir.Var("fused", "int64") + + ax0 = (fused // 4096) // n + ax1 = (fused // 4096) % n + ax2 = fused % 4096 + + rhs_index = ((ax2 // 4096 + ax0 * n + ax1) % n) * 4096 + ax2 % 4096 + + assert_iter_map_simplify({rhs_index: fused}, var_dom([(fused, n * 4096)])) + + def test_iter_map_simplify_unit_loop_order(): """Test itermap simplify""" x = tvm.tir.Var("x", "int64") @@ -1178,18 +1191,18 @@ def test_iter_map_simplify_unit_loop_order(): # trivial iterators can be found at any when comparing via scale # ensure order unchange - assert_iter_map_simplfy( + assert_iter_map_simplify( {x + y + z: x + y + z}, var_dom([(x, 1), (y, 1), (z, 1)]), simplify_trivial_iterators=False ) # Even with simplifcation, it should follow the original order - assert_iter_map_simplfy( + assert_iter_map_simplify( {x + y + (z // 4) * 4 + z % 4: z + x + y}, var_dom([(x, 1), (y, 1), (z, 32)]), simplify_trivial_iterators=False, ) - assert_iter_map_simplfy( + assert_iter_map_simplify( {y + 64 - x % 2 * 64: y + 64 - x % 2 * 64}, var_dom([(x, 6), (y, 64)]), simplify_trivial_iterators=False, @@ -1197,7 +1210,7 @@ def test_iter_map_simplify_unit_loop_order(): # When we have iterators that have same scale but one of them come # with unit extent, we should prioritize unit extent - assert_iter_map_simplfy( + assert_iter_map_simplify( {x // 128 + y + z: y + x // 128 + z}, var_dom([(x, 128), (y, 128), (z, 1)]), simplify_trivial_iterators=False,