diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index bbc8e9382ed0..8f2024f26403 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -67,7 +67,9 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, * \param f The function to rewrite * \return The rewritten or the input function, depending on the pattern matching result. */ -TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f); +TVM_DLL Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function f); /** * \brief Rewrite a function with the given pattern and the rewriter function. diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index a14d43f6d386..531971d3db5d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") }); /*! - * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones + * \brief Apply pattern matching to each dataflow block, replacing matches * with the output of a user-provided rewriter function. */ -class PatternRewriter : ExprMutator { +class BlockPatternRewriter : ExprMutator { public: using ExprMutator::VisitBindingBlock_; using ExprMutator::VisitExpr_; - PatternRewriter(DFPattern pat, PackedFunc rewriter_func, - const std::unordered_set& params) - : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} - - PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, - const std::unordered_set& params) - : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} + BlockPatternRewriter( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter_func) + : ctx_(ctx), rewriter_func_(rewriter_func) {} template - static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) { - std::unordered_set params; - for (const auto& p : f->params) { - params.insert(p.get()); - } - PatternRewriter rewriter(pat, rewriter_func, params); - return Downcast(RemoveAllUnused(rewriter.VisitExpr(f))); - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - if (ctx_) { - return ExprMutator::VisitExpr_(seq); - } - - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this DataflowBlock. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this DataflowBlock. - bindings_ = cache; - prev = next; - } + static Function Run( + PatternType pat, + TypedPackedFunc(Map, Map)> rewriter_func, + Function func) { + BlockPatternRewriter rewriter(pat, rewriter_func); + + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - if (ctx_) { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } else { - return ExprMutator::VisitBindingBlock_(block_node); - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - if (pattern_) { - if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) { - Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); - if (!rewritten_expr.same_as(node)) { - return builder_->Normalize(rewritten_expr); - } - } - } - return node; + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); } private: @@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator { BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { auto df_block = Downcast(block); Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) { + if (auto matches = MatchGraph(ctx_, df_block, bindings)) { builder_->BeginDataflowBlock(); Map replacements = rewriter_func_(matches.value(), bindings); @@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator { return block; } - /*! \brief The pattern for rewriting call nodes */ - Optional pattern_; /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - Optional ctx_; + PatternContext ctx_; /*! * \brief The user-provided rewriter function. Its signature and semantics are: - * - (Call, Map) -> Call for call node rewriting. Given the matched - * call node and the map of patterns and matched expressions, it should return a new call node - * to replace the original one or the original matched call node as is. - * - (Map, Map) -> Map for dataflow block rewriting. - * Given the map of patterns and corresponding variables (bound variables or parameters), - * it should return a map that specifies new values for matched bound variables. It can refer + * + * - (Map, Map) -> Map + * + * Given the map of patterns and corresponding variables (bound + * variables or parameters), it should return a map that + * specifies new values for matched bound variables. It can refer * to the passed bindings to create the replacement expressions. */ - PackedFunc rewriter_func_; - std::unordered_set params_; + TypedPackedFunc(Map, Map)> rewriter_func_; +}; + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class ExprPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + ExprPatternRewriter(DFPattern pat, + TypedPackedFunc)> rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + template + static Function Run(PatternType pat, + TypedPackedFunc)> rewriter_func, + Function func) { + ExprPatternRewriter rewriter(pat, rewriter_func); + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + Expr VisitExpr_(const SeqExprNode* seq) override { + auto cache = bindings_; + SeqExpr prev = GetRef(seq); + + StructuralEqual struct_equal; + + while (true) { + SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Reset all knowledge of bindings that were collected from + // this SeqExpr. The collected bindings are only after + // the point where they were collected, and we are repeating + // the mutation of this SeqExpr. + bindings_ = cache; + prev = next; + } + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto expr = VisitExpr(binding->value); + bindings_.Set(binding->var, expr); + ReEmitBinding(binding, expr); + } + + Expr VisitExpr(const Expr& expr) override { + auto node = ExprMutator::VisitExpr(expr); + + if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) { + Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); + if (!rewritten_expr.same_as(node)) { + return builder_->Normalize(rewritten_expr); + } + } + + return node; + } + + private: + /*! \brief The pattern for rewriting call nodes */ + DFPattern pattern_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Call, Map) -> Call + * + * Given the matched call node and the map of patterns and + * matched expressions, it should return a new call node to + * replace the original one or the original matched call node as + * is. + */ + TypedPackedFunc)> rewriter_func_; + + /*! \brief The known variable bindings + * + * The variable bindings whose value is known. This must be tracked + * separately from the block builder, so that it can be reset after + * each iteration of the mutate-until-converged loop applied to + * `SeqExpr`. + */ Map bindings_; }; -Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f) { - return PatternRewriter::Run(ctx, rewriter, f); +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + return BlockPatternRewriter::Run(ctx, rewriter, func); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function f) { - return PatternRewriter::Run(pat, rewriter, f); + TypedPackedFunc)> rewriter, Function func) { + return ExprPatternRewriter::Run(pat, rewriter, func); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);