From fa5eea28e48f7b143bd6ea831c99e993a10a374d Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Jul 2021 13:36:52 -0400 Subject: [PATCH 1/3] [Arith] Inverse affine map * [Arith] Inverse affine map * Update iter_affine_map.h * Update iter_affine_map.h * Update iter_affine_map.py * Topology order visit * doc * fix * address comments --- include/tvm/arith/iter_affine_map.h | 21 +++ python/tvm/arith/__init__.py | 3 +- python/tvm/arith/iter_affine_map.py | 27 ++++ src/arith/iter_affine_map.cc | 142 ++++++++++++++++++ .../unittest/test_arith_iter_affine_map.py | 61 ++++++++ 5 files changed, 253 insertions(+), 1 deletion(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 641d0e0f5321..d671339fb66b 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -283,6 +283,27 @@ Array DetectIterMap(const Array& indices, const Map InverseAffineIterMap(const Array& iter_map, + const Array outputs); + /*! * \brief Detect if bindings can be written as * [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index d1e4431a2e0e..bcccac6bedab 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,5 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide +from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide, \ + inverse_affine_iter_map diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index bfd5dfadc800..891d4b55f6e8 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -173,3 +173,30 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi Empty array if no match can be found. """ return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) + + +def inverse_affine_iter_map(iter_map, outputs): + """ Apply the inverse of the affine transformation to the outputs. + Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions + in reverse topology order and applies the inverse of the affine transformation until it reaches + the input. The affine iter map is required to be bijective. + + For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], + the affine transformation specified by `iter_map` will be applied to `outputs` and the result + will be {l0: ((output_0*16) + output_1)}. + + See also :any:`detect_iter_map`. + + Parameters + ---------- + iter_map : List[IterSumExpr] + The bijective affine iter map. + outputs : List[PrimExpr] + The outputs of the affine transformation. + + Returns + ------- + results : Map[Var, PrimExpr] + The map from the input to the transformed result. + """ + return _ffi_api.InverseAffineIterMap(iter_map, outputs) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c1daae967b47..4231ff9d256d 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1385,5 +1385,147 @@ TVM_REGISTER_GLOBAL("arith.SubspaceDivide") return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); }); +class InverseAffineIterMapTransformer { + public: + explicit InverseAffineIterMapTransformer(Analyzer* analyzer) : analyzer_(analyzer) {} + + Map operator()(const Array& iter_map, + const Array& outputs) { + ICHECK(iter_map.size() == outputs.size()); + std::vector post_dfs_order = ReverseTopologyOrder(iter_map); + + // initialize back propagation accumulator + for (const IterMapExprNode* node: post_dfs_order) { + backprop_.Set(GetRef(node), Integer(0)); + } + for (size_t i = 0; i < iter_map.size(); i++) { + backprop_.Set(iter_map[i], outputs[i]); + } + + // run back propagation + for (const IterMapExprNode* node: post_dfs_order) { + if (node->IsInstance()) { + Visit_(Downcast(GetRef(node))); + } else { + ICHECK(node->IsInstance()); + Visit_(Downcast(GetRef(node))); + } + } + return std::move(inverse_); + } + + private: + void Visit_(const IterSumExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) - iter_map_expr->base; + + // Case 1: Propagate to the input node directly when the sum expression has only one components + if (iter_map_expr->args.size() == 1) { + const auto& source = iter_map_expr->args[0]; + backprop_.Set(source, backprop_.at(source) + input); + return; + } + + // Case 2: If the sum expression has multiple components, match the fuse pattern and then split + // the sum expression for each components. + // For example, consider the iterator i1[dom = (0, 16)], i2[dom = (0, 8)], fusing i1 and i2 + // we will have i1_i2_fused[dom = (0, 64)]. During back propagation, we need to split the + // propagated value to get the corresponding components of i1 and i2, which are + // floordiv(i1_i2_fused, 8) and floormod(i1_i2_fused, 8), respectively. + Array splits = MatchFusePattern(iter_map_expr); + ICHECK(!splits.empty()); + + for (const IterSplitExpr& split : splits) { + backprop_.Set(split, backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); + } + } + + + std::vector ReverseTopologyOrder(const Array& iter_map) { + std::vector post_dfs_order; + std::unordered_map visited; + + std::function fvisit = [&](const IterMapExpr& expr) { + if (visited[expr]) { + return; + } + visited[expr] = true; + if (const auto* sum_expr = expr.as()) { + for (const IterSplitExpr& child : sum_expr->args) { + fvisit(child); + } + } else { + const auto* split_expr = expr.as(); + ICHECK(split_expr); + if (const auto* source = split_expr->source->source.as()) { + fvisit(GetRef(source)); + } + } + post_dfs_order.push_back(expr.get()); + }; + for (const IterSumExpr& expr : iter_map) { + fvisit(expr); + } + std::reverse(post_dfs_order.begin(), post_dfs_order.end()); + return post_dfs_order; + } + + void Visit_(const IterSplitExpr& iter_map_expr) { + PrimExpr input = backprop_.at(iter_map_expr) * iter_map_expr->lower_factor; + const IterMark& source = iter_map_expr->source; + if (source->source.as()) { + IterSumExpr source_expr = Downcast(source->source); + backprop_.Set(source_expr, backprop_.at(source_expr) + input); + } else { + Var source_var = Downcast(source->source); + if (inverse_.count(source_var)) { + inverse_.Set(source_var, inverse_.at(source_var) + input); + } else { + inverse_.Set(source_var, input); + } + } + } + + Array MatchFusePattern(const IterSumExpr sum_expr) { + IntImm base_scale(nullptr); + size_t base_index = 0; + for (size_t i = 0; i < sum_expr->args.size(); ++i) { + if (const auto* op = sum_expr->args[i]->scale.as()) { + if (!base_scale.defined() || op->value < base_scale->value) { + base_scale = GetRef(op); + base_index = i; + } + } + } + ICHECK(base_scale.defined()); + std::vector iters; + std::vector visited(sum_expr->args.size(), false); + PrimExpr expected_scale = base_scale; + for (size_t i = 0; i < sum_expr->args.size(); i++) { + size_t j = i == 0 ? base_index : 0; + for (; j < sum_expr->args.size(); ++j) { + if (!visited[j] && analyzer_->CanProveEqual(sum_expr->args[j]->scale, expected_scale)) + break; + } + ICHECK(j != sum_expr->args.size()); + visited[j] = true; + iters.push_back(sum_expr->args[j]); + expected_scale *= sum_expr->args[j]->extent; + } + return iters; + } + + Analyzer* analyzer_; + Map backprop_; // the accumulator of backpropgation + Map inverse_; // the result of inverse transformation +}; + +Map InverseAffineIterMap(const Array& iter_map, + const Array outputs) { + Analyzer analyzer; + return InverseAffineIterMapTransformer(&analyzer)(iter_map, outputs); +} + +TVM_REGISTER_GLOBAL("arith.InverseAffineIterMap").set_body_typed(InverseAffineIterMap); + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 7bfdfc676b67..e6b09f860b9a 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -643,6 +643,66 @@ def test_normalize_iter_map_to_expr(): tvm.ir.assert_structural_equal(tvm.arith.normalize_iter_map_to_expr(res[1]), flm(x[0], 5)) +def test_inverse_affine_iter_map(): + analyzer = tvm.arith.Analyzer() + l0 = create_iter("l0", 64) + l1 = create_iter("l1", 64) + l2 = create_iter("l2", 64) + + # simple case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l0_1_l1_1_fused = ifuse([l0_1, l1_1]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], + var_dom([l0, l1])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + print(res) + assert len(res) == 2 + l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 + l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4 + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + + # compound case + l0_0, l0_1 = isplit(l0, 16) + l1_0, l1_1 = isplit(l1, 4) + l2_1, l2_2 = isplit(l2, 4) + l2_0, l2_1 = isplit(l2_1, 4) + + l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) + + iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], + var_dom([l0, l1, l2])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 3 + l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 + l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 + l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 \ + + outputs[2] + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 + assert analyzer.simplify(res[l2[0]] - l2_inverse) == 0 + + # diamond-shape DAG + l0_0, l0_1 = isplit(l0, 16) + l1 = ifuse([l0_1, l0_0]) + l1_0, l1_1 = isplit(l1, 8) + l2 = ifuse([l1_1, l1_0]) + + iter_map = tvm.arith.detect_iter_map([l2[0]], var_dom([l0])) + outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] + res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) + assert len(res) == 1 + l1_inverse = floormod(outputs[0], 8) * 8 + floormod(floordiv(outputs[0], 8), 8) + l0_inverse = floormod(l1_inverse, 4) * 16 + floormod(floordiv(l1_inverse, 4), 16) + + assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 + + if __name__ == "__main__": test_split() test_trivial() @@ -652,3 +712,4 @@ def test_normalize_iter_map_to_expr(): test_normalize_iter_map_to_expr() test_subspace_division() test_complex() + test_inverse_affine_iter_map() From ff4f58d46a581d9be064f618125dce94c07d2e7b Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 1 Jul 2021 11:53:59 -0700 Subject: [PATCH 2/3] lint --- python/tvm/arith/__init__.py | 8 ++++++-- python/tvm/arith/iter_affine_map.py | 2 +- src/arith/iter_affine_map.cc | 10 +++++----- tests/python/unittest/test_arith_iter_affine_map.py | 13 +++++++------ 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index bcccac6bedab..f5a0478dc008 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,5 +22,9 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide, \ - inverse_affine_iter_map +from .iter_affine_map import ( + detect_iter_map, + normalize_iter_map_to_expr, + subspace_divide, + inverse_affine_iter_map, +) diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 891d4b55f6e8..85513ecae5c4 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -176,7 +176,7 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi def inverse_affine_iter_map(iter_map, outputs): - """ Apply the inverse of the affine transformation to the outputs. + """Apply the inverse of the affine transformation to the outputs. Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions in reverse topology order and applies the inverse of the affine transformation until it reaches the input. The affine iter map is required to be bijective. diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 4231ff9d256d..e885195b3d42 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1395,7 +1395,7 @@ class InverseAffineIterMapTransformer { std::vector post_dfs_order = ReverseTopologyOrder(iter_map); // initialize back propagation accumulator - for (const IterMapExprNode* node: post_dfs_order) { + for (const IterMapExprNode* node : post_dfs_order) { backprop_.Set(GetRef(node), Integer(0)); } for (size_t i = 0; i < iter_map.size(); i++) { @@ -1403,7 +1403,7 @@ class InverseAffineIterMapTransformer { } // run back propagation - for (const IterMapExprNode* node: post_dfs_order) { + for (const IterMapExprNode* node : post_dfs_order) { if (node->IsInstance()) { Visit_(Downcast(GetRef(node))); } else { @@ -1435,11 +1435,11 @@ class InverseAffineIterMapTransformer { ICHECK(!splits.empty()); for (const IterSplitExpr& split : splits) { - backprop_.Set(split, backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); + backprop_.Set(split, + backprop_.at(split) + floormod(floordiv(input, split->scale), split->extent)); } } - std::vector ReverseTopologyOrder(const Array& iter_map) { std::vector post_dfs_order; std::unordered_map visited; @@ -1516,7 +1516,7 @@ class InverseAffineIterMapTransformer { Analyzer* analyzer_; Map backprop_; // the accumulator of backpropgation - Map inverse_; // the result of inverse transformation + Map inverse_; // the result of inverse transformation }; Map InverseAffineIterMap(const Array& iter_map, diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index e6b09f860b9a..ba264070f555 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -654,8 +654,7 @@ def test_inverse_affine_iter_map(): l1_0, l1_1 = isplit(l1, 4) l0_1_l1_1_fused = ifuse([l0_1, l1_1]) - iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], - var_dom([l0, l1])) + iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) print(res) @@ -673,15 +672,17 @@ def test_inverse_affine_iter_map(): l0_1_l2_1_l1_1_l2_0_fused = ifuse([l0_1, l2_1, l1_1, l2_0]) - iter_map = tvm.arith.detect_iter_map([l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], - var_dom([l0, l1, l2])) + iter_map = tvm.arith.detect_iter_map( + [l0_1_l2_1_l1_1_l2_0_fused[0], l0_0[0], l2_2[0], l1_0[0]], var_dom([l0, l1, l2]) + ) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) assert len(res) == 3 l0_inverse = floormod(floordiv(outputs[0], 64), 16) + outputs[1] * 16 l1_inverse = floormod(floordiv(outputs[0], 4), 4) + outputs[3] * 4 - l2_inverse = floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 \ - + outputs[2] + l2_inverse = ( + floormod(outputs[0], 4) * 16 + floormod(floordiv(outputs[0], 16), 4) * 4 + outputs[2] + ) assert analyzer.simplify(res[l0[0]] - l0_inverse) == 0 assert analyzer.simplify(res[l1[0]] - l1_inverse) == 0 From 0ee70404a6d16ef5aa856d2b74074c9b306201bf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 2 Jul 2021 19:52:21 -0700 Subject: [PATCH 3/3] remove print --- tests/python/unittest/test_arith_iter_affine_map.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index ba264070f555..b34acb9ae359 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -657,7 +657,6 @@ def test_inverse_affine_iter_map(): iter_map = tvm.arith.detect_iter_map([l0_1_l1_1_fused[0], l0_0[0], l1_0[0]], var_dom([l0, l1])) outputs = [tvm.tir.Var("output_{}".format(i), "int32") for i in range(len(iter_map))] res = tvm.arith.inverse_affine_iter_map(iter_map, outputs) - print(res) assert len(res) == 2 l0_inverse = floormod(floordiv(outputs[0], 4), 16) + outputs[1] * 16 l1_inverse = floormod(outputs[0], 4) + outputs[2] * 4