diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 03c0769850c9..423aafe5d69f 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -25,7 +25,7 @@ ) from .analyzer import ModularSet, ConstIntBound, Analyzer from .bound import deduce_bound -from .pattern import detect_linear_equation, detect_clip_bound +from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr from .iter_affine_map import ( diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py index 53f8eb62b6e1..3c822dc52399 100644 --- a/python/tvm/arith/pattern.py +++ b/python/tvm/arith/pattern.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. """Detect common patterns.""" + +from typing import Dict + +from tvm.tir import PrimExpr from . import _ffi_api @@ -58,3 +62,22 @@ def detect_clip_bound(expr, var_list): An empty list if the match failed. """ return _ffi_api.DetectClipBound(expr, var_list) + + +def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int]: + """Detect common sub expression which shows up more than a threshold times + + Parameters + ---------- + expr : PrimExpr + The expression to be analyzed. + + threshold : int + The threshold of repeat times that determines a common sub expression + + Returns + ------- + cse_dict : Dict[PrimExpr, int] + The detected common sub expression dict, with sub expression and repeat times + """ + return _ffi_api.DetectCommonSubExpr(expr, threshold) diff --git a/src/arith/detect_common_subexpr.cc b/src/arith/detect_common_subexpr.cc new file mode 100644 index 000000000000..b496e7fefca5 --- /dev/null +++ b/src/arith/detect_common_subexpr.cc @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file detect_common_subexpr.cc + * \brief Utility to detect common sub expressions. + */ +#include + +#include + +#include "../tir/transforms/common_subexpr_elim_tools.h" + +namespace tvm { +namespace arith { + +using namespace tir; + +Map DetectCommonSubExpr(const PrimExpr& e, int thresh) { + // Check the threshold in the range of size_t + CHECK_GE(thresh, std::numeric_limits::min()); + CHECK_LE(thresh, std::numeric_limits::max()); + size_t repeat_thr = static_cast(thresh); + auto IsEligibleComputation = [](const PrimExpr& expr) { + return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && + (expr.as() == nullptr) && (expr.as() == nullptr)); + }; + + // Analyze the sub expressions + ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy( + e, IsEligibleComputation, [](const PrimExpr& expr) { return true; }); + + std::vector> semantic_comp_done_by_expr = + SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true); + + // Find eligible sub expr if occurrence is under thresh + for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { + std::pair& computation_and_nb = semantic_comp_done_by_expr[i]; + if (computation_and_nb.second < repeat_thr) { + std::vector direct_subexprs = + DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation, + [](const PrimExpr& expr) { return true; }); + InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true, + computation_and_nb.second); + } + } + + // Return the common sub expr that occur more than thresh times + Map results; + for (auto& it : semantic_comp_done_by_expr) { + if (it.second >= repeat_thr) results.Set(it.first, it.second); + } + return results; +} + +TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr); +} // namespace arith +} // namespace tvm diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 130004c51cd8..c118d1db7d8e 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -902,7 +902,7 @@ void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, const std::vector& vec_to_add, - bool identify_equiv_terms) { + bool identify_equiv_terms, size_t increase_count) { if (sorted_vec == nullptr) { return; } @@ -918,10 +918,10 @@ void InsertVectorToSortedSemanticComputations(std::vectorend()) { // then we just increase its associated count - it_found->second++; + it_found->second += increase_count; } else { // Otherwise we add the pair (`elem_to_add`,1) at the right place - InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, 1}); + InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count}); } } } diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 0871fd009149..841f1d65a6f6 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -210,9 +210,10 @@ template std::vector VectorMap(const std::vector void InsertElemToSortedSemanticComputations(std::vector>* sorted_vec, const std::pair& pair); + void InsertVectorToSortedSemanticComputations(std::vector>* sorted_vec, const std::vector& vec_to_add, - bool identify_equiv_terms); + bool identify_equiv_terms, size_t increase_count = 1); } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_arith_detect_cse.py b/tests/python/unittest/test_arith_detect_cse.py new file mode 100644 index 000000000000..eba0920cb2da --- /dev/null +++ b/tests/python/unittest/test_arith_detect_cse.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm.script import tir as T + + +def test_detect_cs(): + x = T.Var("x", dtype="int32") + y = T.Var("y", dtype="int32") + z = T.Var("z", dtype="int32") + c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5)) + m = tvm.arith.detect_common_subexpr(c, 2) + assert c.a.a in m + assert m[c.a.a] == 2 + + +if __name__ == "__main__": + tvm.testing.main()