From 0e09ddcfc07fe2126d7335cf630b64a7c21b1d82 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Mar 2023 06:40:09 +0900 Subject: [PATCH 1/6] Remove all non-determinsm from graph matching --- include/tvm/relax/dataflow_pattern.h | 24 +++++++-- src/relax/ir/dataflow_matcher.cc | 75 +++++++++++++++------------- 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 37640750a8ef..14516b6a8866 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -191,7 +191,10 @@ class PatternContextNode : public Object { kMustNot, /*!< All nodes except outputs only have internal depedencies in the matched graph. */ } allow_extern_use = kMay; // src node -> constraints. - std::map>> constraints; + // Dst nodes are kept in a vector to keep them ordered. + std::map>>> constraints; + // Keep a seperate vector of patterns to process constraints in a fixed order + std::vector src_ordered; static constexpr const char* _type_key = "relax.dpl.PatternContext"; TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object); @@ -224,9 +227,22 @@ class PatternContext : public ObjectRef { * \param cons The constraint type. \sa PairCons */ void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) { - auto& vec = (*this)->constraints[producer][consumer]; - ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << "Constraint already exists"; - vec.push_back(cons); + auto& pairs = (*this)->constraints[producer]; + auto it = std::find_if(pairs.begin(), pairs.end(), + [consumer](auto p) { return p.first == consumer; }); + if (it == pairs.end()) { + pairs.emplace_back(consumer, std::vector{cons}); + } else { + auto& vec = it->second; + ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) + << "Constraint already exists"; + vec.push_back(cons); + } + + auto& patterns = (*this)->src_ordered; + if (std::find(patterns.begin(), patterns.end(), producer) == patterns.end()) { + patterns.push_back(producer); + } } /*! \brief Get the pass context object on the top of the stack */ diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c6d705b5b482..e85b95f82434 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -541,7 +541,7 @@ struct RNode { * \brief This method try to match a real node and a pattern node along with its neighbors. */ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, - const std::map>& def2use, + const std::map>& def2use, const std::map>& use2def) { if (nullptr != p->matched && p->matched == r->ptr) return true; // matched before. if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; @@ -550,6 +550,12 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, const auto commit = [&undo_stack](PNode* p, RNode* r) { // match with each other. + // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? + if (p->ptr == r->matched) { + ICHECK_EQ(p->matched, r->ptr); + return; + } + ICHECK_EQ(r->matched, nullptr); p->matched = r->ptr; r->matched = p->ptr; undo_stack.emplace(p, r); @@ -568,18 +574,13 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, commit(p, r); // match parent patterns. - for (auto& pparent_pairs : p->parents) { - PNode* pparent = pparent_pairs.first; - const std::vector& constraints = pparent_pairs.second; - + for (auto& [pparent, constraints] : p->parents) { bool any_cons_sat = false; for (auto& rparent : r->parents) { // skip if mismatch. if (rparent->matched && rparent->matched != pparent->ptr) continue; const auto& uses = def2use.at(rparent->ptr); - // skip if `rparent` is not used by `r`. - if (uses.cend() == uses.find(r->ptr)) continue; // check edge constraints. bool cons_sat = true; @@ -612,15 +613,12 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, } // forward matching; - for (auto& pchild_pairs : p->children) { - PNode* pchild = pchild_pairs.first; - const std::vector& constraints = pchild_pairs.second; + for (auto& [pchild, constraints] : p->children) { bool any_cons_sat = false; for (auto& rchild : r->children) { if (rchild->matched && rchild->matched != pchild->ptr) continue; const auto& uses = def2use.at(r->ptr); - if (uses.cend() == uses.find(rchild->ptr)) continue; // check edge constraints. bool all_cons_pass = true; @@ -648,13 +646,13 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, } if (!pchild->matched || !any_cons_sat) return quit(); } - return true; } class MatcherUseDefAnalysis : public relax::ExprVisitor { public: - std::map> def2use; + std::vector vars; + std::map> def2use; // caller -> callee table. std::map> caller2callees; @@ -671,7 +669,15 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { void VisitExpr_(const VarNode* op) override { if (nullptr == cur_user_) return; - def2use[op].insert(cur_user_); + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + caller2callees[cur_user_].push_back(op); } @@ -682,6 +688,10 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { Map MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, Optional start_hint, bool must_include_hint) { + if (ctx->src_ordered.size() == 0) { + return {}; + } + Map ret; // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; @@ -691,7 +701,6 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d const auto var2val = AnalyzeVar2Value(dfb); DFPatternMatcher matcher(var2val); - // std::map> MatcherUseDefAnalysis ud_analysis; ud_analysis.VisitBindingBlock_(dfb.get()); const auto& def2use = ud_analysis.def2use; @@ -701,9 +710,8 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d std::unordered_map var2node; var2node.reserve(dfb->bindings.size()); - for (const auto& du : def2use) { - const VarNode* cur_var = du.first; - const std::set& uses = du.second; + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = def2use.at(cur_var); RNode& cur_node = var2node[cur_var]; cur_node.ptr = cur_var; for (const VarNode* use : uses) { @@ -717,17 +725,13 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d std::unordered_map pattern2node; pattern2node.reserve(ctx->constraints.size()); - for (const auto& def2use_pattern : ctx->constraints) { - const DFPatternNode* def_pattern = def2use_pattern.first.get(); - const std::map>& uses = def2use_pattern.second; - PNode& def_node = pattern2node[def_pattern]; - def_node.ptr = def_pattern; + for (const auto& [def_pattern, uses] : ctx->constraints) { + PNode& def_node = pattern2node[def_pattern.get()]; + def_node.ptr = def_pattern.get(); def_node.children.reserve(uses.size()); - for (const auto& use : uses) { - const auto& cons = use.second; - const DFPatternNode* use_pattern = use.first.get(); - PNode& use_node = pattern2node[use_pattern]; - use_node.ptr = use_pattern; + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); use_node.parents.emplace_back(&def_node, std::ref(cons)); def_node.children.emplace_back(&use_node, std::ref(cons)); } @@ -747,14 +751,15 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d if (must_include_hint) return ret; } - PNode* pnode_start = &pattern2node.begin()->second; + PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()]; - if (!pnode_start->matched) { - for (auto& rpair : var2node) { - if (start_hint.defined() && start_hint.value().get() == rpair.first) continue; - if (try_match(pnode_start, &rpair.second, &matcher, def2use, caller2callees)) { - for (auto ppair : pattern2node) - ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + if (!pnode_start.matched) { + for (const auto& var : ud_analysis.vars) { + if (start_hint.defined() && start_hint.value().get() == var) continue; + RNode& r_node = var2node[var]; + if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { + for (const auto& [df_pattern, pattern_node] : pattern2node) + ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); return ret; } From e9a116af28aefb330a86f993a7f8e2e29648d055 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Mar 2023 06:45:10 +0900 Subject: [PATCH 2/6] add test --- tests/python/relax/test_dataflow_pattern.py | 45 +++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index a40faf3bcbc1..9679e14fffe7 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1006,5 +1006,50 @@ def rewriter(_, matchings): tvm.ir.assert_structural_equal(rewritten, expected) +def test_attention_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + # TODO(masahi): Automate addition of used_by constraints during is_op + inp_pat.used_by(matmul1, 0) + inp_pat.used_by(matmul2, 0) + inp_pat.used_by(matmul3, 0) + + Q_weight_pat.only_used_by(matmul1, 1) + K_weight_pat.only_used_by(matmul2, 1) + V_weight_pat.only_used_by(matmul3, 1) + + dfb = QKV_proj["main"].body.blocks[0] + out = ctx.match_dfb(dfb) + + assert out[Q_weight_pat].name_hint == "w0" + assert out[K_weight_pat].name_hint == "w1" + assert out[V_weight_pat].name_hint == "w2" + + if __name__ == "__main__": tvm.testing.main() From ccd23dd5c643a93a9b53d3bbca787bab66165424 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Mar 2023 07:05:22 +0900 Subject: [PATCH 3/6] typo --- include/tvm/relax/dataflow_pattern.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 14516b6a8866..144a7f45bf57 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -193,7 +193,7 @@ class PatternContextNode : public Object { // src node -> constraints. // Dst nodes are kept in a vector to keep them ordered. std::map>>> constraints; - // Keep a seperate vector of patterns to process constraints in a fixed order + // Keep a separate vector of patterns to process constraints in a fixed order. std::vector src_ordered; static constexpr const char* _type_key = "relax.dpl.PatternContext"; From 7174cbc23847e1dbc3221f2831c450f69e163ea8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Mar 2023 08:28:46 +0900 Subject: [PATCH 4/6] try fixing compile error for gcc --- src/relax/ir/dataflow_matcher.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index e85b95f82434..4b1ad2a82ee4 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -555,7 +555,7 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, ICHECK_EQ(p->matched, r->ptr); return; } - ICHECK_EQ(r->matched, nullptr); + ICHECK(r->matched == nullptr); p->matched = r->ptr; r->matched = p->ptr; undo_stack.emplace(p, r); From 6032281d077d012a42709450797c40e0d025b026 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 29 Mar 2023 20:29:56 +0900 Subject: [PATCH 5/6] more style update --- src/relax/ir/dataflow_matcher.cc | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 4b1ad2a82ee4..863cda86e551 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -543,7 +543,7 @@ struct RNode { static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, const std::map>& def2use, const std::map>& use2def) { - if (nullptr != p->matched && p->matched == r->ptr) return true; // matched before. + if (p->matched != nullptr && p->matched == r->ptr) return true; // matched before. if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; std::stack> undo_stack{}; @@ -585,15 +585,15 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, // check edge constraints. bool cons_sat = true; for (const auto& cons : constraints) { - if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { cons_sat = false; break; } - if (-1 != cons.index) { + if (cons.index != -1) { const auto& callees = use2def.at(r->ptr); - if (static_cast(cons.index) >= callees.size() || - rparent->ptr != callees[cons.index]) { + if (callees.size() <= static_cast(cons.index) || + callees[cons.index] != rparent->ptr) { cons_sat = false; break; } @@ -623,14 +623,14 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, // check edge constraints. bool all_cons_pass = true; for (const auto& cons : constraints) { - if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { all_cons_pass = false; break; } - if (-1 != cons.index) { + if (cons.index != -1) { const auto& callees = use2def.at(rchild->ptr); - if (static_cast(cons.index) >= callees.size() || r->ptr != callees[cons.index]) { + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r->ptr) { all_cons_pass = false; break; } @@ -737,13 +737,12 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d } } - if (start_hint.defined()) { - Var v = start_hint.value(); - auto rnode_ptr = var2node.find(v.get()); - for (auto& ppair : pattern2node) { - if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use, caller2callees)) { - for (auto ppair : pattern2node) - ret.Set(GetRef(ppair.first), GetRef(ppair.second.matched)); + if (start_hint) { + auto rnode_ptr = var2node.at(start_hint.value().get()); + for (auto& [df_pattern, pattern_node] : pattern2node) { + if (try_match(&pattern_node, &rnode_ptr, &matcher, def2use, caller2callees)) { + for (const auto& [df_pattern, pattern_node] : pattern2node) + ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); return ret; } } From 278bc9bbdf7d9b12c44bb78b71f91e4872b14353 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 30 Mar 2023 04:07:50 +0900 Subject: [PATCH 6/6] suppress compile warning --- src/relax/ir/dataflow_matcher.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 863cda86e551..d1b7a1d62cda 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -739,8 +739,8 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d if (start_hint) { auto rnode_ptr = var2node.at(start_hint.value().get()); - for (auto& [df_pattern, pattern_node] : pattern2node) { - if (try_match(&pattern_node, &rnode_ptr, &matcher, def2use, caller2callees)) { + for (auto& p_node : pattern2node) { + if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { for (const auto& [df_pattern, pattern_node] : pattern2node) ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); return ret;