From a67f85042c59747bef7dad2aefd0c0e2d3d3bf71 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Tue, 18 Jan 2022 20:38:40 +0800 Subject: [PATCH 1/2] fix to stablize the var orders when solve bounds in region analysis --- src/tir/transforms/ir_utils.cc | 12 +++++++++--- .../test_meta_schedule_schedule_rule_auto_inline.py | 1 - 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index bc2f7ad6f357..03942a6254f0 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -259,7 +259,9 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); Array equations; std::unordered_set var_set; - std::function fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) { + Array vars; + std::function fvisit = [&equations, &vars, &var_set, + &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { bool is_simple = true; @@ -278,7 +280,12 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }); if (is_simple && !cand_vars.empty()) { - for (const Var& var : cand_vars) var_set.insert(var); + for (const Var& var : cand_vars) { + if (!var_set.count(var)) { + vars.push_back(var); + var_set.insert(var); + } + } equations.push_back(Downcast(e)); } } else if (e->IsInstance()) { @@ -297,7 +304,6 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { return Map(); } // build dom ranges for related vars - Array vars = Array(var_set.begin(), var_set.end()); Map ranges; for (const Var& v : vars) { arith::IntSet dom; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index 826054eb0b9b..0b51ca589817 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -271,7 +271,6 @@ def test_inline_consumer_chain(): tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) -@pytest.mark.skip(reason="Flaky test") def test_inline_into_cache(): mod = MultiLevelTiledConv2D target = Target("cuda", host="llvm") From 44f1e798500225d8bf2f7c13b4662454c4323a1f Mon Sep 17 00:00:00 2001 From: wrongtest Date: Wed, 19 Jan 2022 12:59:11 +0800 Subject: [PATCH 2/2] change to std::find_if since num of vars is generally small --- src/tir/transforms/ir_utils.cc | 14 ++++++-------- ...test_meta_schedule_schedule_rule_auto_inline.py | 1 - 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 03942a6254f0..4eb9cc5b1a90 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -258,10 +258,8 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { arith::Analyzer analyzer; PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); Array equations; - std::unordered_set var_set; Array vars; - std::function fvisit = [&equations, &vars, &var_set, - &fvisit](const PrimExpr& e) { + std::function fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) { if (e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance() || e->IsInstance()) { bool is_simple = true; @@ -280,10 +278,10 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }); if (is_simple && !cand_vars.empty()) { - for (const Var& var : cand_vars) { - if (!var_set.count(var)) { - vars.push_back(var); - var_set.insert(var); + for (const Var& new_var : cand_vars) { + if (!std::any_of(vars.begin(), vars.end(), + [&new_var](const Var& v) { return v.same_as(new_var); })) { + vars.push_back(new_var); } } equations.push_back(Downcast(e)); @@ -300,7 +298,7 @@ Map ConditionalBoundsContext::GetVarBoundsFromCondition() { } }; fvisit(condition); - if (equations.empty() || var_set.empty()) { + if (equations.empty() || vars.empty()) { return Map(); } // build dom ranges for related vars diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index 0b51ca589817..e206fcc4502c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -import pytest from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.meta_schedule.testing.schedule_rule import auto_inline from tvm.meta_schedule.tune_context import TuneContext