From 5d77b510c6fd4b6b1a91dc6949d94b4e5e65b8b4 Mon Sep 17 00:00:00 2001 From: Min Chen Date: Thu, 5 Jan 2023 07:19:01 +0000 Subject: [PATCH 01/11] [TIR][Arith] Add common sub expr analyzer --- python/tvm/arith/__init__.py | 2 +- python/tvm/arith/pattern.py | 19 +++++ src/arith/detect_common_subexpr.cc | 75 +++++++++++++++++++ .../transforms/common_subexpr_elim_tools.cc | 6 +- .../transforms/common_subexpr_elim_tools.h | 3 +- .../python/unittest/test_arith_detect_cse.py | 33 ++++++++ 6 files changed, 133 insertions(+), 5 deletions(-) create mode 100644 src/arith/detect_common_subexpr.cc create mode 100644 tests/python/unittest/test_arith_detect_cse.py diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 03c0769850c9..423aafe5d69f 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -25,7 +25,7 @@ ) from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound -from .pattern import detect_linear_equation, detect_clip_bound +from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr from .iter_affine_map import ( diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index 53f8eb62b6e1..d5dd6b9ae7bc 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -58,3 +58,22 @@ def detect_clip_bound(expr, var_list): An empty list if the match failed. """ return _ffi_api.DetectClipBound(expr, var_list) + + +def detect_common_subexpr(expr, thresh): + """Detect common sub expression which shows up more than a threshold times + + Parameters + ---------- + expr : PrimExpr + The expression to be analyzed. + + thresh : int + The threshold of repeat times that determines a common sub expression + + Returns + ------- + cse : Dict{PrimExpr: int} + The detected common sub expression dict, with sub expression and repeat times + """ + return _ffi_api.DetectCommonSubExpr(expr, thresh) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc new file mode 100644 index 000000000000..4b42fc0ac0bf --- /dev/null +++ b/src/arith/detect_common_subexpr.cc @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file detect_common_subexpr.cc + * \brief Utility to detect common sub expressions. + */ +#include + +#include + +#include "../tir/transforms/common_subexpr_elim_tools.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +Map DetectCommonSubExpr(const PrimExpr& e, const Integer thresh) { + // Check the treshold in range of size_t + int64_t i64_thr = thresh.IntValue(); + CHECK_GE(i64_thr, std::numeric_limits::min()); + CHECK_LE(i64_thr, std::numeric_limits::max()); + size_t repeat_thr = static_cast(i64_thr); + auto IsEligibleComputation = [](const PrimExpr& expr) { + return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && + (expr.as() == nullptr) && (expr.as() == nullptr)); + }; + + // Analyze the sub expressions + ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy( + e, IsEligibleComputation, [](const PrimExpr& expr) { return true; }); + + std::vector> semantic_comp_done_by_expr = + SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true); + + // Find eligible sub expr if occurrence is under thresh + for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { + std::pair& computation_and_nb = semantic_comp_done_by_expr[i]; + if (computation_and_nb.second < repeat_thr) { + std::vector direct_subexprs = + DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation, + [](const PrimExpr& expr) { return true; }); + InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true, + computation_and_nb.second); + } + } + + // Return the common sub expr that occur more than thresh times + Map results; + for (auto& it : semantic_comp_done_by_expr) { + if (it.second >= repeat_thr) results.Set(it.first, it.second); + } + return results; +} + +TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); +} // namespace arith +} // namespace tvm diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 130004c51cd8..c118d1db7d8e 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -902,7 +902,7 @@ void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, const std::vector& vec_to_add, - bool identify_equiv_terms) { + bool identify_equiv_terms, size_t increase_count) { if (sorted_vec == nullptr) { return; } @@ -918,10 +918,10 @@ void InsertVectorToSortedSemanticComputations(std::vectorend()) { // then we just increase its associated count - it_found->second++; + it_found->second += increase_count; } else { // Otherwise we add the pair (`elem_to_add`,1) at the right place - InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, 1}); + InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count}); } } } diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 0871fd009149..841f1d65a6f6 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -210,9 +210,10 @@ template std::vector VectorMap(const std::vector void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, const std::pair& pair); + void InsertVectorToSortedSemanticComputations(std::vector>* sorted_vec, const std::vector& vec_to_add, - bool identify_equiv_terms); + bool identify_equiv_terms, size_t increase_count = 1); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_arith_detect_cse.py b/tests/python/unittest/test_arith_detect_cse.py new file mode 100644 index 000000000000..eba0920cb2da --- /dev/null +++ b/tests/python/unittest/test_arith_detect_cse.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.script import tir as T + + +def test_detect_cs(): + x = T.Var("x", dtype="int32") + y = T.Var("y", dtype="int32") + z = T.Var("z", dtype="int32") + c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5)) + m = tvm.arith.detect_common_subexpr(c, 2) + assert c.a.a in m + assert m[c.a.a] == 2 + + +if __name__ == "__main__": + tvm.testing.main() From 25c69309c72f227ed5d477083ec2334fa0d31f56 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:11:45 +0800 Subject: [PATCH 02/11] Update python/tvm/arith/pattern.py Co-authored-by: Siyuan Feng --- python/tvm/arith/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index d5dd6b9ae7bc..2a1a3c234f68 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -73,7 +73,7 @@ def detect_common_subexpr(expr, thresh): Returns ------- - cse : Dict{PrimExpr: int} + cse_dict : Dict[PrimExpr, int] The detected common sub expression dict, with sub expression and repeat times """ return _ffi_api.DetectCommonSubExpr(expr, thresh) From fc8a9ada37542a38c5b2ddf2658c93ad72de4a82 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:16:22 +0800 Subject: [PATCH 03/11] Update src/arith/detect_common_subexpr.cc Co-authored-by: Siyuan Feng --- src/arith/detect_common_subexpr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 4b42fc0ac0bf..586983b6b945 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -33,7 +33,7 @@ namespace arith { using namespace tir; Map DetectCommonSubExpr(const PrimExpr& e, const Integer thresh) { - // Check the treshold in range of size_t + // Check the threshold in the range of size_t int64_t i64_thr = thresh.IntValue(); CHECK_GE(i64_thr, std::numeric_limits::min()); CHECK_LE(i64_thr, std::numeric_limits::max()); From a6a3025dac4873c425c033aa79222244cff572e3 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:16:44 +0800 Subject: [PATCH 04/11] Update python/tvm/arith/pattern.py Co-authored-by: Siyuan Feng --- python/tvm/arith/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index 2a1a3c234f68..09c8f8bfbbc8 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -60,7 +60,7 @@ def detect_clip_bound(expr, var_list): return _ffi_api.DetectClipBound(expr, var_list) -def detect_common_subexpr(expr, thresh): +def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int]: """Detect common sub expression which shows up more than a threshold times Parameters From 8141036a815309dadb1e51086ef481e0f60e93e6 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:17:09 +0800 Subject: [PATCH 05/11] Update python/tvm/arith/pattern.py Co-authored-by: Siyuan Feng --- python/tvm/arith/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index 09c8f8bfbbc8..ff8802261329 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -68,7 +68,7 @@ def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int] expr : PrimExpr The expression to be analyzed. - thresh : int + threshold : int The threshold of repeat times that determines a common sub expression Returns From a2b8cef32b6f5571921ee0af7b25995eca466cbf Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:18:08 +0800 Subject: [PATCH 06/11] Update src/arith/detect_common_subexpr.cc Co-authored-by: Siyuan Feng --- src/arith/detect_common_subexpr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 586983b6b945..8d814a3c5355 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -32,7 +32,7 @@ namespace arith { using namespace tir; -Map DetectCommonSubExpr(const PrimExpr& e, const Integer thresh) { +Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { // Check the threshold in the range of size_t int64_t i64_thr = thresh.IntValue(); CHECK_GE(i64_thr, std::numeric_limits::min()); From c48b99d2f9de40f13e9572ad255bb892746ff780 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 16:25:50 +0800 Subject: [PATCH 07/11] Update detect_common_subexpr.cc --- src/arith/detect_common_subexpr.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc index 8d814a3c5355..b496e7fefca5 100644 --- a/src/arith/detect_common_subexpr.cc +++ b/src/arith/detect_common_subexpr.cc @@ -34,10 +34,9 @@ using namespace tir; Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { // Check the threshold in the range of size_t - int64_t i64_thr = thresh.IntValue(); - CHECK_GE(i64_thr, std::numeric_limits::min()); - CHECK_LE(i64_thr, std::numeric_limits::max()); - size_t repeat_thr = static_cast(i64_thr); + CHECK_GE(thresh, std::numeric_limits::min()); + CHECK_LE(thresh, std::numeric_limits::max()); + size_t repeat_thr = static_cast(thresh); auto IsEligibleComputation = [](const PrimExpr& expr) { return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && (expr.as() == nullptr) && (expr.as() == nullptr)); From e9c8b45d18878e17b4e36d77d920cd40a00ae066 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 17:16:17 +0800 Subject: [PATCH 08/11] Update pattern.py --- python/tvm/arith/pattern.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index ff8802261329..b9cbb2c5661b 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -16,6 +16,7 @@ # under the License. """Detect common patterns.""" from . import _ffi_api +from tvm.tir import PrimExpr def detect_linear_equation(expr, var_list): From 8c439e838301f12b2d0d3b0a19e44a376d1dff90 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 17:57:35 +0800 Subject: [PATCH 09/11] Update pattern.py --- python/tvm/arith/pattern.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index b9cbb2c5661b..3a5a55b2f06e 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -17,6 +17,7 @@ """Detect common patterns.""" from . import _ffi_api from tvm.tir import PrimExpr +from typing import Dict def detect_linear_equation(expr, var_list): From 0de318accb55f3a0ddde3c53d57565608c78bc43 Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 21:32:07 +0800 Subject: [PATCH 10/11] Update pattern.py --- python/tvm/arith/pattern.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index 3a5a55b2f06e..d4d7be69e365 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -78,4 +78,4 @@ def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int] cse_dict : Dict[PrimExpr, int] The detected common sub expression dict, with sub expression and repeat times """ - return _ffi_api.DetectCommonSubExpr(expr, thresh) + return _ffi_api.DetectCommonSubExpr(expr, threshold) From 0e68d8f32cadaec2b0080fc80b8c596b5a47a32f Mon Sep 17 00:00:00 2001 From: multiverstack <39256082+multiverstack-intellif@users.noreply.github.com> Date: Fri, 6 Jan 2023 23:58:08 +0800 Subject: [PATCH 11/11] Update pattern.py --- python/tvm/arith/pattern.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index d4d7be69e365..3c822dc52399 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. """Detect common patterns.""" -from . import _ffi_api -from tvm.tir import PrimExpr + from typing import Dict +from tvm.tir import PrimExpr +from . import _ffi_api + def detect_linear_equation(expr, var_list): """Match `expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]`