From 92593cd11f74b0e8e872926907d634978148dda5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 1 Apr 2023 04:13:15 +0900 Subject: [PATCH 1/3] Add pattern-based dataflow block rewriting --- python/tvm/relax/dpl/pattern.py | 34 ++++- src/relax/ir/dataflow_matcher.cc | 113 +++++++++++++-- tests/python/relax/test_dataflow_pattern.py | 153 +++++++++++++++++++- 3 files changed, 285 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index acabac2dcbf1..3026213ba2f6 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -1125,7 +1125,7 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None return out -def rewrite( +def rewrite_call( pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function ) -> Function: """ @@ -1158,4 +1158,34 @@ def rewriter(orig, matchings): rewritten_func: Function The rewritten or the input function, depending on the pattern matching result. """ - return ffi.rewrite(pattern, rewriter, func) + return ffi.rewrite_call(pattern, rewriter, func) + + +def rewrite_bindings( + ctx, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + Parameters + ---------- + pattern: DFPattern + The pattern to match. + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + The function to be called on a successful matching for rewriting. Given the matched + call node and the map of patterns and matched expressions, it should return a new call node + to replace the original one or the original matched call node as is. + For example, to replace x + x with 2 * x, we can write the rewriter as follows: + ``` + x = wildcard() + pattern = is_op("relax.add")(x, x) + def rewriter(orig, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + ``` + func: Function + The function to rewrite. + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_bindings(ctx, rewriter, func) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 88381d6e26d9..0055929a78b1 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -780,18 +780,29 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph); /*! - * \brief Apply pattern matching to each call node and replace matching ones with the output of - * a user-provided rewriter function. + * \brief Apply pattern matching to each call node and dataflow block, and replace matching ones + * with the output of a user-provided rewriter function. */ class PatternRewriter : ExprMutator { public: + using ExprMutator::VisitBindingBlock_; using ExprMutator::VisitExpr_; - PatternRewriter(DFPattern pat, PackedFunc rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} + PatternRewriter(DFPattern pat, PackedFunc rewriter_func, + const std::unordered_set& params) + : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {} - static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) { - PatternRewriter rewriter(pat, rewriter_func); + PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func, + const std::unordered_set& params) + : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} + + template + static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { + std::unordered_set params; + for (const auto& p : f->params) { + params.insert(p.get()); + } + PatternRewriter rewriter(pat, rewriter_func, params); return RemoveAllUnused(Downcast(rewriter.VisitExpr(f))); } @@ -807,7 +818,9 @@ class PatternRewriter : ExprMutator { Expr VisitExpr_(const CallNode* call_node) final { auto call = ExprMutator::VisitExpr_(call_node); - if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) { + if (!pattern_) { + return call; + } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call, bindings_)) { auto rewriten_expr = rewriter_func_(call, matches_opt.value()); memo_[call_node] = rewriten_expr; return rewriten_expr; @@ -815,17 +828,99 @@ class PatternRewriter : ExprMutator { return call; } + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { + if (!ctx_) { + return ExprMutator::VisitBindingBlock_(block_node); + } + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + } + private: - DFPattern pattern_; + void EmitUsedVars(Expr val, const Array& pending_bindings, + std::unordered_set* emitted_vars) { + std::unordered_set unemitted_vars; + PostOrderVisit(val, [=, &unemitted_vars](Expr e) { + if (auto v = e.as(); v && !emitted_vars->count(v)) { + unemitted_vars.insert(v); + } + }); + + if (unemitted_vars.empty()) { + return; + } + + size_t num_unemitted = unemitted_vars.size(); + for (const auto& binding : pending_bindings) { + if (auto var_bind = binding.as(); + var_bind && unemitted_vars.count(var_bind->var.get())) { + EmitUsedVars(var_bind->value, pending_bindings, emitted_vars); + this->VisitBinding(binding); + emitted_vars->insert(var_bind->var.get()); + if (--num_unemitted == 0) { + return; + } + } + } + } + + // Repeat until all matchable subsets of bindings are rewritten. + BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { + if (auto matches = MatchGraph(ctx_.value(), Downcast(block))) { + builder_->BeginDataflowBlock(); + Map replacements = rewriter_func_(matches.value()); + + std::unordered_set emitted_vars; + + for (size_t i = 0; i < block->bindings.size(); ++i) { + const auto& binding = block->bindings[i]; + if (auto var_bind = binding.as()) { + if (replacements.count(var_bind->var)) { + auto new_val = replacements[var_bind->var]; + Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); + // Make sure there is no unbound variable used in the new value before it is emitted + EmitUsedVars(new_val, pending_bindings, &emitted_vars); + this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + } else if (!emitted_vars.count(var_bind->var.get())) { + this->VisitBinding(binding); + emitted_vars.insert(var_bind->var.get()); + } + } else { + this->VisitBinding(binding); + } + } + return RewriteDataflowBlockFixedPoint(builder_->EndBlock()); + } + return block; + } + + /*! \brief The pattern for rewriting call nodes */ + Optional pattern_; + /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ + Optional ctx_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * - (Call, Map) -> Call for call node rewriting. Given the matched + * call node and the map of patterns and matched expressions, it should return a new call node + * to replace the original one or the original matched call node as is. + * - Map -> Map for dataflow block rewriting. Given the map of patterns + * and corresponding variables (bound variables or parameters), it should return a map that + * specifies new values for matched bound variables. + */ PackedFunc rewriter_func_; + std::unordered_set params_; Map bindings_; std::unordered_map memo_; }; -TVM_REGISTER_GLOBAL("relax.dpl.rewrite") +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { return PatternRewriter::Run(pat, rewriter, f); }); +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings") + .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(ctx, rewriter, f); + }); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index f18244096ec2..e4d7f7972c6e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -918,7 +918,7 @@ def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtyp def rewriter(_, matchings): return R.multiply(matchings[x], R.const(2, "float32")) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected1) add1 = is_op("relax.add")(x, x) @@ -927,14 +927,14 @@ def rewriter(_, matchings): def rewriter(_, matchings): return R.multiply(matchings[x], R.const(4, "float32")) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected2) # No rewriting, return the original call node as is def rewriter(orig, _): return orig - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, main) @@ -1002,7 +1002,7 @@ def BSH_to_BSNH(tensor): def rewriter(_, matchings): return R.nn.attention(matchings[Q], matchings[K], matchings[V]) - rewritten = rewrite(pattern, rewriter, main) + rewritten = rewrite_call(pattern, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected) @@ -1075,5 +1075,150 @@ def main( assert ctx.match_dfb(dfb) is None +def get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 +): + def qkv_proj_rewriter(matchings): + inp = matchings[inp_pat] + Q_weight = matchings[Q_weight_pat] + K_weight = matchings[K_weight_pat] + V_weight = matchings[V_weight_pat] + width = Q_weight.struct_info.shape[1] + + concat = R.concat([Q_weight, K_weight, V_weight], axis=1) + matmul = R.matmul(inp, concat) + Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) + K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) + V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) + + return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} + + return qkv_proj_rewriter + + +def test_combine_matmul_twice(): + @R.function + def qkv_x2( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv1 = R.matmul(x1, w1) + lv2 = R.matmul(x1, w2) + lv3 = R.matmul(x2, w3) + lv4 = R.matmul(x2, w4) + lv5 = R.matmul(x2, w5) + out = (lv0, lv1, lv2, lv3, lv4, lv5) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + w5: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv = R.concat((w0, w1, w2), axis=1) + lv1 = R.matmul(x1, lv) + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + lv2_1 = R.concat((w3, w4, w5), axis=1) + lv3 = R.matmul(x2, lv2_1, out_dtype="void") + lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640]) + lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280]) + lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2, lv3_1, lv4, lv5 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) + tvm.ir.assert_structural_equal(rewritten, expected) + + +def test_combine_matmul_emit_order(): + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + lv0 = R.matmul(x1, w0_t) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + lv1 = R.matmul(x1, w1_t_t) + w2_t = R.permute_dims(w2, axes=None) + lv2 = R.matmul(x1, w2_t) + out = (lv0, lv1, lv2) + R.output(out) + return out + + @R.function + def expected( + x1: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, 640), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor: + with R.dataflow(): + w0_t = R.permute_dims(w0, axes=None) + w1_t = R.permute_dims(w1, axes=None) + w1_t_t = R.permute_dims(w1_t, axes=None) + w2_t = R.permute_dims(w2, axes=None) + lv = R.concat((w0_t, w1_t_t, w2_t), axis=1) + lv1 = R.matmul(x1, lv, out_dtype="void") + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920]) + out = lv0, lv1_1, lv2 + R.output(out) + return out + + with PatternContext() as ctx: + inp_pat = wildcard() + Q_weight_pat = wildcard() + K_weight_pat = wildcard() + V_weight_pat = wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + rewriter = get_qkv_proj_rewriter( + inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 + ) + rewritten = rewrite_bindings(ctx, rewriter, main) + tvm.ir.assert_structural_equal(rewritten, expected) + + if __name__ == "__main__": tvm.testing.main() From 41012f8767765968d246b5a69540a17c1a44171e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 1 Apr 2023 04:28:59 +0900 Subject: [PATCH 2/3] minor improvement in EmitUsedVars --- src/relax/ir/dataflow_matcher.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 0055929a78b1..c1306ff69093 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -850,10 +850,13 @@ class PatternRewriter : ExprMutator { } size_t num_unemitted = unemitted_vars.size(); - for (const auto& binding : pending_bindings) { + for (size_t i = 0; i < pending_bindings.size(); ++i) { + const auto& binding = pending_bindings[i]; if (auto var_bind = binding.as(); var_bind && unemitted_vars.count(var_bind->var.get())) { - EmitUsedVars(var_bind->value, pending_bindings, emitted_vars); + // var_bind->value may also depend on other unemitted vars in this range + Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); + EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); this->VisitBinding(binding); emitted_vars->insert(var_bind->var.get()); if (--num_unemitted == 0) { From d7218ee14cdcc9d7e1453c0288fe2d9f68cc3316 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 1 Apr 2023 04:48:57 +0900 Subject: [PATCH 3/3] update doc --- python/tvm/relax/dpl/__init__.py | 1 + python/tvm/relax/dpl/pattern.py | 70 +----------- python/tvm/relax/dpl/rewrite.py | 115 ++++++++++++++++++++ tests/python/relax/test_dataflow_pattern.py | 9 +- 4 files changed, 126 insertions(+), 69 deletions(-) create mode 100644 python/tvm/relax/dpl/rewrite.py diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index e0bbdaff0512..6451238428c2 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,3 +19,4 @@ from .pattern import * from .context import * +from .rewrite import rewrite_call, rewrite_bindings diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 3026213ba2f6..79883b9161ec 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -20,7 +20,7 @@ # pylint: disable=pointless-statement import typing -from typing import Dict, List, Optional, Tuple, Union, Callable +from typing import Dict, List, Optional, Tuple, Union import tvm import tvm._ffi as tvm_ffi @@ -31,7 +31,7 @@ from ...ir import make_node from ...ir.base import Node from ...runtime import Object -from ..expr import Expr, Var, Function +from ..expr import Expr, Var from . import _ffi as ffi @@ -1123,69 +1123,3 @@ def make_fused_bias_activation_pattern(op_name, with_bias=False, activation=None return is_op(activation)(out) return out - - -def rewrite_call( - pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function -) -> Function: - """ - Rewrite a function with the given pattern and the rewriter function. - - Parameters - ---------- - pattern: DFPattern - The pattern to match. - - rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] - The function to be called on a successful matching for rewriting. Given the matched - call node and the map of patterns and matched expressions, it should return a new call node - to replace the original one or the original matched call node as is. - - For example, to replace x + x with 2 * x, we can write the rewriter as follows: - ``` - x = wildcard() - pattern = is_op("relax.add")(x, x) - - def rewriter(orig, matchings): - return R.multiply(matchings[x], R.const(2, "float32")) - ``` - - func: Function - The function to rewrite. - - Returns - ------- - rewritten_func: Function - The rewritten or the input function, depending on the pattern matching result. - """ - return ffi.rewrite_call(pattern, rewriter, func) - - -def rewrite_bindings( - ctx, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function -) -> Function: - """ - Rewrite a function with the given pattern and the rewriter function. - Parameters - ---------- - pattern: DFPattern - The pattern to match. - rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] - The function to be called on a successful matching for rewriting. Given the matched - call node and the map of patterns and matched expressions, it should return a new call node - to replace the original one or the original matched call node as is. - For example, to replace x + x with 2 * x, we can write the rewriter as follows: - ``` - x = wildcard() - pattern = is_op("relax.add")(x, x) - def rewriter(orig, matchings): - return R.multiply(matchings[x], R.const(2, "float32")) - ``` - func: Function - The function to rewrite. - Returns - ------- - rewritten_func: Function - The rewritten or the input function, depending on the pattern matching result. - """ - return ffi.rewrite_bindings(ctx, rewriter, func) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py new file mode 100644 index 000000000000..1b62a429030e --- /dev/null +++ b/python/tvm/relax/dpl/rewrite.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""APIs for pattern-based rewriting.""" +from typing import Dict, Callable +from .pattern import DFPattern +from .context import PatternContext + +from ..expr import Expr, Function, Var +from . import _ffi as ffi + + +def rewrite_call( + pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + + Parameters + ---------- + pattern: DFPattern + The pattern to match. + + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + The function to be called on a successful matching for rewriting. Given the matched + call node and the map of patterns and matched expressions, it should return a new call node + to replace the original one or the original matched call node as is. + + For example, to replace x + x with 2 * x, we can write the rewriter as follows: + ``` + x = wildcard() + pattern = is_op("relax.add")(x, x) + + def rewriter(orig, matchings): + return R.multiply(matchings[x], R.const(2, "float32")) + ``` + + func: Function + The function to rewrite. + + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_call(pattern, rewriter, func) + + +def rewrite_bindings( + ctx: PatternContext, rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]], func: Function +) -> Function: + """ + Rewrite a function with the given pattern and the rewriter function. + + Parameters + ---------- + ctx: PatternContext + The pattern constraint context under which rewriting takes place. + + rewriter: Callable[[Dict[DFPattern, Var]], Dict[Var, Expr]] + The function to be called on a successful matching for rewriting. Given the map of patterns + and corresponding variables (bound variables or parameters), it should return a map that + specifies new values for matched bound variables. + + For example, to rewrite three matmuls for QKV projection in transformer models into one + matmul followed by slicing, one can use the follwoing rewriter: + ``` + inp_pat = wildcard() + Q_weight_pat, K_weight_pat, V_weight_pat = wildcard(), wildcard(), wildcard() + + matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat) + matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat) + matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat) + + def rewriter(matchings): + inp = matchings[inp_pat] + Q_weight = matchings[Q_weight_pat] + K_weight = matchings[K_weight_pat] + V_weight = matchings[V_weight_pat] + width = Q_weight.struct_info.shape[1] + + concat = R.concat([Q_weight, K_weight, V_weight], axis=1) + matmul = R.matmul(inp, concat) + Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width]) + K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2]) + V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width * 3]) + + # matchings[matmul1] gives the bound variable in the binding whose RHS matches with + # the matmul1 pattern. For example, lv0 in lv0 = R.matmul(x1, w0). + # We want to replace the RHS of this binding with Q. + return {matchings[matmul1]: Q, matchings[matmul2]: K, matchings[matmul3]: V} + ``` + + func: Function + The function to rewrite. + + Returns + ------- + rewritten_func: Function + The rewritten or the input function, depending on the pattern matching result. + """ + return ffi.rewrite_bindings(ctx, rewriter, func) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index e4d7f7972c6e..b85543cafcb8 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -18,7 +18,7 @@ import pytest import tvm.testing -from tvm import relay +from tvm import relay, relax from tvm.relax.dpl import * from tvm.relax.analysis import get_var2val from tvm import relax as rx, tir @@ -1219,6 +1219,13 @@ def expected( rewritten = rewrite_bindings(ctx, rewriter, main) tvm.ir.assert_structural_equal(rewritten, expected) + # make sure it builds + mod = tvm.IRModule() + mod["main"] = rewritten + mod = relax.transform.LegalizeOps()(mod) + + relax.build(mod, target="llvm") + if __name__ == "__main__": tvm.testing.main()