From ecbff528f62b761c80e6d31049a796a684eb2a35 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 4 Apr 2023 04:49:10 +0900 Subject: [PATCH 01/16] remove start_hint from MatchGraph --- include/tvm/relax/dataflow_matcher.h | 9 +---- python/tvm/relax/dpl/context.py | 8 +--- src/relax/ir/dataflow_matcher.cc | 20 +-------- tests/python/relax/test_dataflow_pattern.py | 45 --------------------- 4 files changed, 3 insertions(+), 79 deletions(-) diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index e4268be882d7..cf7c58f093e6 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -51,19 +51,12 @@ Optional> ExtractMatchedExpr( /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. - * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the - * starting point of the matching so that we can distinguish multiple matches. - * * \param ctx The graph-wise patterns. * \param dfb The function to match. - * \param start_hint The starting point expression to match to distinguish multiple matches. - * \param must_include_hint If start_hint is given, the return pattern must include start_hint. * \return Matched patterns and corresponding bound variables */ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb, - Optional start_hint = NullOpt, - bool must_include_hint = false); + const DataflowBlock& dfb); } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py index 69a5e70ed0f1..de0ca264d119 100644 --- a/python/tvm/relax/dpl/context.py +++ b/python/tvm/relax/dpl/context.py @@ -63,8 +63,6 @@ def current() -> "PatternContext": def match_dfb( self, dfb: DataflowBlock, - start_hint: Optional[Var] = None, - must_include_hint: bool = False, ) -> Dict[DFPattern, Var]: """ Match a DataflowBlock via a graph of DFPattern and corresponding constraints @@ -73,14 +71,10 @@ def match_dfb( ---------- dfb : DataflowBlock The DataflowBlock to match - start_hint : Optional[Var], optional - Indicating the starting expression to match, by default None - must_include_hint : bool, optional - Whether the start_hint expression must be matched, by default False Returns ------- Dict[DFPattern, Var] The mapping from DFPattern to matched expression """ - return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore + return ffi.match_dfb(self, dfb) # type: ignore diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c1306ff69093..9c5ca15bae49 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -693,16 +693,13 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { } }; -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - Optional start_hint, bool must_include_hint) { +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { if (ctx->src_ordered.size() == 0) { return NullOpt; } // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - ICHECK(!must_include_hint || start_hint.defined()) - << "must_include_hint is only supported with start_hint."; const auto var2val = AnalyzeVar2Value(dfb); DFPatternMatcher matcher(var2val); @@ -745,25 +742,10 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl Map ret; - if (start_hint) { - auto rnode_ptr = var2node.at(start_hint.value().get()); - for (auto& p_node : pattern2node) { - if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) { - ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - } - return ret; - } - } - - if (must_include_hint) return ret; - } - PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()]; if (!pnode_start.matched) { for (const auto& var : ud_analysis.vars) { - if (start_hint.defined() && start_hint.value().get() == var) continue; RNode& r_node = var2node[var]; if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { for (const auto& [df_pattern, pattern_node] : pattern2node) { diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index b85543cafcb8..a73a62eeef8d 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -519,51 +519,6 @@ def main( return lv6 -def test_single_cbr(): - with PatternContext() as ctx: - ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("bias_add") - >> is_call_dps_packed("my_relu") - ) - dfb = CBRx2["main"].body.blocks[0] - matched = ctx.match_dfb(dfb) - assert matched - - with PatternContext() as ctx: - chain = ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("bias_add") - >> is_call_dps_packed("my_relu") - ) - dfb = CBRx2["main"].body.blocks[0] - # we want to specifically match the first CBR (lv0) - matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var) - assert matched - assert matched[chain[0]] == dfb.bindings[0].var - # we want to specifically match the second CBR (lv3) - matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var) - assert matched - assert matched[chain[0]] == dfb.bindings[3].var - - -def test_counter_single_crb(): - with PatternContext() as ctx: - ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("my_relu") - >> is_call_dps_packed("bias_add") - ) - dfb = CBRx2["main"].body.blocks[0] - assert not ctx.match_dfb(dfb) - # Quickly fails unpromising matches by assuming `start_hint` must be matched by a pattern. - # This is usually faster than the full match: - # Full match: let one pattern to match -> all Var: complexity ~ #Var - # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns - # Usually #patterns is much smaller than #Var, so this is faster. - assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True) - - def test_nested_context(): dfb = CBRx2["main"].body.blocks[0] with PatternContext() as ctx0: From e3a3c8fad9f3a7be8738b4759b2e367200a007f4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 4 Apr 2023 10:29:36 +0900 Subject: [PATCH 02/16] improve graph matching algorithm --- src/relax/ir/dataflow_matcher.cc | 189 ++++++++++++++++--------------- src/relax/ir/dataflow_pattern.cc | 1 + 2 files changed, 99 insertions(+), 91 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 9c5ca15bae49..113d72e535eb 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -523,6 +523,43 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_o TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + struct PNode { const DFPatternNode* ptr; const VarNode* matched = nullptr; @@ -541,18 +578,16 @@ struct RNode { * \brief This method try to match a real node and a pattern node along with its neighbors. */ using UndoItems = std::vector>; -static std::optional try_match( - PNode* p, RNode* r, DFPatternMatcher* m, - const std::map>& def2use, - const std::map>& use2def) { - if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. + +static std::optional TryMatch(PNode* p, RNode* r, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (r->matched != nullptr && p->matched != r->ptr) return std::nullopt; if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; UndoItems undo; const auto commit = [&undo](PNode* p, RNode* r) { // match with each other. - // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? if (p->ptr == r->matched) { ICHECK_EQ(p->matched, r->ptr); return; @@ -571,61 +606,32 @@ static std::optional try_match( }; const auto try_match_update_undo = [&](PNode* p, RNode* r) { - if (auto undo_more = try_match(p, r, m, def2use, use2def)) { + if (auto undo_more = TryMatch(p, r, m, ud_analysis)) { undo.insert(undo.end(), undo_more->begin(), undo_more->end()); return true; } return false; }; - commit(p, r); - - // match parent patterns. - for (auto& [pparent, constraints] : p->parents) { - bool any_cons_sat = false; - for (auto& rparent : r->parents) { - // skip if mismatch. - if (rparent->matched && rparent->matched != pparent->ptr) continue; - - const auto& uses = def2use.at(rparent->ptr); - - // check edge constraints. - bool cons_sat = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - cons_sat = false; - break; - } - - if (cons.index != -1) { - const auto& callees = use2def.at(r->ptr); - if (callees.size() <= static_cast(cons.index) || - callees[cons.index] != rparent->ptr) { - cons_sat = false; - break; - } - } - } - if (!cons_sat) continue; - any_cons_sat = true; - - // try all parent R nodes that are not matched yet. - // as long as ppattern can match one node. - if (!pparent->matched && try_match_update_undo(pparent, rparent)) { - commit(pparent, rparent); - break; + for (size_t i = 0; i < p->parents.size(); ++i) { + auto p_node_parent = p->parents[i].first; + if (p_node_parent->ptr->IsInstance()) { + if (p_node_parent->matched && p_node_parent->matched != r->parents[i]->ptr) { + return std::nullopt; } + commit(p_node_parent, r->parents[i]); } - if (!pparent->matched || !any_cons_sat) return quit(); } + commit(p, r); + // forward matching; for (auto& [pchild, constraints] : p->children) { bool any_cons_sat = false; for (auto& rchild : r->children) { if (rchild->matched && rchild->matched != pchild->ptr) continue; - const auto& uses = def2use.at(r->ptr); + const auto& uses = ud_analysis.def2use.at(r->ptr); // check edge constraints. bool all_cons_pass = true; @@ -636,7 +642,7 @@ static std::optional try_match( } if (cons.index != -1) { - const auto& callees = use2def.at(rchild->ptr); + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r->ptr) { all_cons_pass = false; break; @@ -653,45 +659,45 @@ static std::optional try_match( } if (!pchild->matched || !any_cons_sat) return quit(); } + return undo; } -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; - - const VarNode* cur_user_; - - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; +static bool MatchTree(int current_root_idx, + std::unordered_map& pattern2node, + std::unordered_map& var2node, + DFPatternMatcher* matcher, const std::vector& roots, + const MatcherUseDefAnalysis& ud_analysis) { + PNode* root = nullptr; + for (; current_root_idx < roots.size(); ++current_root_idx) { + root = &pattern2node[roots[current_root_idx].get()]; + if (!root->matched) { + break; + } } - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; + if (root == nullptr || root->matched) { + return true; + } - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); + for (const auto& var : ud_analysis.vars) { + RNode& r_node = var2node[var]; + if (r_node.matched) continue; + if (auto undo_items = TryMatch(root, &r_node, matcher, ud_analysis)) { + if (MatchTree(current_root_idx + 1, pattern2node, var2node, matcher, roots, ud_analysis)) { + return true; } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); + // clean up undo stack and backtrack + for (auto& [p_node, r_node] : *undo_items) { + p_node->matched = nullptr; + r_node->matched = nullptr; + } + continue; + } } - void VisitExpr_(const DataflowVarNode* op) override { - VisitExpr_(static_cast(op)); - } -}; + return false; +} Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { if (ctx->src_ordered.size() == 0) { @@ -706,15 +712,13 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl MatcherUseDefAnalysis ud_analysis; ud_analysis.VisitBindingBlock_(dfb.get()); - const auto& def2use = ud_analysis.def2use; - const auto& caller2callees = ud_analysis.caller2callees; // First construct a graph of PNode and RNode. std::unordered_map var2node; var2node.reserve(dfb->bindings.size()); for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = def2use.at(cur_var); + const auto& uses = ud_analysis.def2use.at(cur_var); RNode& cur_node = var2node[cur_var]; cur_node.ptr = cur_var; for (const VarNode* use : uses) { @@ -728,8 +732,9 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl std::unordered_map pattern2node; pattern2node.reserve(ctx->constraints.size()); - for (const auto& [def_pattern, uses] : ctx->constraints) { + for (const auto& def_pattern : ctx->src_ordered) { PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->constraints.at(def_pattern); def_node.ptr = def_pattern.get(); def_node.children.reserve(uses.size()); for (const auto& [use_pattern, cons] : uses) { @@ -740,20 +745,22 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl } } - Map ret; + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); + } + } - PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()]; + ICHECK_GT(roots.size(), 0); - if (!pnode_start.matched) { - for (const auto& var : ud_analysis.vars) { - RNode& r_node = var2node[var]; - if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) { - ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - } - return ret; - } + if (MatchTree(0, pattern2node, var2node, &matcher, roots, ud_analysis)) { + Map ret; + for (const auto& [df_pattern, pattern_node] : pattern2node) { + ICHECK(pattern_node.matched); + ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); } + return ret; } return NullOpt; diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 5580f6a1ab74..4d225ceecfe7 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -406,6 +406,7 @@ PatternContext::PatternContext(bool incremental) { << "Incremental context needs to be built inside a existing context."; n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; n->constraints = pattern_ctx_stack().top()->constraints; + n->src_ordered = pattern_ctx_stack().top()->src_ordered; } data_ = std::move(n); From a0e80d43c1929f11878dc738d52ec8a24305795b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 07:35:28 +0900 Subject: [PATCH 03/16] remove side effect from matching algo --- src/relax/ir/dataflow_matcher.cc | 166 +++++++++++++++---------------- 1 file changed, 81 insertions(+), 85 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 113d72e535eb..5324364307c1 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -562,76 +562,74 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { struct PNode { const DFPatternNode* ptr; - const VarNode* matched = nullptr; std::vector&>> children; std::vector&>> parents; }; struct RNode { const VarNode* ptr; - const DFPatternNode* matched = nullptr; std::vector children; std::vector parents; }; -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -using UndoItems = std::vector>; +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } -static std::optional TryMatch(PNode* p, RNode* r, DFPatternMatcher* m, - const MatcherUseDefAnalysis& ud_analysis) { - if (r->matched != nullptr && p->matched != r->ptr) return std::nullopt; - if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; + void add(const MatchState& other) { + for (const auto& [p, r] : other.match_p_r) { + add(p, r); + } + } - UndoItems undo; + const VarNode* matched(const PNode* p) const { + if (!match_p_r.count(p)) return nullptr; + return match_p_r.at(p)->ptr; + } - const auto commit = [&undo](PNode* p, RNode* r) { - // match with each other. - if (p->ptr == r->matched) { - ICHECK_EQ(p->matched, r->ptr); - return; - } - p->matched = r->ptr; - r->matched = p->ptr; - undo.emplace_back(p, r); - }; + const DFPatternNode* matched(const RNode* r) const { + if (!match_r_p.count(r)) return nullptr; + return match_r_p.at(r)->ptr; + } - const auto quit = [&undo] { - for (auto& [p_node, r_node] : undo) { - p_node->matched = nullptr; - r_node->matched = nullptr; - } - return std::nullopt; - }; + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - const auto try_match_update_undo = [&](PNode* p, RNode* r) { - if (auto undo_more = TryMatch(p, r, m, ud_analysis)) { - undo.insert(undo.end(), undo_more->begin(), undo_more->end()); - return true; - } - return false; - }; + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; +}; - for (size_t i = 0; i < p->parents.size(); ++i) { - auto p_node_parent = p->parents[i].first; +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + + MatchState result; + + for (size_t i = 0; i < p.parents.size(); ++i) { + auto p_node_parent = p.parents[i].first; if (p_node_parent->ptr->IsInstance()) { - if (p_node_parent->matched && p_node_parent->matched != r->parents[i]->ptr) { + if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { return std::nullopt; } - commit(p_node_parent, r->parents[i]); + result.add(p_node_parent, r.parents[i]); } } - commit(p, r); + result.add(&p, &r); // forward matching; - for (auto& [pchild, constraints] : p->children) { + for (auto& [pchild, constraints] : p.children) { bool any_cons_sat = false; - for (auto& rchild : r->children) { - if (rchild->matched && rchild->matched != pchild->ptr) continue; + for (auto& rchild : r.children) { + if (auto p = result.matched(rchild); p && p != pchild->ptr) continue; - const auto& uses = ud_analysis.def2use.at(r->ptr); + const auto& uses = ud_analysis.def2use.at(r.ptr); // check edge constraints. bool all_cons_pass = true; @@ -643,67 +641,63 @@ static std::optional TryMatch(PNode* p, RNode* r, DFPatternMatcher* m if (cons.index != -1) { const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r->ptr) { + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { all_cons_pass = false; break; } } } - if (!all_cons_pass) continue; + if (!all_cons_pass || result.matched(pchild)) continue; any_cons_sat = true; - if (!pchild->matched && try_match_update_undo(pchild, rchild)) { - commit(pchild, rchild); - break; + if (auto match_rec = TryMatch(*pchild, *rchild, m, ud_analysis)) { + result.add(pchild, rchild); + result.add(*match_rec); } } - if (!pchild->matched || !any_cons_sat) return quit(); + if (!result.matched(pchild) || !any_cons_sat) return std::nullopt; } - return undo; + return result; } -static bool MatchTree(int current_root_idx, - std::unordered_map& pattern2node, - std::unordered_map& var2node, - DFPatternMatcher* matcher, const std::vector& roots, - const MatcherUseDefAnalysis& ud_analysis) { - PNode* root = nullptr; - for (; current_root_idx < roots.size(); ++current_root_idx) { - root = &pattern2node[roots[current_root_idx].get()]; - if (!root->matched) { - break; +static std::optional MatchTree( + const MatchState& current_matches, int current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const MatcherUseDefAnalysis& ud_analysis) { + auto get_next_root = [&](int root_idx) -> const PNode* { + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_matches.matched(root)) { + return &root; + } } - } + return nullptr; + }; + + const auto root = get_next_root(current_root_idx); - if (root == nullptr || root->matched) { - return true; + if (!root || current_matches.matched(root)) { + return current_matches; } for (const auto& var : ud_analysis.vars) { - RNode& r_node = var2node[var]; - if (r_node.matched) continue; - if (auto undo_items = TryMatch(root, &r_node, matcher, ud_analysis)) { - if (MatchTree(current_root_idx + 1, pattern2node, var2node, matcher, roots, ud_analysis)) { - return true; - } - // clean up undo stack and backtrack - for (auto& [p_node, r_node] : *undo_items) { - p_node->matched = nullptr; - r_node->matched = nullptr; + const RNode& r_node = var2node.at(var); + if (current_matches.matched(r_node)) continue; + if (auto matches = TryMatch(*root, r_node, matcher, ud_analysis)) { + if (auto matches_rec = MatchTree(*matches, current_root_idx + 1, pattern2node, var2node, + matcher, roots, ud_analysis)) { + matches->add(*matches_rec); + return matches; } - continue; } } - return false; + return std::nullopt; } Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - if (ctx->src_ordered.size() == 0) { - return NullOpt; - } - // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; @@ -752,13 +746,15 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl } } - ICHECK_GT(roots.size(), 0); + if (roots.empty()) { + return NullOpt; + } - if (MatchTree(0, pattern2node, var2node, &matcher, roots, ud_analysis)) { + if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, ud_analysis)) { Map ret; - for (const auto& [df_pattern, pattern_node] : pattern2node) { - ICHECK(pattern_node.matched); - ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); } return ret; } From e5ccd8a7f035629056421432cc2ffcd8abd620e3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 08:55:21 +0900 Subject: [PATCH 04/16] pylint --- python/tvm/relax/dpl/context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py index de0ca264d119..16d86fb32df0 100644 --- a/python/tvm/relax/dpl/context.py +++ b/python/tvm/relax/dpl/context.py @@ -17,7 +17,7 @@ """The Graph Matching Context Manager for Dataflow Pattern Language.""" -from typing import Optional, Dict +from typing import Dict import tvm from ..expr import DataflowBlock, Var From aed075f4b6b9d89e885b4aa91c8f632fcb2f725a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 09:08:26 +0900 Subject: [PATCH 05/16] add comments --- src/relax/ir/dataflow_matcher.cc | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 5324364307c1..e909e0d51cd2 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -615,6 +615,7 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt auto p_node_parent = p.parents[i].first; if (p_node_parent->ptr->IsInstance()) { if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { + // A parent wildcard pattern is already matched to other variable. return std::nullopt; } result.add(p_node_parent, r.parents[i]); @@ -667,6 +668,7 @@ static std::optional MatchTree( const std::unordered_map& var2node, DFPatternMatcher* matcher, const std::vector& roots, const MatcherUseDefAnalysis& ud_analysis) { auto get_next_root = [&](int root_idx) -> const PNode* { + // Look for the next unmatched root node. for (; root_idx < roots.size(); ++root_idx) { const auto& root = pattern2node.at(roots[root_idx].get()); if (!current_matches.matched(root)) { @@ -678,7 +680,8 @@ static std::optional MatchTree( const auto root = get_next_root(current_root_idx); - if (!root || current_matches.matched(root)) { + if (!root) { + // All root nodes have been matched return current_matches; } @@ -686,11 +689,14 @@ static std::optional MatchTree( const RNode& r_node = var2node.at(var); if (current_matches.matched(r_node)) continue; if (auto matches = TryMatch(*root, r_node, matcher, ud_analysis)) { + // Recursivly try to match the next subtree. if (auto matches_rec = MatchTree(*matches, current_root_idx + 1, pattern2node, var2node, matcher, roots, ud_analysis)) { matches->add(*matches_rec); return matches; } + // Recursive matching has failed, backtrack. + continue; } } From 795449a57f437845e03b4d49bb28cf7a4d359a9c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 09:13:09 +0900 Subject: [PATCH 06/16] add more const now that we can --- src/relax/ir/dataflow_matcher.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index e909e0d51cd2..d8432f9bcd4e 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -612,8 +612,9 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt MatchState result; for (size_t i = 0; i < p.parents.size(); ++i) { - auto p_node_parent = p.parents[i].first; + const auto p_node_parent = p.parents[i].first; if (p_node_parent->ptr->IsInstance()) { + ICHECK_EQ(p.parents.size(), r.parents.size()); if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { // A parent wildcard pattern is already matched to other variable. return std::nullopt; @@ -625,9 +626,9 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt result.add(&p, &r); // forward matching; - for (auto& [pchild, constraints] : p.children) { + for (const auto& [pchild, constraints] : p.children) { bool any_cons_sat = false; - for (auto& rchild : r.children) { + for (const auto& rchild : r.children) { if (auto p = result.matched(rchild); p && p != pchild->ptr) continue; const auto& uses = ud_analysis.def2use.at(r.ptr); From 57215f8f557874bdd2cdc7c6b2e38731d3c67a03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 09:35:51 +0900 Subject: [PATCH 07/16] cpplint --- src/relax/ir/dataflow_matcher.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index d8432f9bcd4e..74ead22ce84d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -629,7 +629,9 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt for (const auto& [pchild, constraints] : p.children) { bool any_cons_sat = false; for (const auto& rchild : r.children) { - if (auto p = result.matched(rchild); p && p != pchild->ptr) continue; + if (auto p = result.matched(rchild); p && p != pchild->ptr) { + continue; + } const auto& uses = ud_analysis.def2use.at(r.ptr); From 3a4d8ea2edd29b3362bb48de26a7434e86422805 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 16:24:58 +0900 Subject: [PATCH 08/16] fix compile warning --- src/relax/ir/dataflow_matcher.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 74ead22ce84d..3a8b84c41edb 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -666,11 +666,11 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt } static std::optional MatchTree( - const MatchState& current_matches, int current_root_idx, + const MatchState& current_matches, size_t current_root_idx, const std::unordered_map& pattern2node, const std::unordered_map& var2node, DFPatternMatcher* matcher, const std::vector& roots, const MatcherUseDefAnalysis& ud_analysis) { - auto get_next_root = [&](int root_idx) -> const PNode* { + auto get_next_root = [&](size_t root_idx) -> const PNode* { // Look for the next unmatched root node. for (; root_idx < roots.size(); ++root_idx) { const auto& root = pattern2node.at(roots[root_idx].get()); From 5c23f602df7f9dfd10c0e3a700703e22f6d647ce Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 6 Apr 2023 09:02:01 +0900 Subject: [PATCH 09/16] Update src/relax/ir/dataflow_matcher.cc Co-authored-by: Jiawei Liu --- src/relax/ir/dataflow_matcher.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 3a8b84c41edb..531987114d49 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -585,8 +585,8 @@ struct MatchState { } const VarNode* matched(const PNode* p) const { - if (!match_p_r.count(p)) return nullptr; - return match_p_r.at(p)->ptr; + if (auto it = match_p_r.find(p); it != match_p_r.end()) return it->ptr; + return nullptr; } const DFPatternNode* matched(const RNode* r) const { From 08a63cd59d90d0039f893cb5745975c0e843c022 Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 6 Apr 2023 09:02:09 +0900 Subject: [PATCH 10/16] Update src/relax/ir/dataflow_matcher.cc Co-authored-by: Jiawei Liu --- src/relax/ir/dataflow_matcher.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 531987114d49..faad7abd31cb 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -590,8 +590,8 @@ struct MatchState { } const DFPatternNode* matched(const RNode* r) const { - if (!match_r_p.count(r)) return nullptr; - return match_r_p.at(r)->ptr; + if (auto it = match_r_p.find(p); it != match_r_p.end()) return it->ptr; + return nullptr; } const VarNode* matched(const PNode& p) const { return matched(&p); } From c7f97ded9598a3ea89a198cad11516deb68df4f3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 09:03:14 +0900 Subject: [PATCH 11/16] use insert for merging MatchState --- src/relax/ir/dataflow_matcher.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index faad7abd31cb..42f41ffa0e07 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -579,9 +579,8 @@ struct MatchState { } void add(const MatchState& other) { - for (const auto& [p, r] : other.match_p_r) { - add(p, r); - } + match_p_r.insert(other.match_p_r.cbegin(), other.match_p_r.cend()); + match_r_p.insert(other.match_r_p.cbegin(), other.match_r_p.cend()); } const VarNode* matched(const PNode* p) const { From 6a18c64bd6c545fe03548de88f6378d34a72f238 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 09:17:03 +0900 Subject: [PATCH 12/16] fix --- src/relax/ir/dataflow_matcher.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 42f41ffa0e07..882d8be6a87f 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -584,12 +584,12 @@ struct MatchState { } const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) return it->ptr; + if (auto it = match_p_r.find(p); it != match_p_r.end()) return it->second->ptr; return nullptr; } const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(p); it != match_r_p.end()) return it->ptr; + if (auto it = match_r_p.find(r); it != match_r_p.end()) return it->second->ptr; return nullptr; } From 55f653998635555380e90787248e146aa828b279 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 09:35:30 +0900 Subject: [PATCH 13/16] parent check is not specific to wildcard --- src/relax/ir/dataflow_matcher.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 882d8be6a87f..da8f41a6e102 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -612,12 +612,12 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt for (size_t i = 0; i < p.parents.size(); ++i) { const auto p_node_parent = p.parents[i].first; + if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { + // A parent pattern is already matched to other variable. + return std::nullopt; + } if (p_node_parent->ptr->IsInstance()) { ICHECK_EQ(p.parents.size(), r.parents.size()); - if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { - // A parent wildcard pattern is already matched to other variable. - return std::nullopt; - } result.add(p_node_parent, r.parents[i]); } } From 1c5af30e014dc735c8bb10e975cf51d0e97dc538 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 09:45:12 +0900 Subject: [PATCH 14/16] use map merge --- src/relax/ir/dataflow_matcher.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index da8f41a6e102..daa25dfdc13d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -578,9 +578,9 @@ struct MatchState { match_r_p[r] = p; } - void add(const MatchState& other) { - match_p_r.insert(other.match_p_r.cbegin(), other.match_p_r.cend()); - match_r_p.insert(other.match_r_p.cbegin(), other.match_r_p.cend()); + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); } const VarNode* matched(const PNode* p) const { @@ -655,7 +655,7 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt if (auto match_rec = TryMatch(*pchild, *rchild, m, ud_analysis)) { result.add(pchild, rchild); - result.add(*match_rec); + result.add(std::move(*match_rec)); } } if (!result.matched(pchild) || !any_cons_sat) return std::nullopt; @@ -694,7 +694,7 @@ static std::optional MatchTree( // Recursivly try to match the next subtree. if (auto matches_rec = MatchTree(*matches, current_root_idx + 1, pattern2node, var2node, matcher, roots, ud_analysis)) { - matches->add(*matches_rec); + matches->add(std::move(*matches_rec)); return matches; } // Recursive matching has failed, backtrack. From 8c04060dd24a243b869ed9e792a6f462bbc4a4be Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 09:49:39 +0900 Subject: [PATCH 15/16] cpplint --- src/relax/ir/dataflow_matcher.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index daa25dfdc13d..e3f1ddaebc92 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -584,12 +584,16 @@ struct MatchState { } const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) return it->second->ptr; + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; + } return nullptr; } const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(r); it != match_r_p.end()) return it->second->ptr; + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } return nullptr; } From f2eaddae05a631f0070ef875ab6b62ea30154f7c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 12:28:41 +0900 Subject: [PATCH 16/16] Pass and check current_match in TryMatch --- src/relax/ir/dataflow_matcher.cc | 59 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index e3f1ddaebc92..6e8211cfd314 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -608,31 +608,25 @@ struct MatchState { /** * \brief This method try to match a real node and a pattern node along with its neighbors. */ -static std::optional TryMatch(const PNode& p, const RNode& r, DFPatternMatcher* m, +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, const MatcherUseDefAnalysis& ud_analysis) { if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - MatchState result; + MatchState new_match; - for (size_t i = 0; i < p.parents.size(); ++i) { - const auto p_node_parent = p.parents[i].first; - if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) { - // A parent pattern is already matched to other variable. - return std::nullopt; - } - if (p_node_parent->ptr->IsInstance()) { - ICHECK_EQ(p.parents.size(), r.parents.size()); - result.add(p_node_parent, r.parents[i]); - } - } - - result.add(&p, &r); + new_match.add(&p, &r); // forward matching; for (const auto& [pchild, constraints] : p.children) { bool any_cons_sat = false; for (const auto& rchild : r.children) { - if (auto p = result.matched(rchild); p && p != pchild->ptr) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. continue; } @@ -654,22 +648,22 @@ static std::optional TryMatch(const PNode& p, const RNode& r, DFPatt } } } - if (!all_cons_pass || result.matched(pchild)) continue; + if (!all_cons_pass || new_match.matched(pchild)) continue; any_cons_sat = true; - if (auto match_rec = TryMatch(*pchild, *rchild, m, ud_analysis)) { - result.add(pchild, rchild); - result.add(std::move(*match_rec)); + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); } } - if (!result.matched(pchild) || !any_cons_sat) return std::nullopt; + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; } - return result; + return new_match; } static std::optional MatchTree( - const MatchState& current_matches, size_t current_root_idx, + const MatchState& current_match, size_t current_root_idx, const std::unordered_map& pattern2node, const std::unordered_map& var2node, DFPatternMatcher* matcher, const std::vector& roots, const MatcherUseDefAnalysis& ud_analysis) { @@ -677,7 +671,7 @@ static std::optional MatchTree( // Look for the next unmatched root node. for (; root_idx < roots.size(); ++root_idx) { const auto& root = pattern2node.at(roots[root_idx].get()); - if (!current_matches.matched(root)) { + if (!current_match.matched(root)) { return &root; } } @@ -688,18 +682,21 @@ static std::optional MatchTree( if (!root) { // All root nodes have been matched - return current_matches; + return current_match; } + MatchState new_match = current_match; + for (const auto& var : ud_analysis.vars) { const RNode& r_node = var2node.at(var); - if (current_matches.matched(r_node)) continue; - if (auto matches = TryMatch(*root, r_node, matcher, ud_analysis)) { + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { // Recursivly try to match the next subtree. - if (auto matches_rec = MatchTree(*matches, current_root_idx + 1, pattern2node, var2node, - matcher, roots, ud_analysis)) { - matches->add(std::move(*matches_rec)); - return matches; + new_match.add(std::move(*match)); + if (auto match_rec = MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, + matcher, roots, ud_analysis)) { + new_match.add(std::move(*match_rec)); + return new_match; } // Recursive matching has failed, backtrack. continue;