diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index e4268be882d7..cf7c58f093e6 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -51,19 +51,12 @@ Optional> ExtractMatchedExpr( /** * \brief Match a sub-graph in a DataflowBlock with a graph of patterns and return the mapping. - * \note This algorithm returns the first matched sub-graph. Use `start_hint` to specify the - * starting point of the matching so that we can distinguish multiple matches. - * * \param ctx The graph-wise patterns. * \param dfb The function to match. - * \param start_hint The starting point expression to match to distinguish multiple matches. - * \param must_include_hint If start_hint is given, the return pattern must include start_hint. * \return Matched patterns and corresponding bound variables */ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb, - Optional start_hint = NullOpt, - bool must_include_hint = false); + const DataflowBlock& dfb); } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/dpl/context.py b/python/tvm/relax/dpl/context.py index 69a5e70ed0f1..16d86fb32df0 100644 --- a/python/tvm/relax/dpl/context.py +++ b/python/tvm/relax/dpl/context.py @@ -17,7 +17,7 @@ """The Graph Matching Context Manager for Dataflow Pattern Language.""" -from typing import Optional, Dict +from typing import Dict import tvm from ..expr import DataflowBlock, Var @@ -63,8 +63,6 @@ def current() -> "PatternContext": def match_dfb( self, dfb: DataflowBlock, - start_hint: Optional[Var] = None, - must_include_hint: bool = False, ) -> Dict[DFPattern, Var]: """ Match a DataflowBlock via a graph of DFPattern and corresponding constraints @@ -73,14 +71,10 @@ def match_dfb( ---------- dfb : DataflowBlock The DataflowBlock to match - start_hint : Optional[Var], optional - Indicating the starting expression to match, by default None - must_include_hint : bool, optional - Whether the start_hint expression must be matched, by default False Returns ------- Dict[DFPattern, Var] The mapping from DFPattern to matched expression """ - return ffi.match_dfb(self, dfb, start_hint, must_include_hint) # type: ignore + return ffi.match_dfb(self, dfb) # type: ignore diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c1306ff69093..6e8211cfd314 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -523,109 +523,114 @@ bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_o TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } + + void VisitExpr_(const DataflowVarNode* op) override { + VisitExpr_(static_cast(op)); + } +}; + struct PNode { const DFPatternNode* ptr; - const VarNode* matched = nullptr; std::vector&>> children; std::vector&>> parents; }; struct RNode { const VarNode* ptr; - const DFPatternNode* matched = nullptr; std::vector children; std::vector parents; }; -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -using UndoItems = std::vector>; -static std::optional try_match( - PNode* p, RNode* r, DFPatternMatcher* m, - const std::map>& def2use, - const std::map>& use2def) { - if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. - if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; - - UndoItems undo; - - const auto commit = [&undo](PNode* p, RNode* r) { - // match with each other. - // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? - if (p->ptr == r->matched) { - ICHECK_EQ(p->matched, r->ptr); - return; - } - p->matched = r->ptr; - r->matched = p->ptr; - undo.emplace_back(p, r); - }; +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } - const auto quit = [&undo] { - for (auto& [p_node, r_node] : undo) { - p_node->matched = nullptr; - r_node->matched = nullptr; - } - return std::nullopt; - }; + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); + } - const auto try_match_update_undo = [&](PNode* p, RNode* r) { - if (auto undo_more = try_match(p, r, m, def2use, use2def)) { - undo.insert(undo.end(), undo_more->begin(), undo_more->end()); - return true; + const VarNode* matched(const PNode* p) const { + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; } - return false; - }; + return nullptr; + } - commit(p, r); + const DFPatternNode* matched(const RNode* r) const { + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } + return nullptr; + } - // match parent patterns. - for (auto& [pparent, constraints] : p->parents) { - bool any_cons_sat = false; - for (auto& rparent : r->parents) { - // skip if mismatch. - if (rparent->matched && rparent->matched != pparent->ptr) continue; + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - const auto& uses = def2use.at(rparent->ptr); + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; +}; - // check edge constraints. - bool cons_sat = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - cons_sat = false; - break; - } +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - if (cons.index != -1) { - const auto& callees = use2def.at(r->ptr); - if (callees.size() <= static_cast(cons.index) || - callees[cons.index] != rparent->ptr) { - cons_sat = false; - break; - } - } - } - if (!cons_sat) continue; - any_cons_sat = true; + MatchState new_match; - // try all parent R nodes that are not matched yet. - // as long as ppattern can match one node. - if (!pparent->matched && try_match_update_undo(pparent, rparent)) { - commit(pparent, rparent); - break; - } - } - if (!pparent->matched || !any_cons_sat) return quit(); - } + new_match.add(&p, &r); // forward matching; - for (auto& [pchild, constraints] : p->children) { + for (const auto& [pchild, constraints] : p.children) { bool any_cons_sat = false; - for (auto& rchild : r->children) { - if (rchild->matched && rchild->matched != pchild->ptr) continue; + for (const auto& rchild : r.children) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. + continue; + } - const auto& uses = def2use.at(r->ptr); + const auto& uses = ud_analysis.def2use.at(r.ptr); // check edge constraints. bool all_cons_pass = true; @@ -636,88 +641,87 @@ static std::optional try_match( } if (cons.index != -1) { - const auto& callees = use2def.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r->ptr) { + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { all_cons_pass = false; break; } } } - if (!all_cons_pass) continue; + if (!all_cons_pass || new_match.matched(pchild)) continue; any_cons_sat = true; - if (!pchild->matched && try_match_update_undo(pchild, rchild)) { - commit(pchild, rchild); - break; + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); } } - if (!pchild->matched || !any_cons_sat) return quit(); + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; } - return undo; + + return new_match; } -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; +static std::optional MatchTree( + const MatchState& current_match, size_t current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const MatcherUseDefAnalysis& ud_analysis) { + auto get_next_root = [&](size_t root_idx) -> const PNode* { + // Look for the next unmatched root node. + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_match.matched(root)) { + return &root; + } + } + return nullptr; + }; - const VarNode* cur_user_; + const auto root = get_next_root(current_root_idx); - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; + if (!root) { + // All root nodes have been matched + return current_match; } - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; + MatchState new_match = current_match; - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); + for (const auto& var : ud_analysis.vars) { + const RNode& r_node = var2node.at(var); + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { + // Recursivly try to match the next subtree. + new_match.add(std::move(*match)); + if (auto match_rec = MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, + matcher, roots, ud_analysis)) { + new_match.add(std::move(*match_rec)); + return new_match; } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); - } - - void VisitExpr_(const DataflowVarNode* op) override { - VisitExpr_(static_cast(op)); + // Recursive matching has failed, backtrack. + continue; + } } -}; -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - Optional start_hint, bool must_include_hint) { - if (ctx->src_ordered.size() == 0) { - return NullOpt; - } + return std::nullopt; +} +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - ICHECK(!must_include_hint || start_hint.defined()) - << "must_include_hint is only supported with start_hint."; const auto var2val = AnalyzeVar2Value(dfb); DFPatternMatcher matcher(var2val); MatcherUseDefAnalysis ud_analysis; ud_analysis.VisitBindingBlock_(dfb.get()); - const auto& def2use = ud_analysis.def2use; - const auto& caller2callees = ud_analysis.caller2callees; // First construct a graph of PNode and RNode. std::unordered_map var2node; var2node.reserve(dfb->bindings.size()); for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = def2use.at(cur_var); + const auto& uses = ud_analysis.def2use.at(cur_var); RNode& cur_node = var2node[cur_var]; cur_node.ptr = cur_var; for (const VarNode* use : uses) { @@ -731,8 +735,9 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl std::unordered_map pattern2node; pattern2node.reserve(ctx->constraints.size()); - for (const auto& [def_pattern, uses] : ctx->constraints) { + for (const auto& def_pattern : ctx->src_ordered) { PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->constraints.at(def_pattern); def_node.ptr = def_pattern.get(); def_node.children.reserve(uses.size()); for (const auto& [use_pattern, cons] : uses) { @@ -743,35 +748,24 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl } } - Map ret; - - if (start_hint) { - auto rnode_ptr = var2node.at(start_hint.value().get()); - for (auto& p_node : pattern2node) { - if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) { - ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - } - return ret; - } + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); } - - if (must_include_hint) return ret; } - PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()]; + if (roots.empty()) { + return NullOpt; + } - if (!pnode_start.matched) { - for (const auto& var : ud_analysis.vars) { - if (start_hint.defined() && start_hint.value().get() == var) continue; - RNode& r_node = var2node[var]; - if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) { - ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - } - return ret; - } + if (auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, ud_analysis)) { + Map ret; + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); } + return ret; } return NullOpt; diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 5580f6a1ab74..4d225ceecfe7 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -406,6 +406,7 @@ PatternContext::PatternContext(bool incremental) { << "Incremental context needs to be built inside a existing context."; n->allow_extern_use = pattern_ctx_stack().top()->allow_extern_use; n->constraints = pattern_ctx_stack().top()->constraints; + n->src_ordered = pattern_ctx_stack().top()->src_ordered; } data_ = std::move(n); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index b85543cafcb8..a73a62eeef8d 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -519,51 +519,6 @@ def main( return lv6 -def test_single_cbr(): - with PatternContext() as ctx: - ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("bias_add") - >> is_call_dps_packed("my_relu") - ) - dfb = CBRx2["main"].body.blocks[0] - matched = ctx.match_dfb(dfb) - assert matched - - with PatternContext() as ctx: - chain = ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("bias_add") - >> is_call_dps_packed("my_relu") - ) - dfb = CBRx2["main"].body.blocks[0] - # we want to specifically match the first CBR (lv0) - matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var) - assert matched - assert matched[chain[0]] == dfb.bindings[0].var - # we want to specifically match the second CBR (lv3) - matched = ctx.match_dfb(dfb, start_hint=dfb.bindings[3].var) - assert matched - assert matched[chain[0]] == dfb.bindings[3].var - - -def test_counter_single_crb(): - with PatternContext() as ctx: - ( - is_call_dps_packed("conv1x1") - >> is_call_dps_packed("my_relu") - >> is_call_dps_packed("bias_add") - ) - dfb = CBRx2["main"].body.blocks[0] - assert not ctx.match_dfb(dfb) - # Quickly fails unpromising matches by assuming `start_hint` must be matched by a pattern. - # This is usually faster than the full match: - # Full match: let one pattern to match -> all Var: complexity ~ #Var - # must_include_hint: let `start_hint` to match -> all patterns: complexity ~ #patterns - # Usually #patterns is much smaller than #Var, so this is faster. - assert not ctx.match_dfb(dfb, start_hint=dfb.bindings[0].var, must_include_hint=True) - - def test_nested_context(): dfb = CBRx2["main"].body.blocks[0] with PatternContext() as ctx0: