diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 498f77a3f7d5..e4268be882d7 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -58,12 +58,12 @@ Optional> ExtractMatchedExpr( * \param dfb The function to match. * \param start_hint The starting point expression to match to distinguish multiple matches. * \param must_include_hint If start_hint is given, the return pattern must include start_hint. - * \return tvm::runtime::Map + * \return Matched patterns and corresponding bound variables */ -TVM_DLL tvm::runtime::Map MatchGraph(const PatternContext& ctx, - const DataflowBlock& dfb, - Optional start_hint = NullOpt, - bool must_include_hint = false); +TVM_DLL Optional> MatchGraph(const PatternContext& ctx, + const DataflowBlock& dfb, + Optional start_hint = NullOpt, + bool must_include_hint = false); } // namespace relax } // namespace tvm diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 77e52408a710..1c49fd581f7d 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -65,36 +65,36 @@ class UDChain : public relax::ExprVisitor { std::pair>, runtime::Array> FunctionUseDef( const Function& fn) { UDChain udchain; - udchain.VisitExpr_(fn.get()); + udchain.VisitExpr(fn); Map> user_map; Array fn_outs; - for (const auto& kv : udchain.to_users) { + for (const auto& [var, users] : udchain.to_users) { Array uses{}; - uses.reserve(kv.second.size()); - for (const auto& v : kv.second) { - if (nullptr == v && - fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), GetRef(kv.first))) { - fn_outs.push_back(GetRef(kv.first)); + uses.reserve(users.size()); + for (const auto& v : users) { + if (v == nullptr && + std::find(fn_outs.begin(), fn_outs.end(), GetRef(var)) == fn_outs.end()) { + fn_outs.push_back(GetRef(var)); } else { uses.push_back(GetRef(v)); } } - user_map.Set(GetRef(kv.first), std::move(uses)); + user_map.Set(GetRef(var), std::move(uses)); } return std::make_pair(std::move(user_map), std::move(fn_outs)); } runtime::Map> DataflowBlockUseDef(const DataflowBlock& dfb) { UDChain udchain; - udchain.VisitBindingBlock_(dfb.get()); + udchain.VisitBindingBlock(dfb); runtime::Map> ret; - for (const auto& kv : udchain.to_users) { + for (const auto& [var, users] : udchain.to_users) { Array uses{}; - uses.reserve(kv.second.size()); - for (const auto& v : kv.second) uses.push_back(GetRef(v)); - ret.Set(GetRef(kv.first), std::move(uses)); + uses.reserve(users.size()); + for (const auto& v : users) uses.push_back(GetRef(v)); + ret.Set(GetRef(var), std::move(uses)); } return ret; } diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index d1b7a1d62cda..88381d6e26d9 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include #include @@ -540,33 +540,40 @@ struct RNode { /** * \brief This method try to match a real node and a pattern node along with its neighbors. */ -static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, - const std::map>& def2use, - const std::map>& use2def) { - if (p->matched != nullptr && p->matched == r->ptr) return true; // matched before. - if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return false; +using UndoItems = std::vector>; +static std::optional try_match( + PNode* p, RNode* r, DFPatternMatcher* m, + const std::map>& def2use, + const std::map>& use2def) { + if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched before. + if (!m->Match(GetRef(p->ptr), GetRef(r->ptr))) return std::nullopt; - std::stack> undo_stack{}; + UndoItems undo; - const auto commit = [&undo_stack](PNode* p, RNode* r) { + const auto commit = [&undo](PNode* p, RNode* r) { // match with each other. // TODO(ganler, masahi): Why commit on the same p-r pair happens more than once? if (p->ptr == r->matched) { ICHECK_EQ(p->matched, r->ptr); return; } - ICHECK(r->matched == nullptr); p->matched = r->ptr; r->matched = p->ptr; - undo_stack.emplace(p, r); + undo.emplace_back(p, r); }; - const auto quit = [&undo_stack] { - while (!undo_stack.empty()) { - auto& top = undo_stack.top(); - top.first->matched = nullptr; - top.second->matched = nullptr; - undo_stack.pop(); + const auto quit = [&undo] { + for (auto& [p_node, r_node] : undo) { + p_node->matched = nullptr; + r_node->matched = nullptr; + } + return std::nullopt; + }; + + const auto try_match_update_undo = [&](PNode* p, RNode* r) { + if (auto undo_more = try_match(p, r, m, def2use, use2def)) { + undo.insert(undo.end(), undo_more->begin(), undo_more->end()); + return true; } return false; }; @@ -604,7 +611,7 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, // try all parent R nodes that are not matched yet. // as long as ppattern can match one node. - if (!pparent->matched && try_match(pparent, rparent, m, def2use, use2def)) { + if (!pparent->matched && try_match_update_undo(pparent, rparent)) { commit(pparent, rparent); break; } @@ -639,14 +646,14 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m, if (!all_cons_pass) continue; any_cons_sat = true; - if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) { + if (!pchild->matched && try_match_update_undo(pchild, rchild)) { commit(pchild, rchild); break; } } if (!pchild->matched || !any_cons_sat) return quit(); } - return true; + return undo; } class MatcherUseDefAnalysis : public relax::ExprVisitor { @@ -686,13 +693,12 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor { } }; -Map MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - Optional start_hint, bool must_include_hint) { +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + Optional start_hint, bool must_include_hint) { if (ctx->src_ordered.size() == 0) { - return {}; + return NullOpt; } - Map ret; // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; ICHECK(!must_include_hint || start_hint.defined()) @@ -737,12 +743,15 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d } } + Map ret; + if (start_hint) { auto rnode_ptr = var2node.at(start_hint.value().get()); for (auto& p_node : pattern2node) { if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) + for (const auto& [df_pattern, pattern_node] : pattern2node) { ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); + } return ret; } } @@ -757,15 +766,15 @@ Map MatchGraph(const PatternContext& ctx, const DataflowBlock& d if (start_hint.defined() && start_hint.value().get() == var) continue; RNode& r_node = var2node[var]; if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) { - for (const auto& [df_pattern, pattern_node] : pattern2node) + for (const auto& [df_pattern, pattern_node] : pattern2node) { ret.Set(GetRef(df_pattern), GetRef(pattern_node.matched)); - + } return ret; } } } - return ret; + return NullOpt; } TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 76bce47f7ff7..f18244096ec2 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -1042,5 +1042,38 @@ def main( assert out[V_weight_pat].name_hint == "w2" +def test_attention_fake_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x2, w1) + lv2 = R.matmul(x2, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + dfb = QKV_proj["main"].body.blocks[0] + assert ctx.match_dfb(dfb) is None + + if __name__ == "__main__": tvm.testing.main()