diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index bc2f7ad6f357..4eb9cc5b1a90 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -258,8 +258,8 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { arith::Analyzer analyzer; PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); Array equations; - std::unordered_set var_set; - std::function fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) { + Array vars; + std::function fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { bool is_simple = true; @@ -278,7 +278,12 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }); if (is_simple && !cand_vars.empty()) { - for (const Var& var : cand_vars) var_set.insert(var); + for (const Var& new_var : cand_vars) { + if (!std::any_of(vars.begin(), vars.end(), + [&new_var](const Var& v) { return v.same_as(new_var); })) { + vars.push_back(new_var); + } + } equations.push_back(Downcast(e)); } } else if (e->IsInstance()) { @@ -293,11 +298,10 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }; fvisit(condition); - if (equations.empty() || var_set.empty()) { + if (equations.empty() || vars.empty()) { return Map(); } // build dom ranges for related vars - Array vars = Array(var_set.begin(), var_set.end()); Map ranges; for (const Var& v : vars) { arith::IntSet dom; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index 826054eb0b9b..e206fcc4502c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -import pytest from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.schedule_rule import auto_inline from tvm.meta_schedule.tune_context import TuneContext @@ -271,7 +270,6 @@ def test_inline_consumer_chain(): tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) -@pytest.mark.skip(reason="Flaky test") def test_inline_into_cache(): mod = MultiLevelTiledConv2D target = Target("cuda", host="llvm")