-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Pattern-based rewriting for dataflow block #14446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,3 +19,4 @@ | |
|
|
||
| from .pattern import * | ||
| from .context import * | ||
| from .rewrite import rewrite_call, rewrite_bindings | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """APIs for pattern-based rewriting.""" | ||
| from typing import Dict, Callable | ||
| from .pattern import DFPattern | ||
| from .context import PatternContext | ||
|
|
||
| from ..expr import Expr, Function, Var | ||
| from . import _ffi as ffi | ||
|
|
||
|
|
||
| def rewrite_call( | ||
| pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function | ||
| ) -> Function: | ||
| """ | ||
| Rewrite a function with the given pattern and the rewriter function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| pattern: DFPattern | ||
| The pattern to match. | ||
|
|
||
| rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] | ||
| The function to be called on a successful matching for rewriting. Given the matched | ||
| call node and the map of patterns and matched expressions, it should return a new call node | ||
| to replace the original one or the original matched call node as is. | ||
|
|
||
| For example, to replace x + x with 2 * x, we can write the rewriter as follows: | ||
| ``` | ||
| x = wildcard() | ||
| pattern = is_op("relax.add")(x, x) | ||
|
|
||
| def rewriter(orig, matchings): | ||
| return R.multiply(matchings[x], R.const(2, "float32")) | ||
| ``` | ||
|
|
||
| func: Function | ||
| The function to rewrite. | ||
|
|
||
| Returns | ||
| ------- | ||
| rewritten_func: Function | ||
| The rewritten or the input function, depending on the pattern matching result. | ||
| """ | ||
| return ffi.rewrite_call(pattern, rewriter, func) | ||
|
|
||
|
|
||
| def rewrite_bindings( | ||
| ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]], func: Function | ||
| ) -> Function: | ||
| """ | ||
| Rewrite a function with the given pattern and the rewriter function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| ctx: PatternContext | ||
| The pattern constraint context under which rewriting takes place. | ||
|
|
||
| rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]] | ||
| The function to be called on a successful matching for rewriting. Given the map of patterns | ||
| and corresponding variables (bound variables or parameters), it should return a map that | ||
| specifies new values for matched bound variables. | ||
|
|
||
| For example, to rewrite three matmuls for QKV projection in transformer models into one | ||
| matmul followed by slicing, one can use the follwoing rewriter: | ||
| ``` | ||
| inp_pat = wildcard() | ||
| Q_weight_pat, K_weight_pat, V_weight_pat = wildcard(), wildcard(), wildcard() | ||
|
|
||
| matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) | ||
| matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) | ||
| matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) | ||
|
|
||
| def rewriter(matchings): | ||
| inp = matchings[inp_pat] | ||
| Q_weight = matchings[Q_weight_pat] | ||
| K_weight = matchings[K_weight_pat] | ||
| V_weight = matchings[V_weight_pat] | ||
| width = Q_weight.struct_info.shape[1] | ||
|
|
||
| concat = R.concat([Q_weight, K_weight, V_weight], axis=1) | ||
| matmul = R.matmul(inp, concat) | ||
| Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) | ||
| K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) | ||
| V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) | ||
|
|
||
| # matchings[matmul1] gives the bound variable in the binding whose RHS matches with | ||
| # the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0). | ||
| # We want to replace the RHS of this binding with Q. | ||
| return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} | ||
| ``` | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hope the API for the rewriter makes sense and the usage is intuitive. It took me a while to workout this interface together with how the rewriting mutator should be implemented in cc @ganler |
||
|
|
||
| func: Function | ||
| The function to rewrite. | ||
|
|
||
| Returns | ||
| ------- | ||
| rewritten_func: Function | ||
| The rewritten or the input function, depending on the pattern matching result. | ||
| """ | ||
| return ffi.rewrite_bindings(ctx, rewriter, func) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const Datafl | |
| TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); | ||
|
|
||
| /*! | ||
| * \brief Apply pattern matching to each call node and replace matching ones with the output of | ||
| * a user-provided rewriter function. | ||
| * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones | ||
| * with the output of a user-provided rewriter function. | ||
| */ | ||
| class PatternRewriter : ExprMutator { | ||
| public: | ||
| using ExprMutator::VisitBindingBlock_; | ||
| using ExprMutator::VisitExpr_; | ||
|
|
||
| PatternRewriter(DFPattern pat, PackedFunc rewriter_func) | ||
| : pattern_(pat), rewriter_func_(rewriter_func) {} | ||
| PatternRewriter(DFPattern pat, PackedFunc rewriter_func, | ||
| const std::unordered_set<const VarNode*>& params) | ||
| : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} | ||
|
|
||
| static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) { | ||
| PatternRewriter rewriter(pat, rewriter_func); | ||
| PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, | ||
| const std::unordered_set<const VarNode*>& params) | ||
| : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} | ||
|
|
||
| template <typename PatternType> | ||
| static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on the passed pattern type ( |
||
| std::unordered_set<const VarNode*> params; | ||
| for (const auto& p : f->params) { | ||
| params.insert(p.get()); | ||
| } | ||
| PatternRewriter rewriter(pat, rewriter_func, params); | ||
| return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f))); | ||
| } | ||
|
|
||
|
|
@@ -807,25 +818,112 @@ class PatternRewriter : ExprMutator { | |
|
|
||
| Expr VisitExpr_(const CallNode* call_node) final { | ||
| auto call = ExprMutator::VisitExpr_(call_node); | ||
| if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) { | ||
| if (!pattern_) { | ||
| return call; | ||
| } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) { | ||
| auto rewriten_expr = rewriter_func_(call, matches_opt.value()); | ||
| memo_[call_node] = rewriten_expr; | ||
| return rewriten_expr; | ||
| } | ||
| return call; | ||
| } | ||
|
|
||
| BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { | ||
| if (!ctx_) { | ||
| return ExprMutator::VisitBindingBlock_(block_node); | ||
| } | ||
| return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node)); | ||
| } | ||
|
|
||
| private: | ||
| DFPattern pattern_; | ||
| void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings, | ||
| std::unordered_set<const VarNode*>* emitted_vars) { | ||
| std::unordered_set<const VarNode*> unemitted_vars; | ||
| PostOrderVisit(val, [=, &unemitted_vars](Expr e) { | ||
| if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) { | ||
| unemitted_vars.insert(v); | ||
| } | ||
| }); | ||
|
|
||
| if (unemitted_vars.empty()) { | ||
| return; | ||
| } | ||
|
|
||
| size_t num_unemitted = unemitted_vars.size(); | ||
| for (size_t i = 0; i < pending_bindings.size(); ++i) { | ||
| const auto& binding = pending_bindings[i]; | ||
| if (auto var_bind = binding.as<VarBindingNode>(); | ||
| var_bind && unemitted_vars.count(var_bind->var.get())) { | ||
| // var_bind->value may also depend on other unemitted vars in this range | ||
| Array<Binding> prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); | ||
| EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to get rid of this recursive call and make sure we traverse we need to emit all I think we can use |
||
| this->VisitBinding(binding); | ||
| emitted_vars->insert(var_bind->var.get()); | ||
| if (--num_unemitted == 0) { | ||
| return; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Repeat until all matchable subsets of bindings are rewritten. | ||
| BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to apply rewriting repeatedly since, for example, the same QKV projection pattern appears a number of times in a single DFB. |
||
| if (auto matches = MatchGraph(ctx_.value(), Downcast<DataflowBlock>(block))) { | ||
| builder_->BeginDataflowBlock(); | ||
| Map<Var, Expr> replacements = rewriter_func_(matches.value()); | ||
|
|
||
| std::unordered_set<const VarNode*> emitted_vars; | ||
|
|
||
| for (size_t i = 0; i < block->bindings.size(); ++i) { | ||
| const auto& binding = block->bindings[i]; | ||
| if (auto var_bind = binding.as<VarBindingNode>()) { | ||
| if (replacements.count(var_bind->var)) { | ||
| auto new_val = replacements[var_bind->var]; | ||
| Array<Binding> pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); | ||
| // Make sure there is no unbound variable used in the new value before it is emitted | ||
| EmitUsedVars(new_val, pending_bindings, &emitted_vars); | ||
| this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); | ||
| } else if (!emitted_vars.count(var_bind->var.get())) { | ||
| this->VisitBinding(binding); | ||
| emitted_vars.insert(var_bind->var.get()); | ||
| } | ||
| } else { | ||
| this->VisitBinding(binding); | ||
| } | ||
| } | ||
| return RewriteDataflowBlockFixedPoint(builder_->EndBlock()); | ||
| } | ||
| return block; | ||
| } | ||
|
|
||
| /*! \brief The pattern for rewriting call nodes */ | ||
| Optional<DFPattern> pattern_; | ||
| /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ | ||
| Optional<PatternContext> ctx_; | ||
| /*! | ||
| * \brief The user-provided rewriter function. Its signature and semantics are: | ||
| * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched | ||
| * call node and the map of patterns and matched expressions, it should return a new call node | ||
| * to replace the original one or the original matched call node as is. | ||
| * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting. Given the map of patterns | ||
| * and corresponding variables (bound variables or parameters), it should return a map that | ||
| * specifies new values for matched bound variables. | ||
| */ | ||
| PackedFunc rewriter_func_; | ||
| std::unordered_set<const VarNode*> params_; | ||
| Map<Var, Expr> bindings_; | ||
| std::unordered_map<const Object*, Expr> memo_; | ||
| }; | ||
|
|
||
| TVM_REGISTER_GLOBAL("relax.dpl.rewrite") | ||
| TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") | ||
| .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { | ||
| return PatternRewriter::Run(pat, rewriter, f); | ||
| }); | ||
|
|
||
| TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings") | ||
| .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) { | ||
| return PatternRewriter::Run(ctx, rewriter, f); | ||
| }); | ||
|
|
||
| } // namespace relax | ||
| } // namespace tvm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existing
rewritefunction is renamed torewrite_callto make it clear that it is CallNode rewriting. And together with the new dataflow block rewriting function, it is put under the new file.