From 05a2528a1f997f201b8cb02d9125d632f0fe1a9b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Mar 2023 07:55:47 +0900 Subject: [PATCH 1/6] Clean up undo stack for parent and child nodes properly --- include/tvm/relax/dataflow_matcher.h | 10 +-- src/relax/analysis/udchain.cc | 26 ++++---- src/relax/ir/dataflow_matcher.cc | 70 +++++++++++++-------- tests/python/relax/test_dataflow_pattern.py | 33 ++++++++++ 4 files changed, 95 insertions(+), 44 deletions(-) diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 498f77a3f7d5..e4268be882d7 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -58,12 +58,12 @@ Optional> ExtractMatchedExpr( * \param dfb The function to match. * \param start_hint The starting point expression to match to distinguish multiple matches. * \param must_include_hint If start_hint is given, the return pattern must include start_hint. - * \return tvm::runtime::Map + * \return Matched patterns and corresponding bound variables */ -TVM_DLL tvm::runtime::Map MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb, - Optional start_hint = NullOpt, - bool must_include_hint = false); +TVM_DLL Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 77e52408a710..53f13eee5b89 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -65,36 +65,36 @@ class UDChain : public relax::ExprVisitor { std::pair>, runtime::Array> FunctionUseDef( const Function& fn) { UDChain udchain; - udchain.VisitExpr_(fn.get()); + udchain.VisitExpr(fn); Map> user_map; Array fn_outs; - for (const auto& kv : udchain.to_users) { + for (const auto& [var, users] : udchain.to_users) { Array uses{}; - uses.reserve(kv.second.size()); - for (const auto& v : kv.second) { - if (nullptr == v && - fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), GetRef(kv.first))) { - fn_outs.push_back(GetRef(kv.first)); + uses.reserve(users.size()); + for (const auto& v : users) { + if (v == nullptr && + std::find(fn_outs.begin(), fn_outs.end(), GetRef(var)) == fn_outs.end() ) { + fn_outs.push_back(GetRef(var)); } else { uses.push_back(GetRef(v)); } } - user_map.Set(GetRef(kv.first), std::move(uses)); + user_map.Set(GetRef(var), std::move(uses)); } return std::make_pair(std::move(user_map), std::move(fn_outs)); } runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { UDChain udchain; - udchain.VisitBindingBlock_(dfb.get()); + udchain.VisitBindingBlock(dfb); runtime::Map> ret; - for (const auto& kv : udchain.to_users) { + for (const auto& [var, users] : udchain.to_users) { Array uses{}; - uses.reserve(kv.second.size()); - for (const auto& v : kv.second) uses.push_back(GetRef(v)); - ret.Set(GetRef(kv.first), std::move(uses)); + uses.reserve(users.size()); + for (const auto& v : users) uses.push_back(GetRef(v)); + ret.Set(GetRef(var), std::move(uses)); } return ret; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index d1b7a1d62cda..1c8267a79cf4 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -540,33 +541,48 @@ struct RNode { /** * \brief This method try to match a real node and a pattern node along with its neighbors. */ -static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, - const std::map>& def2use, - const std::map>& use2def) { - if (p->matched != nullptr && p->matched == r->ptr) return true; // matched before. - if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; +using UndoStack = std::stack>; +static std::optional try_match( + PNode* p, RNode* r, DFPatternMatcher* m, + const std::map>& def2use, + const std::map>& use2def) { + if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. + if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; - std::stack> undo_stack{}; + UndoStack undo; - const auto commit = [&undo_stack](PNode* p, RNode* r) { + const auto commit = [&undo](PNode* p, RNode* r) { // match with each other. // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? if (p->ptr == r->matched) { ICHECK_EQ(p->matched, r->ptr); return; } - ICHECK(r->matched == nullptr); + // TODO(ganler, masahi): Why this condition can fail? + // ICHECK(r->matched == nullptr); p->matched = r->ptr; r->matched = p->ptr; - undo_stack.emplace(p, r); + undo.emplace(p, r); }; - const auto quit = [&undo_stack] { - while (!undo_stack.empty()) { - auto& top = undo_stack.top(); - top.first->matched = nullptr; - top.second->matched = nullptr; - undo_stack.pop(); + const auto quit = [&undo] { + while (!undo.empty()) { + auto& [p_node, r_node] = undo.top(); + p_node->matched = nullptr; + r_node->matched = nullptr; + undo.pop(); + } + return std::nullopt; + }; + + const auto try_match_update_undo = [&](PNode* p, RNode* r) { + if (auto undo_more = try_match(p, r, m, def2use, use2def)) { + while (!undo_more->empty()) { + auto& [p_node, r_node] = undo_more->top(); + undo.emplace(p_node, r_node); + undo_more->pop(); + } + return true; } return false; }; @@ -604,7 +620,7 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, // try all parent R nodes that are not matched yet. // as long as ppattern can match one node. - if (!pparent->matched && try_match(pparent, rparent, m, def2use, use2def)) { + if (!pparent->matched && try_match_update_undo(pparent, rparent)) { commit(pparent, rparent); break; } @@ -639,14 +655,14 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, if (!all_cons_pass) continue; any_cons_sat = true; - if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) { + if (!pchild->matched && try_match_update_undo(pchild, rchild)) { commit(pchild, rchild); break; } } if (!pchild->matched || !any_cons_sat) return quit(); } - return true; + return undo; } class MatcherUseDefAnalysis : public relax::ExprVisitor { @@ -686,13 +702,12 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { } }; -Map MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - Optional start_hint, bool must_include_hint) { +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + Optional start_hint, bool must_include_hint) { if (ctx->src_ordered.size() == 0) { - return {}; + return NullOpt; } - Map ret; // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; ICHECK(!must_include_hint || start_hint.defined()) @@ -737,12 +752,15 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d } } + Map ret; + if (start_hint) { auto rnode_ptr = var2node.at(start_hint.value().get()); for (auto& p_node : pattern2node) { if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) + for (const auto& [df_pattern, pattern_node] : pattern2node) { ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); + } return ret; } } @@ -757,15 +775,15 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d if (start_hint.defined() && start_hint.value().get() == var) continue; RNode& r_node = var2node[var]; if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) + for (const auto& [df_pattern, pattern_node] : pattern2node) { ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - + } return ret; } } } - return ret; + return NullOpt; } TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 76bce47f7ff7..f18244096ec2 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1042,5 +1042,38 @@ def main( assert out[V_weight_pat].name_hint == "w2" +def test_attention_fake_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x2, w1) + lv2 = R.matmul(x2, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + dfb = QKV_proj["main"].body.blocks[0] + assert ctx.match_dfb(dfb) is None + + if __name__ == "__main__": tvm.testing.main() From fe4ee72bd65a24650b1e099d2b545482ef543c71 Mon Sep 17 00:00:00 2001 From: masahi Date: Fri, 31 Mar 2023 08:37:24 +0900 Subject: [PATCH 2/6] Update src/relax/analysis/udchain.cc Co-authored-by: Jiawei Liu --- src/relax/analysis/udchain.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 53f13eee5b89..1c49fd581f7d 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -75,7 +75,7 @@ std::pair>, runtime::Array> FunctionU uses.reserve(users.size()); for (const auto& v : users) { if (v == nullptr && - std::find(fn_outs.begin(), fn_outs.end(), GetRef(var)) == fn_outs.end() ) { + std::find(fn_outs.begin(), fn_outs.end(), GetRef(var)) == fn_outs.end()) { fn_outs.push_back(GetRef(var)); } else { uses.push_back(GetRef(v)); From 8e8cab19db4c27cf1eb4736a597ec02d205b2a51 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Mar 2023 08:26:53 +0900 Subject: [PATCH 3/6] minor change --- src/relax/ir/dataflow_matcher.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 1c8267a79cf4..59170d54b883 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -558,8 +558,6 @@ static std::optional try_match( ICHECK_EQ(p->matched, r->ptr); return; } - // TODO(ganler, masahi): Why this condition can fail? - // ICHECK(r->matched == nullptr); p->matched = r->ptr; r->matched = p->ptr; undo.emplace(p, r); @@ -567,7 +565,7 @@ static std::optional try_match( const auto quit = [&undo] { while (!undo.empty()) { - auto& [p_node, r_node] = undo.top(); + const auto& [p_node, r_node] = undo.top(); p_node->matched = nullptr; r_node->matched = nullptr; undo.pop(); From c27b17de91bb15abdccd0d273adbaeee02bd175a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Mar 2023 08:47:58 +0900 Subject: [PATCH 4/6] stack -> vector --- src/relax/ir/dataflow_matcher.cc | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 59170d54b883..fc29e893f1c5 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -541,15 +541,15 @@ struct RNode { /** * \brief This method try to match a real node and a pattern node along with its neighbors. */ -using UndoStack = std::stack>; -static std::optional try_match( +using UndoItems = std::vector>; +static std::optional try_match( PNode* p, RNode* r, DFPatternMatcher* m, const std::map>& def2use, const std::map>& use2def) { if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; - UndoStack undo; + UndoItems undo; const auto commit = [&undo](PNode* p, RNode* r) { // match with each other. @@ -560,27 +560,20 @@ static std::optional try_match( } p->matched = r->ptr; r->matched = p->ptr; - undo.emplace(p, r); + undo.emplace_back(p, r); }; const auto quit = [&undo] { - while (!undo.empty()) { - const auto& [p_node, r_node] = undo.top(); + for (auto& [p_node, r_node] : undo) { p_node->matched = nullptr; r_node->matched = nullptr; - undo.pop(); } return std::nullopt; }; const auto try_match_update_undo = [&](PNode* p, RNode* r) { if (auto undo_more = try_match(p, r, m, def2use, use2def)) { - while (!undo_more->empty()) { - auto& [p_node, r_node] = undo_more->top(); - undo.emplace(p_node, r_node); - undo_more->pop(); - } - return true; + undo.insert(undo.end(), undo_more->begin(), undo_more->end()); } return false; }; From 191b055e43745c7a5ea960c325f0350d1dbbe102 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 31 Mar 2023 08:49:31 +0900 Subject: [PATCH 5/6] remove stack header --- src/relax/ir/dataflow_matcher.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index fc29e893f1c5..c39cf528a445 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -34,7 +34,6 @@ #include #include #include -#include #include #include #include From 4bce8b7445bd2a7e7d7eef8eb319f33d15341525 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 1 Apr 2023 01:54:46 +0900 Subject: [PATCH 6/6] fix accidentally removed statement from recent commit --- src/relax/ir/dataflow_matcher.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c39cf528a445..88381d6e26d9 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -573,6 +573,7 @@ static std::optional try_match( const auto try_match_update_undo = [&](PNode* p, RNode* r) { if (auto undo_more = try_match(p, r, m, def2use, use2def)) { undo.insert(undo.end(), undo_more->begin(), undo_more->end()); + return true; } return false; };