Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,8 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
arith::Analyzer analyzer;
PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_);
Array<PrimExpr> equations;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set;
std::function<void(const PrimExpr&)> fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) {
Array<Var> vars;
std::function<void(const PrimExpr&)> fvisit = [&equations, &vars, &fvisit](const PrimExpr& e) {
if (e->IsInstance<GENode>() || e->IsInstance<GTNode>() || e->IsInstance<LENode>() ||
e->IsInstance<LTNode>() || e->IsInstance<EQNode>() || e->IsInstance<NENode>()) {
bool is_simple = true;
Expand All @@ -278,7 +278,12 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
}
});
if (is_simple && !cand_vars.empty()) {
for (const Var& var : cand_vars) var_set.insert(var);
for (const Var& new_var : cand_vars) {
if (!std::any_of(vars.begin(), vars.end(),
[&new_var](const Var& v) { return v.same_as(new_var); })) {
vars.push_back(new_var);
}
}
equations.push_back(Downcast<PrimExpr>(e));
}
} else if (e->IsInstance<AndNode>()) {
Expand All @@ -293,11 +298,10 @@ Map<Var, Range> ConditionalBoundsContext::GetVarBoundsFromCondition() {
}
};
fvisit(condition);
if (equations.empty() || var_set.empty()) {
if (equations.empty() || vars.empty()) {
return Map<Var, Range>();
}
// build dom ranges for related vars
Array<Var> vars = Array<Var>(var_set.begin(), var_set.end());
Map<Var, Range> ranges;
for (const Var& v : vars) {
arith::IntSet dom;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
import pytest
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import auto_inline
from tvm.meta_schedule.tune_context import TuneContext
Expand Down Expand Up @@ -271,7 +270,6 @@ def test_inline_consumer_chain():
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined)


@pytest.mark.skip(reason="Flaky test")
def test_inline_into_cache():
mod = MultiLevelTiledConv2D
target = Target("cuda", host="llvm")
Expand Down