Skip to content
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import (
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/arith/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Detect common patterns."""

from typing import Dict

from tvm.tir import PrimExpr
from . import _ffi_api


Expand Down Expand Up @@ -58,3 +62,22 @@ def detect_clip_bound(expr, var_list):
An empty list if the match failed.
"""
return _ffi_api.DetectClipBound(expr, var_list)


def detect_common_subexpr(expr: PrimExpr, threshold: int) -> Dict[PrimExpr, int]:
"""Detect common sub expression which shows up more than a threshold times

Parameters
----------
expr : PrimExpr
The expression to be analyzed.

threshold : int
The threshold of repeat times that determines a common sub expression

Returns
-------
cse_dict : Dict[PrimExpr, int]
The detected common sub expression dict, with sub expression and repeat times
"""
return _ffi_api.DetectCommonSubExpr(expr, threshold)
74 changes: 74 additions & 0 deletions src/arith/detect_common_subexpr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file detect_common_subexpr.cc
* \brief Utility to detect common sub expressions.
*/
#include <tvm/tir/expr.h>

#include <limits>

#include "../tir/transforms/common_subexpr_elim_tools.h"

namespace tvm {
namespace arith {

using namespace tir;

Map<PrimExpr, Integer> DetectCommonSubExpr(const PrimExpr& e, int thresh) {
// Check the threshold in the range of size_t
CHECK_GE(thresh, std::numeric_limits<size_t>::min());
CHECK_LE(thresh, std::numeric_limits<size_t>::max());
size_t repeat_thr = static_cast<size_t>(thresh);
auto IsEligibleComputation = [](const PrimExpr& expr) {
return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
(expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
};

// Analyze the sub expressions
ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
e, IsEligibleComputation, [](const PrimExpr& expr) { return true; });

std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true);

// Find eligible sub expr if occurrence is under thresh
for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
if (computation_and_nb.second < repeat_thr) {
std::vector<PrimExpr> direct_subexprs =
DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation,
[](const PrimExpr& expr) { return true; });
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true,
computation_and_nb.second);
}
}

// Return the common sub expr that occur more than thresh times
Map<PrimExpr, Integer> results;
for (auto& it : semantic_comp_done_by_expr) {
if (it.second >= repeat_thr) results.Set(it.first, it.second);
}
return results;
}

TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr);
} // namespace arith
} // namespace tvm
6 changes: 3 additions & 3 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size
*/
void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::vector<PrimExpr>& vec_to_add,
bool identify_equiv_terms) {
bool identify_equiv_terms, size_t increase_count) {
if (sorted_vec == nullptr) {
return;
}
Expand All @@ -918,10 +918,10 @@ void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, si
// If we found `elem_to_add` (or an equivalent expression) already in sorted_vec
if (it_found != sorted_vec->end()) {
// then we just increase its associated count
it_found->second++;
it_found->second += increase_count;
} else {
// Otherwise we add the pair (`elem_to_add`,1) at the right place
InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, 1});
InsertElemToSortedSemanticComputations(sorted_vec, {elem_to_add, increase_count});
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/common_subexpr_elim_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ template std::vector<Var> VectorMap(const std::vector<std::pair<Var, MaybeValue>

void InsertElemToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::pair<PrimExpr, size_t>& pair);

void InsertVectorToSortedSemanticComputations(std::vector<std::pair<PrimExpr, size_t>>* sorted_vec,
const std::vector<PrimExpr>& vec_to_add,
bool identify_equiv_terms);
bool identify_equiv_terms, size_t increase_count = 1);

} // namespace tir
} // namespace tvm
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_arith_detect_cse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm.script import tir as T


def test_detect_cs():
x = T.Var("x", dtype="int32")
y = T.Var("y", dtype="int32")
z = T.Var("z", dtype="int32")
c = T.floor(x + y + 0.5) + x + z * (T.floor(x + y + 0.5))
m = tvm.arith.detect_common_subexpr(c, 2)
assert c.a.a in m
assert m[c.a.a] == 2


if __name__ == "__main__":
tvm.testing.main()