From 65e9808cbe1ecf249a91233f32b948fef098ba7d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 16 Mar 2024 09:41:04 -0500 Subject: [PATCH] [Relax] Refactor PatternRewriter into separate Block/Expr mutators Prior to this commit, the `PatternRewriter` mutator handled pattern rewriting at either the expression level (`rewrite_call`) or the dataflow block level (`rewrite_bindings`). These two functionalities had different external APIs, defined diffierent member variables, and visited different IR nodes. In effect, it had two entirely independent implementations, which just happened to be implemented within the same class. This commit refactors the single `PatternRewriter` mutator into separate `BlockPatternRewriter` and `ExprPatternRewriter` mutators. --- include/tvm/relax/dataflow_matcher.h | 4 +- src/relax/ir/dataflow_matcher.cc | 238 +++++++++++++++------------ 2 files changed, 140 insertions(+), 102 deletions(-) diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index bbc8e9382ed0..8f2024f26403 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -67,7 +67,9 @@ TVM_DLL Optional> MatchGraph(const PatternContext& ctx, * \param f The function to rewrite * \return The rewritten or the input function, depending on the pattern matching result. */ -TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f); +TVM_DLL Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function f); /** * \brief Rewrite a function with the given pattern and the rewriter function. diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index a14d43f6d386..531971d3db5d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") }); /*! - * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones + * \brief Apply pattern matching to each dataflow block, replacing matches * with the output of a user-provided rewriter function. */ -class PatternRewriter : ExprMutator { +class BlockPatternRewriter : ExprMutator { public: using ExprMutator::VisitBindingBlock_; using ExprMutator::VisitExpr_; - PatternRewriter(DFPattern pat, PackedFunc rewriter_func, - const std::unordered_set& params) - : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} - - PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, - const std::unordered_set& params) - : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} + BlockPatternRewriter( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter_func) + : ctx_(ctx), rewriter_func_(rewriter_func) {} template - static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) { - std::unordered_set params; - for (const auto& p : f->params) { - params.insert(p.get()); - } - PatternRewriter rewriter(pat, rewriter_func, params); - return Downcast(RemoveAllUnused(rewriter.VisitExpr(f))); - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - if (ctx_) { - return ExprMutator::VisitExpr_(seq); - } - - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this DataflowBlock. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this DataflowBlock. - bindings_ = cache; - prev = next; - } + static Function Run( + PatternType pat, + TypedPackedFunc(Map, Map)> rewriter_func, + Function func) { + BlockPatternRewriter rewriter(pat, rewriter_func); + + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; } BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - if (ctx_) { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } else { - return ExprMutator::VisitBindingBlock_(block_node); - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - if (pattern_) { - if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) { - Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); - if (!rewritten_expr.same_as(node)) { - return builder_->Normalize(rewritten_expr); - } - } - } - return node; + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); } private: @@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator { BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { auto df_block = Downcast(block); Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) { + if (auto matches = MatchGraph(ctx_, df_block, bindings)) { builder_->BeginDataflowBlock(); Map replacements = rewriter_func_(matches.value(), bindings); @@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator { return block; } - /*! \brief The pattern for rewriting call nodes */ - Optional pattern_; /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - Optional ctx_; + PatternContext ctx_; /*! * \brief The user-provided rewriter function. Its signature and semantics are: - * - (Call, Map) -> Call for call node rewriting. Given the matched - * call node and the map of patterns and matched expressions, it should return a new call node - * to replace the original one or the original matched call node as is. - * - (Map, Map) -> Map for dataflow block rewriting. - * Given the map of patterns and corresponding variables (bound variables or parameters), - * it should return a map that specifies new values for matched bound variables. It can refer + * + * - (Map, Map) -> Map + * + * Given the map of patterns and corresponding variables (bound + * variables or parameters), it should return a map that + * specifies new values for matched bound variables. It can refer * to the passed bindings to create the replacement expressions. */ - PackedFunc rewriter_func_; - std::unordered_set params_; + TypedPackedFunc(Map, Map)> rewriter_func_; +}; + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class ExprPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + ExprPatternRewriter(DFPattern pat, + TypedPackedFunc)> rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + template + static Function Run(PatternType pat, + TypedPackedFunc)> rewriter_func, + Function func) { + ExprPatternRewriter rewriter(pat, rewriter_func); + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + Expr VisitExpr_(const SeqExprNode* seq) override { + auto cache = bindings_; + SeqExpr prev = GetRef(seq); + + StructuralEqual struct_equal; + + while (true) { + SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Reset all knowledge of bindings that were collected from + // this SeqExpr. The collected bindings are only after + // the point where they were collected, and we are repeating + // the mutation of this SeqExpr. + bindings_ = cache; + prev = next; + } + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto expr = VisitExpr(binding->value); + bindings_.Set(binding->var, expr); + ReEmitBinding(binding, expr); + } + + Expr VisitExpr(const Expr& expr) override { + auto node = ExprMutator::VisitExpr(expr); + + if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) { + Expr rewritten_expr = rewriter_func_(node, matches_opt.value()); + if (!rewritten_expr.same_as(node)) { + return builder_->Normalize(rewritten_expr); + } + } + + return node; + } + + private: + /*! \brief The pattern for rewriting call nodes */ + DFPattern pattern_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Call, Map) -> Call + * + * Given the matched call node and the map of patterns and + * matched expressions, it should return a new call node to + * replace the original one or the original matched call node as + * is. + */ + TypedPackedFunc)> rewriter_func_; + + /*! \brief The known variable bindings + * + * The variable bindings whose value is known. This must be tracked + * separately from the block builder, so that it can be reset after + * each iteration of the mutate-until-converged loop applied to + * `SeqExpr`. + */ Map bindings_; }; -Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f) { - return PatternRewriter::Run(ctx, rewriter, f); +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + return BlockPatternRewriter::Run(ctx, rewriter, func); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function f) { - return PatternRewriter::Run(pat, rewriter, f); + TypedPackedFunc)> rewriter, Function func) { + return ExprPatternRewriter::Run(pat, rewriter, func); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);