From 8b81a7aa0b854b7555e5d4f9418a8483688f6168 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 5 Apr 2023 19:23:45 +0900 Subject: [PATCH 01/20] stub --- include/tvm/relax/dataflow_matcher.h | 1 + src/relax/ir/dataflow_matcher.cc | 11 +-- .../transform/combine_parallel_matmul.cc | 88 +++++++++++++++++++ 3 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 src/relax/transform/combine_parallel_matmul.cc diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index cf7c58f093e6..2235ea16af2a 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -58,6 +58,7 @@ Optional> ExtractMatchedExpr( TVM_DLL Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb); +TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f); } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 6e8211cfd314..d980a046ac75 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -791,7 +791,7 @@ class PatternRewriter : ExprMutator { : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {} template - static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) { + static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) { std::unordered_set params; for (const auto& p : f->params) { params.insert(p.get()); @@ -909,15 +909,16 @@ class PatternRewriter : ExprMutator { std::unordered_map memo_; }; +Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f) { + return PatternRewriter::Run(ctx, rewriter, f); +} + TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call") .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) { return PatternRewriter::Run(pat, rewriter, f); }); -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings") - .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter, Function f) { - return PatternRewriter::Run(ctx, rewriter, f); - }); +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc new file mode 100644 index 000000000000..a982d78cb191 --- /dev/null +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +#include + +#include "../op/tensor/index.h" +#include "../op/tensor/linear_algebra.h" +#include "../op/tensor/manipulate.h" + +namespace tvm { +namespace relax { + +using runtime::Map; + +Function CombineParallelMatmul(Function f) { + PatternContext ctx; + WildcardPattern input_pattern; + std::vector weight_patterns; + std::vector matmul_patterns; + const int num_branches = 32; + + runtime::TypedPackedFunc(Map)> rewriter = + [=](Map matchings) { + auto inp = matchings[input_pattern]; + + Array weights; + for (const auto& weight_pat : weight_patterns) { + weights.push_back(matchings[weight_pat]); + } + + auto concat_weights = concat(Tuple(weights), Integer(1)); + auto matmul_combined = matmul(inp, concat_weights, DataType::Float(16)); + + Map replacements; + PrimExpr begin{0}; + int slice_axis = 2; + Array strides{1}; + + for (size_t i = 0; i < num_branches; ++i) { + auto sinfo = GetStructInfo(weights[i]); + auto width = Downcast(sinfo)->GetShape().value()[1]; + auto bound_var = matchings[matmul_patterns[i]]; + auto slice = + strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); + replacements.Set(bound_var, slice); + begin += width; + } + + return replacements; + }; + return RewriteBindings(ctx, rewriter, f); +} + +namespace transform { + +Pass CombineParallelMatmul() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return relax::CombineParallelMatmul(f); }; + return CreateFunctionPass(/*pass_function=*/pass_func, // + /*opt_level=*/0, // + /*pass_name=*/"CombineParallelMatmul", // + /*required=*/{}); +} + +} // namespace transform + +} // namespace relax +} // namespace tvm From 3298c00c4dd9fd0a22f7ed502fd786da99b20a49 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 04:40:24 +0900 Subject: [PATCH 02/20] make EnterWithScope and ExitWithScope public --- include/tvm/relax/dataflow_pattern.h | 10 ++++------ src/relax/ir/dataflow_pattern.cc | 18 ++++++++---------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index e4c27f3558ba..68cfdd83ad2a 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -248,14 +248,12 @@ class PatternContext : public ObjectRef { /*! \brief Get the constraint context object on the top of the stack */ TVM_DLL static Optional Current(); - class Internal; - - private: /*! \brief The RAII-like entry of a constraint context scope */ - TVM_DLL void EnterWithScope(); + TVM_DLL void EnterWithScope() const; /*! \brief The RAII-like exit of a constraint context scope */ - TVM_DLL void ExitWithScope(); - friend class Internal; + TVM_DLL void ExitWithScope() const; + + private: friend class With; }; diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 4d225ceecfe7..cd1376303ccf 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -412,9 +412,9 @@ PatternContext::PatternContext(bool incremental) { data_ = std::move(n); } -void PatternContext::EnterWithScope() { pattern_ctx_stack().push(*this); } +void PatternContext::EnterWithScope() const { pattern_ctx_stack().push(*this); } -void PatternContext::ExitWithScope() { +void PatternContext::ExitWithScope() const { ICHECK(pattern_ctx_stack().top().same_as(*this)); pattern_ctx_stack().pop(); } @@ -610,15 +610,13 @@ TVM_REGISTER_GLOBAL("relax.dpl.current_context").set_body_typed([] { return PatternContext::Current(); }); -class PatternContext::Internal { - public: - static void EnterScope(PatternContext pass_ctx) { pass_ctx.EnterWithScope(); } - static void ExitScope(PatternContext pass_ctx) { pass_ctx.ExitWithScope(); } -}; - -TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed(PatternContext::Internal::EnterScope); +TVM_REGISTER_GLOBAL("relax.dpl.enter_context").set_body_typed([](const PatternContext& ctx) { + ctx.EnterWithScope(); +}); -TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed(PatternContext::Internal::ExitScope); +TVM_REGISTER_GLOBAL("relax.dpl.exit_context").set_body_typed([](const PatternContext& ctx) { + ctx.ExitWithScope(); +}); } // namespace relax } // namespace tvm From b7d55079d03a436cb1059e7e64ca6aa3829281b5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 04:54:48 +0900 Subject: [PATCH 03/20] qkv combining works --- python/tvm/relax/transform/transform.py | 4 +++ .../transform/combine_parallel_matmul.cc | 29 +++++++++++++++---- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index a53c45b655ab..82762feedeff 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -934,6 +934,10 @@ def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: return _ffi_api.SplitCallTIRByPattern(patterns, fcodegen) # type: ignore +def CombineParallelMatmul(): + return _ffi_api.CombineParallelMatmul() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index a982d78cb191..5920e2cdfd59 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -33,14 +33,28 @@ namespace relax { using runtime::Map; Function CombineParallelMatmul(Function f) { + const int num_branches = 32; PatternContext ctx; - WildcardPattern input_pattern; + + ctx.EnterWithScope(); + + auto input_pattern = Wildcard(); std::vector weight_patterns; std::vector matmul_patterns; - const int num_branches = 32; + auto matmul_op = Op::Get("relax.matmul"); + + for (int i = 0; i < 32; ++i) { + auto w_pat = Wildcard(); + CallPattern matmul_pat{ExprPattern(matmul_op), {input_pattern, w_pat}}; + weight_patterns.push_back(w_pat); + matmul_patterns.push_back(matmul_pat); + ctx.add_constraint(input_pattern, matmul_pat, PairCons(PairCons::kUsedBy, 0)); + ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 1)); + } runtime::TypedPackedFunc(Map)> rewriter = [=](Map matchings) { + LOG(INFO) << "matched"; auto inp = matchings[input_pattern]; Array weights; @@ -49,11 +63,11 @@ Function CombineParallelMatmul(Function f) { } auto concat_weights = concat(Tuple(weights), Integer(1)); - auto matmul_combined = matmul(inp, concat_weights, DataType::Float(16)); + auto matmul_combined = matmul(inp, concat_weights, DataType::Float(16)); // TODO dtype Map replacements; PrimExpr begin{0}; - int slice_axis = 2; + int slice_axis = 2; // TODO Array strides{1}; for (size_t i = 0; i < num_branches; ++i) { @@ -68,7 +82,10 @@ Function CombineParallelMatmul(Function f) { return replacements; }; - return RewriteBindings(ctx, rewriter, f); + + auto rewritten = RewriteBindings(ctx, rewriter, f); + ctx.ExitWithScope(); + return rewritten; } namespace transform { @@ -82,6 +99,8 @@ Pass CombineParallelMatmul() { /*required=*/{}); } +TVM_REGISTER_GLOBAL("relax.transform.CombineParallelMatmul").set_body_typed(CombineParallelMatmul); + } // namespace transform } // namespace relax From 6653d63526f268810c82a887e6612bdbdf8a1a74 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 18:30:37 +0900 Subject: [PATCH 04/20] automatic branch extraction --- .../transform/combine_parallel_matmul.cc | 60 ++++++++++++++++--- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 5920e2cdfd59..b02f349f671b 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "../op/tensor/index.h" @@ -32,18 +33,17 @@ namespace relax { using runtime::Map; -Function CombineParallelMatmul(Function f) { - const int num_branches = 32; - PatternContext ctx; +static auto matmul_op = Op::Get("relax.matmul"); +Function Rewrite(Function f, int num_branches, int slice_axis) { + PatternContext ctx; ctx.EnterWithScope(); auto input_pattern = Wildcard(); std::vector weight_patterns; std::vector matmul_patterns; - auto matmul_op = Op::Get("relax.matmul"); - for (int i = 0; i < 32; ++i) { + for (int i = 0; i < num_branches; ++i) { auto w_pat = Wildcard(); CallPattern matmul_pat{ExprPattern(matmul_op), {input_pattern, w_pat}}; weight_patterns.push_back(w_pat); @@ -54,7 +54,6 @@ Function CombineParallelMatmul(Function f) { runtime::TypedPackedFunc(Map)> rewriter = [=](Map matchings) { - LOG(INFO) << "matched"; auto inp = matchings[input_pattern]; Array weights; @@ -62,12 +61,13 @@ Function CombineParallelMatmul(Function f) { weights.push_back(matchings[weight_pat]); } - auto concat_weights = concat(Tuple(weights), Integer(1)); - auto matmul_combined = matmul(inp, concat_weights, DataType::Float(16)); // TODO dtype + auto concat_weights = concat(Tuple(weights), Integer(1)); // TODO: axis + auto out_dtype = + Downcast(GetStructInfo(matchings[matmul_patterns[0]]))->dtype; + auto matmul_combined = matmul(inp, concat_weights, out_dtype); Map replacements; PrimExpr begin{0}; - int slice_axis = 2; // TODO Array strides{1}; for (size_t i = 0; i < num_branches; ++i) { @@ -88,6 +88,48 @@ Function CombineParallelMatmul(Function f) { return rewritten; } +struct BranchInfo { + int num_branches; + int slice_axis; +}; + +std::vector GetBranchInfo(Function f) { + std::unordered_map groups; + PostOrderVisit(f, [&](const Expr& e) { + if (auto call = e.as(); call && call->op.same_as(matmul_op)) { + auto lhs = Downcast(call->args[0]); + if (auto it = groups.find(lhs.get()); it == groups.end()) { + auto sinfo = GetStructInfo(e); + auto slice_axis = Downcast(sinfo)->ndim - 1; + groups[lhs.get()] = {1, slice_axis}; + } else { + it->second.num_branches += 1; + } + } + }); + + std::vector info; + + for (const auto& group : groups) { + if (group.second.num_branches > 1) { + info.push_back(group.second); + } + } + + std::sort(info.begin(), info.end(), + [](const auto& b1, const auto& b2) { return b1.num_branches > b2.num_branches; }); + + return info; +} + +Function CombineParallelMatmul(Function f) { + auto branches = GetBranchInfo(f); + for (const auto& branch : branches) { + f = Rewrite(f, branch.num_branches, branch.slice_axis); + } + return f; +} + namespace transform { Pass CombineParallelMatmul() { From aab49d64a93db12546b5421526a7ec191cc51380 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 6 Apr 2023 20:00:34 +0900 Subject: [PATCH 05/20] fix hardcoded concat axis --- src/relax/transform/combine_parallel_matmul.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index b02f349f671b..3bf293d558e7 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include @@ -61,7 +62,8 @@ Function Rewrite(Function f, int num_branches, int slice_axis) { weights.push_back(matchings[weight_pat]); } - auto concat_weights = concat(Tuple(weights), Integer(1)); // TODO: axis + auto concat_axis = Downcast(GetStructInfo(weights[0]))->ndim - 1; + auto concat_weights = concat(Tuple(weights), Integer(concat_axis)); auto out_dtype = Downcast(GetStructInfo(matchings[matmul_patterns[0]]))->dtype; auto matmul_combined = matmul(inp, concat_weights, out_dtype); From 5f7f71dc31b05e8ef55480d4f2f6a8a86a77042f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 04:44:46 +0900 Subject: [PATCH 06/20] handle non-uniform rhs ranks --- src/relax/ir/dataflow_matcher.cc | 3 + .../transform/combine_parallel_matmul.cc | 76 ++++++++++++------- 2 files changed, 52 insertions(+), 27 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index d980a046ac75..038f44854f1c 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -865,6 +865,9 @@ class PatternRewriter : ExprMutator { if (auto matches = MatchGraph(ctx_.value(), Downcast(block))) { builder_->BeginDataflowBlock(); Map replacements = rewriter_func_(matches.value()); + if (replacements.empty()) { + return block; + } std::unordered_set emitted_vars; diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 3bf293d558e7..dd6b3cc18ec4 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -36,18 +36,31 @@ using runtime::Map; static auto matmul_op = Op::Get("relax.matmul"); -Function Rewrite(Function f, int num_branches, int slice_axis) { +std::unordered_map> GroupShapes( + const std::vector>& shapes) { + std::unordered_map> indices_map; + for (size_t i = 0; i < shapes.size(); ++i) { + indices_map[shapes[i].size()].push_back(i); + } + return indices_map; +} + +inline TensorStructInfo GetTensorSInfo(Expr e) { + return Downcast(GetStructInfo(e)); +} + +Function Rewrite(Function f, int num_branches) { PatternContext ctx; ctx.EnterWithScope(); auto input_pattern = Wildcard(); - std::vector weight_patterns; + std::vector rhs_patterns; std::vector matmul_patterns; for (int i = 0; i < num_branches; ++i) { auto w_pat = Wildcard(); CallPattern matmul_pat{ExprPattern(matmul_op), {input_pattern, w_pat}}; - weight_patterns.push_back(w_pat); + rhs_patterns.push_back(w_pat); matmul_patterns.push_back(matmul_pat); ctx.add_constraint(input_pattern, matmul_pat, PairCons(PairCons::kUsedBy, 0)); ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 1)); @@ -56,30 +69,42 @@ Function Rewrite(Function f, int num_branches, int slice_axis) { runtime::TypedPackedFunc(Map)> rewriter = [=](Map matchings) { auto inp = matchings[input_pattern]; + auto lhs_dim = GetTensorSInfo(inp)->ndim; - Array weights; - for (const auto& weight_pat : weight_patterns) { - weights.push_back(matchings[weight_pat]); + std::vector> rhs_shapes; + for (const auto& rhs_pat : rhs_patterns) { + auto r = matchings[rhs_pat]; + rhs_shapes.push_back(GetTensorSInfo(r)->GetShape().value()); } - auto concat_axis = Downcast(GetStructInfo(weights[0]))->ndim - 1; - auto concat_weights = concat(Tuple(weights), Integer(concat_axis)); - auto out_dtype = - Downcast(GetStructInfo(matchings[matmul_patterns[0]]))->dtype; - auto matmul_combined = matmul(inp, concat_weights, out_dtype); + auto shape_groups = GroupShapes(rhs_shapes); Map replacements; - PrimExpr begin{0}; - Array strides{1}; - - for (size_t i = 0; i < num_branches; ++i) { - auto sinfo = GetStructInfo(weights[i]); - auto width = Downcast(sinfo)->GetShape().value()[1]; - auto bound_var = matchings[matmul_patterns[i]]; - auto slice = - strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); - replacements.Set(bound_var, slice); - begin += width; + + for (const auto& [rhs_dim, indices] : shape_groups) { + if (indices.size() == 1) continue; + + Array rhs; + for (auto ind : indices) { + rhs.push_back(matchings[rhs_patterns[ind]]); + } + + auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1)); + auto out_dtype = GetTensorSInfo(matchings[matmul_patterns[indices[0]]])->dtype; + auto matmul_combined = matmul(inp, concat_rhs, out_dtype); + + PrimExpr begin{0}; + Array strides{1}; + int slice_axis = std::max(lhs_dim, rhs_dim) - 1; + + for (size_t i = 0; i < indices.size(); ++i) { + auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1]; + auto bound_var = matchings[matmul_patterns[indices[i]]]; + auto slice = + strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); + replacements.Set(bound_var, slice); + begin += width; + } } return replacements; @@ -92,7 +117,6 @@ Function Rewrite(Function f, int num_branches, int slice_axis) { struct BranchInfo { int num_branches; - int slice_axis; }; std::vector GetBranchInfo(Function f) { @@ -101,9 +125,7 @@ std::vector GetBranchInfo(Function f) { if (auto call = e.as(); call && call->op.same_as(matmul_op)) { auto lhs = Downcast(call->args[0]); if (auto it = groups.find(lhs.get()); it == groups.end()) { - auto sinfo = GetStructInfo(e); - auto slice_axis = Downcast(sinfo)->ndim - 1; - groups[lhs.get()] = {1, slice_axis}; + groups[lhs.get()] = {1}; } else { it->second.num_branches += 1; } @@ -127,7 +149,7 @@ std::vector GetBranchInfo(Function f) { Function CombineParallelMatmul(Function f) { auto branches = GetBranchInfo(f); for (const auto& branch : branches) { - f = Rewrite(f, branch.num_branches, branch.slice_axis); + f = Rewrite(f, branch.num_branches); } return f; } From 611a431b8f8114610104e8927074ba1f6db368e0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 05:53:15 +0900 Subject: [PATCH 07/20] wip --- .../transform/combine_parallel_matmul.cc | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index dd6b3cc18ec4..cfc77976b9aa 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -34,8 +34,6 @@ namespace relax { using runtime::Map; -static auto matmul_op = Op::Get("relax.matmul"); - std::unordered_map> GroupShapes( const std::vector>& shapes) { std::unordered_map> indices_map; @@ -49,21 +47,43 @@ inline TensorStructInfo GetTensorSInfo(Expr e) { return Downcast(GetStructInfo(e)); } -Function Rewrite(Function f, int num_branches) { +struct BranchInfo { + int num_branches; + bool has_bias; + std::optional activation; +}; + +Function Rewrite(Function f, const BranchInfo& branch_info) { PatternContext ctx; ctx.EnterWithScope(); auto input_pattern = Wildcard(); std::vector rhs_patterns; - std::vector matmul_patterns; + std::vector bias_patterns; + std::vector matmul_patterns, bias_add_patterns, activation_patterns; - for (int i = 0; i < num_branches; ++i) { + for (int i = 0; i < branch_info.num_branches; ++i) { auto w_pat = Wildcard(); - CallPattern matmul_pat{ExprPattern(matmul_op), {input_pattern, w_pat}}; rhs_patterns.push_back(w_pat); + auto matmul_pat = IsOp("relax.matmul")(input_pattern, w_pat); matmul_patterns.push_back(matmul_pat); ctx.add_constraint(input_pattern, matmul_pat, PairCons(PairCons::kUsedBy, 0)); ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 1)); + + CallPattern matmul_out = matmul_pat; + + if (branch_info.has_bias) { + auto bias_pat = Wildcard(); + bias_patterns.push_back(bias_pat); + auto bias_add = IsOp("relax.add")(matmul_pat, bias_pat); + bias_add_patterns.push_back(bias_add); + matmul_out = bias_add; + } + + if (branch_info.activation) { + activation_patterns.push_back(IsOp(*branch_info.activation)(matmul_out)); + } + } runtime::TypedPackedFunc(Map)> rewriter = @@ -115,17 +135,15 @@ Function Rewrite(Function f, int num_branches) { return rewritten; } -struct BranchInfo { - int num_branches; -}; - std::vector GetBranchInfo(Function f) { std::unordered_map groups; + static auto matmul_op = Op::Get("relax.matmul"); + PostOrderVisit(f, [&](const Expr& e) { if (auto call = e.as(); call && call->op.same_as(matmul_op)) { auto lhs = Downcast(call->args[0]); if (auto it = groups.find(lhs.get()); it == groups.end()) { - groups[lhs.get()] = {1}; + groups[lhs.get()] = {1, false, std::nullopt}; } else { it->second.num_branches += 1; } @@ -149,7 +167,7 @@ std::vector GetBranchInfo(Function f) { Function CombineParallelMatmul(Function f) { auto branches = GetBranchInfo(f); for (const auto& branch : branches) { - f = Rewrite(f, branch.num_branches); + f = Rewrite(f, branch); } return f; } From cc24b3d46babcbd214f13432a4e4b5d4dd21d850 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 06:03:07 +0900 Subject: [PATCH 08/20] improve termination check in binding rewriter --- src/relax/ir/dataflow_matcher.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 038f44854f1c..b06da62c2696 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -22,6 +22,7 @@ * \brief The dataflow pattern matcher for Relax. */ +#include #include #include #include @@ -865,21 +866,20 @@ class PatternRewriter : ExprMutator { if (auto matches = MatchGraph(ctx_.value(), Downcast(block))) { builder_->BeginDataflowBlock(); Map replacements = rewriter_func_(matches.value()); - if (replacements.empty()) { - return block; - } std::unordered_set emitted_vars; + bool changed = false; for (size_t i = 0; i < block->bindings.size(); ++i) { const auto& binding = block->bindings[i]; if (auto var_bind = binding.as()) { - if (replacements.count(var_bind->var)) { - auto new_val = replacements[var_bind->var]; + if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); + !StructuralEqual()(var_bind->value, new_val)) { Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); // Make sure there is no unbound variable used in the new value before it is emitted EmitUsedVars(new_val, pending_bindings, &emitted_vars); this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + changed = true; } else if (!emitted_vars.count(var_bind->var.get())) { this->VisitBinding(binding); emitted_vars.insert(var_bind->var.get()); @@ -888,7 +888,11 @@ class PatternRewriter : ExprMutator { this->VisitBinding(binding); } } - return RewriteDataflowBlockFixedPoint(builder_->EndBlock()); + + auto new_block = builder_->EndBlock(); + + if (!changed) return new_block; + return RewriteDataflowBlockFixedPoint(new_block); } return block; } From 947e025544102ab18e71e7012faa6ba454ca2c43 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 06:27:44 +0900 Subject: [PATCH 09/20] wip --- .../transform/combine_parallel_matmul.cc | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index cfc77976b9aa..c546815aa6b9 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -25,6 +25,8 @@ #include #include +#include "../op/nn/nn.h" +#include "../op/tensor/binary.h" #include "../op/tensor/index.h" #include "../op/tensor/linear_algebra.h" #include "../op/tensor/manipulate.h" @@ -83,7 +85,6 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { if (branch_info.activation) { activation_patterns.push_back(IsOp(*branch_info.activation)(matmul_out)); } - } runtime::TypedPackedFunc(Map)> rewriter = @@ -105,21 +106,47 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { if (indices.size() == 1) continue; Array rhs; + Array bias; for (auto ind : indices) { rhs.push_back(matchings[rhs_patterns[ind]]); + if (branch_info.has_bias) { + bias.push_back(matchings[bias_patterns[ind]]); + } } auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1)); auto out_dtype = GetTensorSInfo(matchings[matmul_patterns[indices[0]]])->dtype; auto matmul_combined = matmul(inp, concat_rhs, out_dtype); + auto pattern_to_replace = &matmul_patterns; + + if (branch_info.has_bias) { + auto bias_dim = GetTensorSInfo(bias[0])->ndim; + auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1)); + matmul_combined = add(matmul_combined, concat_bias); + pattern_to_replace = &bias_add_patterns; + } + + if (branch_info.activation) { + pattern_to_replace = &activation_patterns; + if (*branch_info.activation == "relu") { + matmul_combined = relu(matmul_combined); + } else if (*branch_info.activation == "gelu") { + matmul_combined = gelu(matmul_combined); + } else if (*branch_info.activation == "silu") { + matmul_combined = silu(matmul_combined); + } else { + LOG(INFO) << "Unsupported activation: " << *branch_info.activation; + } + } + PrimExpr begin{0}; Array strides{1}; int slice_axis = std::max(lhs_dim, rhs_dim) - 1; for (size_t i = 0; i < indices.size(); ++i) { auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1]; - auto bound_var = matchings[matmul_patterns[indices[i]]]; + auto bound_var = matchings[(*pattern_to_replace)[indices[i]]]; auto slice = strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); replacements.Set(bound_var, slice); From 746ae9dc2d53a1eaca34f60e54fffcd5ae673d1f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 07:09:18 +0900 Subject: [PATCH 10/20] properly handle rhs with same rank but different batch size --- .../transform/combine_parallel_matmul.cc | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index c546815aa6b9..53cb375b2bb8 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ + +#include #include #include #include @@ -95,18 +97,34 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { std::vector> rhs_shapes; for (const auto& rhs_pat : rhs_patterns) { auto r = matchings[rhs_pat]; - rhs_shapes.push_back(GetTensorSInfo(r)->GetShape().value()); + auto rhs_shape_opt = GetTensorSInfo(r)->GetShape(); + if (!rhs_shape_opt) { + return Map{}; + } + rhs_shapes.push_back(rhs_shape_opt.value()); } - auto shape_groups = GroupShapes(rhs_shapes); + auto batch_dims_compatible = [&rhs_shapes](int rhs_dim, + const std::vector& indices) { + arith::Analyzer ana; + for (auto ind : indices) { + ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); + // -2 for reduction and concat axes + for (size_t i = 0; i < rhs_dim - 2; ++i) { + if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { + return false; + } + } + } + return true; + }; Map replacements; - for (const auto& [rhs_dim, indices] : shape_groups) { - if (indices.size() == 1) continue; + for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { + if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices)) continue; - Array rhs; - Array bias; + Array rhs, bias; for (auto ind : indices) { rhs.push_back(matchings[rhs_patterns[ind]]); if (branch_info.has_bias) { @@ -122,6 +140,9 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { if (branch_info.has_bias) { auto bias_dim = GetTensorSInfo(bias[0])->ndim; + for (auto b : bias) { + ICHECK(GetTensorSInfo(b)->ndim == bias_dim); + } auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1)); matmul_combined = add(matmul_combined, concat_bias); pattern_to_replace = &bias_add_patterns; From 9df60d5fe80cdd09b82f0e291ad7a6729c01f102 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 7 Apr 2023 08:33:55 +0900 Subject: [PATCH 11/20] support bias and activation --- .../transform/combine_parallel_matmul.cc | 112 +++++++++++++++--- 1 file changed, 98 insertions(+), 14 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 53cb375b2bb8..fdfedcd2897e 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -18,12 +18,14 @@ */ #include +#include #include #include #include #include #include +#include #include #include @@ -53,7 +55,7 @@ inline TensorStructInfo GetTensorSInfo(Expr e) { struct BranchInfo { int num_branches; - bool has_bias; + std::optional bias_dim; std::optional activation; }; @@ -76,10 +78,12 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { CallPattern matmul_out = matmul_pat; - if (branch_info.has_bias) { + if (branch_info.bias_dim) { auto bias_pat = Wildcard(); bias_patterns.push_back(bias_pat); auto bias_add = IsOp("relax.add")(matmul_pat, bias_pat); + ctx.add_constraint(matmul_pat, bias_add, PairCons(PairCons::kUsedBy, 0)); + ctx.add_constraint(bias_pat, bias_add, PairCons(PairCons::kUsedBy, 1)); bias_add_patterns.push_back(bias_add); matmul_out = bias_add; } @@ -127,7 +131,8 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { Array rhs, bias; for (auto ind : indices) { rhs.push_back(matchings[rhs_patterns[ind]]); - if (branch_info.has_bias) { + if (branch_info.bias_dim) { + ICHECK(matchings.count(bias_patterns[ind])); bias.push_back(matchings[bias_patterns[ind]]); } } @@ -138,7 +143,7 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { auto pattern_to_replace = &matmul_patterns; - if (branch_info.has_bias) { + if (branch_info.bias_dim) { auto bias_dim = GetTensorSInfo(bias[0])->ndim; for (auto b : bias) { ICHECK(GetTensorSInfo(b)->ndim == bias_dim); @@ -157,7 +162,7 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { } else if (*branch_info.activation == "silu") { matmul_combined = silu(matmul_combined); } else { - LOG(INFO) << "Unsupported activation: " << *branch_info.activation; + LOG(FATAL) << "Unsupported activation: " << *branch_info.activation; } } @@ -184,25 +189,104 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { } std::vector GetBranchInfo(Function f) { - std::unordered_map groups; - static auto matmul_op = Op::Get("relax.matmul"); + auto lhs_pat = Wildcard(); + auto rhs_pat = Wildcard(); + auto bias_pat = Wildcard(); + + auto matmul_pat = IsOp("relax.matmul")(lhs_pat, rhs_pat); + auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat); + + std::vector activations{"relax.nn.relu", "relax.nn.gelu", "relax.nn.silu"}; + + std::vector activation_pat, bias_activation_pat; + for (const auto& act : activations) { + activation_pat.push_back(IsOp(act)(matmul_pat)); + bias_activation_pat.push_back(IsOp(act)(bias_add_pat)); + } + + auto bindings = AnalyzeVar2Value(f); + + std::unordered_map groups_activation, groups_bias, groups_matmul; + + PostOrderVisit(f, [&](const Expr& e) { + if (!e->IsInstance()) return; + if (auto match = ExtractMatchedExpr(bias_add_pat, e, bindings)) { + auto matmul_call = Downcast(match.value()[matmul_pat]); + auto matmul_lhs = Downcast(matmul_call->args[0]); + auto bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; + if (auto it = groups_bias.find(matmul_lhs.get()); it == groups_bias.end()) { + groups_bias[matmul_lhs.get()] = {1, bias_dim, std::nullopt}; + } else { + it->second.num_branches += 1; + if (it->second.bias_dim && *it->second.bias_dim != bias_dim) { + it->second.bias_dim = std::nullopt; + } + } + return; + } + }); PostOrderVisit(f, [&](const Expr& e) { - if (auto call = e.as(); call && call->op.same_as(matmul_op)) { - auto lhs = Downcast(call->args[0]); - if (auto it = groups.find(lhs.get()); it == groups.end()) { - groups[lhs.get()] = {1, false, std::nullopt}; + if (!e->IsInstance()) return; + if (auto match = ExtractMatchedExpr(matmul_pat, e, bindings)) { + auto matmul_call = Downcast(match.value()[matmul_pat]); + auto matmul_lhs = Downcast(matmul_call->args[0]); + if (groups_bias.count(matmul_lhs.get()) || groups_activation.count(matmul_lhs.get())) return; + if (auto it = groups_matmul.find(matmul_lhs.get()); it == groups_matmul.end()) { + groups_matmul[matmul_lhs.get()] = {1, std::nullopt, std::nullopt}; } else { it->second.num_branches += 1; } + return; } }); + // for (size_t i = 0; i < activations.size(); ++i) { + // if (auto match = ExtractMatchedExpr(bias_activation_pat[i], e, bindings)) { + // auto matmul_lhs = Downcast(match.value()[lhs_pat]); + // auto bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; + // if (auto it = groups.find(matmul_lhs.get()); it == groups.end()) { + // groups[matmul_lhs.get()] = {1, bias_dim, activations[i]}; + // } else { + // it->second.num_branches += 1; + + // if (it->second.bias_dim != bias_dim) { + // it->second.bias_dim = std::nullopt; + // } + + // if (!it->second.activation || (*it->second.activation != activations[i])) { + // it->second.activation = std::nullopt; + // } + // } + + // for (auto pat : {matmul_pat, bias_add_pat}) { + // seen.insert(match.value()[pat].get()); + // } + // return; + // } + // if (auto match = ExtractMatchedExpr(activation_pat[i], e, bindings)) { + // auto matmul = match.value()[matmul_pat]; + // auto matmul_lhs = Downcast(match.value()[lhs_pat]); + // if (auto it = groups.find(matmul_lhs.get()); it == groups.end()) { + // groups[matmul_lhs.get()] = {1, std::nullopt, activations[i]}; + // } else { + // it->second.num_branches += 1; + + // if (!it->second.activation || (*it->second.activation != activations[i])) { + // it->second.activation = std::nullopt; + // } + // } + // return; + // } + // } + std::vector info; - for (const auto& group : groups) { - if (group.second.num_branches > 1) { - info.push_back(group.second); + for (auto groups : {groups_matmul, groups_activation, groups_bias}) { + for (const auto& group : groups) { + if (group.second.num_branches > 1) { + info.push_back(group.second); + } } } From 3c8848b1c49cd1530d4762aa1e90c7d8faa99089 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 8 Apr 2023 06:47:09 +0900 Subject: [PATCH 12/20] refactor --- .../transform/combine_parallel_matmul.cc | 343 +++++++++--------- 1 file changed, 169 insertions(+), 174 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index fdfedcd2897e..45324be91317 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -59,141 +59,147 @@ struct BranchInfo { std::optional activation; }; -Function Rewrite(Function f, const BranchInfo& branch_info) { +struct Patterns { + Patterns() : input(Wildcard()) { ctx.EnterWithScope(); } + PatternContext ctx; - ctx.EnterWithScope(); + WildcardPattern input; + std::vector rhs; + std::vector bias; + std::vector matmul, bias_add, activation; +}; - auto input_pattern = Wildcard(); - std::vector rhs_patterns; - std::vector bias_patterns; - std::vector matmul_patterns, bias_add_patterns, activation_patterns; +Patterns CreatePatterns(const BranchInfo& branch_info) { + Patterns patterns; for (int i = 0; i < branch_info.num_branches; ++i) { auto w_pat = Wildcard(); - rhs_patterns.push_back(w_pat); - auto matmul_pat = IsOp("relax.matmul")(input_pattern, w_pat); - matmul_patterns.push_back(matmul_pat); - ctx.add_constraint(input_pattern, matmul_pat, PairCons(PairCons::kUsedBy, 0)); - ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 1)); + auto matmul_pat = IsOp("relax.matmul")(patterns.input, w_pat); + patterns.rhs.push_back(w_pat); + patterns.matmul.push_back(matmul_pat); + patterns.ctx.add_constraint(patterns.input, matmul_pat, PairCons(PairCons::kUsedBy, 0)); + patterns.ctx.add_constraint(w_pat, matmul_pat, PairCons(PairCons::kUsedBy, 1)); CallPattern matmul_out = matmul_pat; if (branch_info.bias_dim) { auto bias_pat = Wildcard(); - bias_patterns.push_back(bias_pat); - auto bias_add = IsOp("relax.add")(matmul_pat, bias_pat); - ctx.add_constraint(matmul_pat, bias_add, PairCons(PairCons::kUsedBy, 0)); - ctx.add_constraint(bias_pat, bias_add, PairCons(PairCons::kUsedBy, 1)); - bias_add_patterns.push_back(bias_add); - matmul_out = bias_add; + auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat); + patterns.bias.push_back(bias_pat); + patterns.bias_add.push_back(bias_add_pat); + patterns.ctx.add_constraint(matmul_pat, bias_add_pat, PairCons(PairCons::kUsedBy, 0)); + patterns.ctx.add_constraint(bias_pat, bias_add_pat, PairCons(PairCons::kUsedBy, 1)); + matmul_out = bias_add_pat; } if (branch_info.activation) { - activation_patterns.push_back(IsOp(*branch_info.activation)(matmul_out)); + auto act_pat = IsOp(*branch_info.activation)(matmul_out); + patterns.activation.push_back(act_pat); + patterns.ctx.add_constraint(matmul_out, act_pat, PairCons(PairCons::kUsedBy, 0)); } } - runtime::TypedPackedFunc(Map)> rewriter = - [=](Map matchings) { - auto inp = matchings[input_pattern]; - auto lhs_dim = GetTensorSInfo(inp)->ndim; - - std::vector> rhs_shapes; - for (const auto& rhs_pat : rhs_patterns) { - auto r = matchings[rhs_pat]; - auto rhs_shape_opt = GetTensorSInfo(r)->GetShape(); - if (!rhs_shape_opt) { - return Map{}; - } - rhs_shapes.push_back(rhs_shape_opt.value()); - } - - auto batch_dims_compatible = [&rhs_shapes](int rhs_dim, - const std::vector& indices) { - arith::Analyzer ana; - for (auto ind : indices) { - ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); - // -2 for reduction and concat axes - for (size_t i = 0; i < rhs_dim - 2; ++i) { - if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { - return false; - } - } - } - return true; - }; - - Map replacements; - - for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { - if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices)) continue; + return patterns; +} - Array rhs, bias; - for (auto ind : indices) { - rhs.push_back(matchings[rhs_patterns[ind]]); - if (branch_info.bias_dim) { - ICHECK(matchings.count(bias_patterns[ind])); - bias.push_back(matchings[bias_patterns[ind]]); - } - } +runtime::TypedPackedFunc(Map)> GetRewriter( + const Patterns& patterns, const BranchInfo& branch_info) { + auto batch_dims_compatible = [](int rhs_dim, const std::vector& indices, + const std::vector>& rhs_shapes) { + arith::Analyzer ana; + for (auto ind : indices) { + ICHECK_EQ(static_cast(rhs_shapes[ind].size()), rhs_dim); + // -2 for reduction and concat axes + for (size_t i = 0; i < rhs_dim - 2; ++i) { + if (!ana.CanProve(rhs_shapes[indices[0]][i] == rhs_shapes[ind][i])) { + return false; + } + } + } + return true; + }; + + return [=](Map matchings) { + std::vector> rhs_shapes; + for (const auto& rhs_pat : patterns.rhs) { + auto rhs_shape_opt = GetTensorSInfo(matchings[rhs_pat])->GetShape(); + if (!rhs_shape_opt) { + return Map{}; + } + rhs_shapes.push_back(rhs_shape_opt.value()); + } - auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1)); - auto out_dtype = GetTensorSInfo(matchings[matmul_patterns[indices[0]]])->dtype; - auto matmul_combined = matmul(inp, concat_rhs, out_dtype); + Map replacements; - auto pattern_to_replace = &matmul_patterns; + for (const auto& [rhs_dim, indices] : GroupShapes(rhs_shapes)) { + if (indices.size() == 1 || !batch_dims_compatible(rhs_dim, indices, rhs_shapes)) continue; - if (branch_info.bias_dim) { - auto bias_dim = GetTensorSInfo(bias[0])->ndim; - for (auto b : bias) { - ICHECK(GetTensorSInfo(b)->ndim == bias_dim); - } - auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1)); - matmul_combined = add(matmul_combined, concat_bias); - pattern_to_replace = &bias_add_patterns; - } + Array rhs, bias; + for (auto ind : indices) { + rhs.push_back(matchings[patterns.rhs[ind]]); + if (branch_info.bias_dim) { + ICHECK(matchings.count(patterns.bias[ind])); + bias.push_back(matchings[patterns.bias[ind]]); + } + } - if (branch_info.activation) { - pattern_to_replace = &activation_patterns; - if (*branch_info.activation == "relu") { - matmul_combined = relu(matmul_combined); - } else if (*branch_info.activation == "gelu") { - matmul_combined = gelu(matmul_combined); - } else if (*branch_info.activation == "silu") { - matmul_combined = silu(matmul_combined); - } else { - LOG(FATAL) << "Unsupported activation: " << *branch_info.activation; - } - } + auto inp = matchings[patterns.input]; + auto concat_rhs = concat(Tuple(rhs), Integer(rhs_dim - 1)); + auto out_dtype = GetTensorSInfo(matchings[patterns.matmul[indices[0]]])->dtype; + auto matmul_combined = matmul(inp, concat_rhs, out_dtype); + + const auto& pattern_to_replace = [&patterns, &branch_info]() { + if (branch_info.activation) return patterns.activation; + if (branch_info.bias_dim) return patterns.bias_add; + return patterns.matmul; + }(); + + if (branch_info.bias_dim) { + auto bias_dim = GetTensorSInfo(bias[0])->ndim; + auto concat_bias = concat(Tuple(bias), Integer(bias_dim - 1)); + matmul_combined = add(matmul_combined, concat_bias); + } - PrimExpr begin{0}; - Array strides{1}; - int slice_axis = std::max(lhs_dim, rhs_dim) - 1; - - for (size_t i = 0; i < indices.size(); ++i) { - auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1]; - auto bound_var = matchings[(*pattern_to_replace)[indices[i]]]; - auto slice = - strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); - replacements.Set(bound_var, slice); - begin += width; - } + if (branch_info.activation) { + if (*branch_info.activation == "relu") { + matmul_combined = relu(matmul_combined); + } else if (*branch_info.activation == "gelu") { + matmul_combined = gelu(matmul_combined); + } else if (*branch_info.activation == "silu") { + matmul_combined = silu(matmul_combined); + } else { + LOG(FATAL) << "Unsupported activation: " << *branch_info.activation; } + } - return replacements; - }; + PrimExpr begin{0}; + Array strides{1}; + int lhs_dim = GetTensorSInfo(inp)->ndim; + int slice_axis = std::max(lhs_dim, rhs_dim) - 1; + + for (size_t i = 0; i < indices.size(); ++i) { + auto width = GetTensorSInfo(rhs[i])->GetShape().value()[rhs_dim - 1]; + auto bound_var = matchings[pattern_to_replace[indices[i]]]; + auto slice = + strided_slice(matmul_combined, {slice_axis}, {begin}, {begin + width}, strides); + replacements.Set(bound_var, slice); + begin += width; + } + } - auto rewritten = RewriteBindings(ctx, rewriter, f); - ctx.ExitWithScope(); - return rewritten; + return replacements; + }; +} + +Function Rewrite(Function f, const BranchInfo& branch_info) { + auto patterns = CreatePatterns(branch_info); + auto rewriter = GetRewriter(patterns, branch_info); + return RewriteBindings(patterns.ctx, rewriter, f); } std::vector GetBranchInfo(Function f) { - auto lhs_pat = Wildcard(); - auto rhs_pat = Wildcard(); auto bias_pat = Wildcard(); - - auto matmul_pat = IsOp("relax.matmul")(lhs_pat, rhs_pat); + auto matmul_pat = IsOp("relax.matmul")(Wildcard(), Wildcard()); auto bias_add_pat = IsOp("relax.add")(matmul_pat, bias_pat); std::vector activations{"relax.nn.relu", "relax.nn.gelu", "relax.nn.silu"}; @@ -206,79 +212,68 @@ std::vector GetBranchInfo(Function f) { auto bindings = AnalyzeVar2Value(f); - std::unordered_map groups_activation, groups_bias, groups_matmul; - - PostOrderVisit(f, [&](const Expr& e) { - if (!e->IsInstance()) return; - if (auto match = ExtractMatchedExpr(bias_add_pat, e, bindings)) { - auto matmul_call = Downcast(match.value()[matmul_pat]); - auto matmul_lhs = Downcast(matmul_call->args[0]); - auto bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; - if (auto it = groups_bias.find(matmul_lhs.get()); it == groups_bias.end()) { - groups_bias[matmul_lhs.get()] = {1, bias_dim, std::nullopt}; - } else { - it->second.num_branches += 1; - if (it->second.bias_dim && *it->second.bias_dim != bias_dim) { - it->second.bias_dim = std::nullopt; + using BranchGroups = std::unordered_map; + + auto create_group = [&](DFPattern pat, const std::vector& ignore_groups) { + BranchGroups groups; + + PostOrderVisit(f, [&](const Expr& e) { + if (!e->IsInstance()) return; + if (auto match = ExtractMatchedExpr(pat, e, bindings)) { + auto matmul_call = Downcast(match.value()[matmul_pat]); + auto matmul_lhs = Downcast(matmul_call->args[0]); + + for (const auto& prev_group : ignore_groups) { + if (prev_group.count(matmul_lhs.get())) return; } + + auto it = groups.find(matmul_lhs.get()); + BranchInfo* branch = it != groups.end() ? &it->second : nullptr; + std::optional bias_dim = std::nullopt; + std::optional activation = std::nullopt; + + if (match.value().count(bias_pat)) { + bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; + } + + for (size_t i = 0; i < activations.size(); ++i) { + if (match.value().count(activation_pat[i])) { + activation = activations[i]; + } + } + + if (!branch) { + groups[matmul_lhs.get()] = {1, bias_dim, activation}; + } else { + branch->num_branches += 1; + + if (branch->bias_dim && branch->bias_dim != bias_dim) { + branch->bias_dim = std::nullopt; + } + + if ((branch->activation && activation) && *branch->activation != *activation) { + branch->activation = std::nullopt; + } + } + return; } - return; - } - }); - - PostOrderVisit(f, [&](const Expr& e) { - if (!e->IsInstance()) return; - if (auto match = ExtractMatchedExpr(matmul_pat, e, bindings)) { - auto matmul_call = Downcast(match.value()[matmul_pat]); - auto matmul_lhs = Downcast(matmul_call->args[0]); - if (groups_bias.count(matmul_lhs.get()) || groups_activation.count(matmul_lhs.get())) return; - if (auto it = groups_matmul.find(matmul_lhs.get()); it == groups_matmul.end()) { - groups_matmul[matmul_lhs.get()] = {1, std::nullopt, std::nullopt}; - } else { - it->second.num_branches += 1; - } - return; - } - }); + }); + + return groups; + }; + BranchGroups groups_activation; // for (size_t i = 0; i < activations.size(); ++i) { - // if (auto match = ExtractMatchedExpr(bias_activation_pat[i], e, bindings)) { - // auto matmul_lhs = Downcast(match.value()[lhs_pat]); - // auto bias_dim = GetTensorSInfo(match.value()[bias_pat])->ndim; - // if (auto it = groups.find(matmul_lhs.get()); it == groups.end()) { - // groups[matmul_lhs.get()] = {1, bias_dim, activations[i]}; - // } else { - // it->second.num_branches += 1; - - // if (it->second.bias_dim != bias_dim) { - // it->second.bias_dim = std::nullopt; - // } - - // if (!it->second.activation || (*it->second.activation != activations[i])) { - // it->second.activation = std::nullopt; - // } - // } - - // for (auto pat : {matmul_pat, bias_add_pat}) { - // seen.insert(match.value()[pat].get()); - // } - // return; - // } - // if (auto match = ExtractMatchedExpr(activation_pat[i], e, bindings)) { - // auto matmul = match.value()[matmul_pat]; - // auto matmul_lhs = Downcast(match.value()[lhs_pat]); - // if (auto it = groups.find(matmul_lhs.get()); it == groups.end()) { - // groups[matmul_lhs.get()] = {1, std::nullopt, activations[i]}; - // } else { - // it->second.num_branches += 1; - - // if (!it->second.activation || (*it->second.activation != activations[i])) { - // it->second.activation = std::nullopt; - // } - // } - // return; - // } + // auto groups = create_group(bias_activation_pat[i], {}); + // groups_activation.merge(std::move(groups)); // } + // for (size_t i = 0; i < activations.size(); ++i) { + // auto groups = create_group(activation_pat[i], {}); + // groups_activation.merge(std::move(groups)); + // } + + auto groups_bias = create_group(bias_add_pat, {groups_activation}); + auto groups_matmul = create_group(matmul_pat, {groups_activation, groups_bias}); std::vector info; From 8b4ae01162dbd867662a20a82da067771ee429ec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 11:15:03 +0900 Subject: [PATCH 13/20] fixed activation handling --- .../transform/combine_parallel_matmul.cc | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 45324be91317..9ed9c6e34a9c 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -224,7 +224,9 @@ std::vector GetBranchInfo(Function f) { auto matmul_lhs = Downcast(matmul_call->args[0]); for (const auto& prev_group : ignore_groups) { - if (prev_group.count(matmul_lhs.get())) return; + if (auto it = prev_group.find(matmul_lhs.get()); + it != prev_group.end() && it->second.num_branches > 1) + return; } auto it = groups.find(matmul_lhs.get()); @@ -263,14 +265,15 @@ std::vector GetBranchInfo(Function f) { }; BranchGroups groups_activation; - // for (size_t i = 0; i < activations.size(); ++i) { - // auto groups = create_group(bias_activation_pat[i], {}); - // groups_activation.merge(std::move(groups)); - // } - // for (size_t i = 0; i < activations.size(); ++i) { - // auto groups = create_group(activation_pat[i], {}); - // groups_activation.merge(std::move(groups)); - // } + for (size_t i = 0; i < activations.size(); ++i) { + auto groups = create_group(bias_activation_pat[i], {}); + groups_activation.merge(std::move(groups)); + } + + for (size_t i = 0; i < activations.size(); ++i) { + auto groups = create_group(activation_pat[i], {}); + groups_activation.merge(std::move(groups)); + } auto groups_bias = create_group(bias_add_pat, {groups_activation}); auto groups_matmul = create_group(matmul_pat, {groups_activation, groups_bias}); From e0aa3a0d2d3b423b5ef5104d72e8b8cec456fd97 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 12:35:31 +0900 Subject: [PATCH 14/20] wip --- tests/python/relax/test_dataflow_pattern.py | 6 +- .../test_transform_combine_parallel_matmul.py | 97 +++++++++++++++++++ 2 files changed, 100 insertions(+), 3 deletions(-) create mode 100644 tests/python/relax/test_transform_combine_parallel_matmul.py diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index a73a62eeef8d..ed221f54be1e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -18,7 +18,7 @@ import pytest import tvm.testing -from tvm import relay, relax +from tvm import relay from tvm.relax.dpl import * from tvm.relax.analysis import get_var2val from tvm import relax as rx, tir @@ -1177,9 +1177,9 @@ def expected( # make sure it builds mod = tvm.IRModule() mod["main"] = rewritten - mod = relax.transform.LegalizeOps()(mod) + mod = rx.transform.LegalizeOps()(mod) - relax.build(mod, target="llvm") + rx.build(mod, target="llvm") if __name__ == "__main__": diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py new file mode 100644 index 000000000000..76d223671d15 --- /dev/null +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm.testing + +from tvm import relax, tir +from tvm.script import relax as R, tir as T +from tvm.relax.transform import CombineParallelMatmul +from tvm.script.ir_builder import IRBuilder +from tvm.script.ir_builder import relax as relax_builder + + +def get_parallel_matmul( + num_branches, + with_bias=False, + activation=None, +): + shape = (640, 640) + dtype = "float32" + + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(shape, dtype)) + + rhs = [] + bias = [] + + for _ in range(num_branches): + rhs.append(R.arg("y", R.Tensor(shape, dtype))) + + if with_bias: + bias.append(R.arg("bias", R.Tensor((shape[1],), dtype))) + + with R.dataflow() as frame: + branches = [] + + for i, r in enumerate(rhs): + result = R.emit(R.matmul(x, r, out_dtype=dtype)) + if with_bias: + result = R.emit(result + bias[i]) + if activation is not None: + result = R.emit(activation(result)) + + branches.append(result) + + R.output(R.emit(R.concat(branches, axis=1))) + + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + +def test_attention_qkv(): + @tvm.script.ir_module + class QKV_proj: + @R.function + def main( + x: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + out = (lv0, lv1, lv2) + R.output(out) + return out + + mod = get_parallel_matmul(3) + + # tvm.ir.assert_structural_equal(mod, QKV_proj) + mod = CombineParallelMatmul()(mod) + + + print(mod) + + +if __name__ == "__main__": + # tvm.testing.main() + test_attention_qkv() From c3b20b7dca45978248c765d33d10849752b076f2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 17:11:18 +0900 Subject: [PATCH 15/20] clean --- .../test_transform_combine_parallel_matmul.py | 77 +++++++++++++------ 1 file changed, 53 insertions(+), 24 deletions(-) diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 76d223671d15..e345096e8a0a 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -25,25 +25,26 @@ def get_parallel_matmul( num_branches, + lhs_shape=(640, 640), + rhs_shape=(640, 640), with_bias=False, activation=None, ): - shape = (640, 640) dtype = "float32" with IRBuilder() as builder: with relax_builder.function(): R.func_name("main") - x = R.arg("x", R.Tensor(shape, dtype)) + x = R.arg("x", R.Tensor(lhs_shape, dtype)) rhs = [] bias = [] for _ in range(num_branches): - rhs.append(R.arg("y", R.Tensor(shape, dtype))) + rhs.append(R.arg("y", R.Tensor(rhs_shape, dtype))) if with_bias: - bias.append(R.arg("bias", R.Tensor((shape[1],), dtype))) + bias.append(R.arg("bias", R.Tensor((rhs_shape[1],), dtype))) with R.dataflow() as frame: branches = [] @@ -65,33 +66,61 @@ def get_parallel_matmul( return tvm.IRModule({"main": func}) -def test_attention_qkv(): - @tvm.script.ir_module - class QKV_proj: - @R.function - def main( - x: R.Tensor((2, 1024, 640), "float32"), - w0: R.Tensor((640, 640), "float32"), - w1: R.Tensor((640, 640), "float32"), - w2: R.Tensor((640, 640), "float32"), - ) -> R.Tensor: - with R.dataflow(): - lv0 = R.matmul(x, w0) - lv1 = R.matmul(x, w1) - lv2 = R.matmul(x, w2) - out = (lv0, lv1, lv2) - R.output(out) - return out - +def test_simple(): mod = get_parallel_matmul(3) + mod = CombineParallelMatmul()(mod) - # tvm.ir.assert_structural_equal(mod, QKV_proj) + @R.function + def expected1( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv1_1 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) + lv2 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv3 = R.concat((lv_1, lv1_1, lv2), axis=1) + R.output(lv3) + return lv3 + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + # Test a batched LHS case, slicing is done on the axis 2 + mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640)) mod = CombineParallelMatmul()(mod) + @R.function + def expected2( + x: R.Tensor((2, 1024, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor((2, 3072, 640), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) + lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], strides=[1]) + lv3 = R.concat((lv_1, lv1_1, lv2), axis=1) + R.output(lv3) + return lv3 + + tvm.ir.assert_structural_equal(mod["main"], expected2) + + +def test_bias(): + mod = get_parallel_matmul(3, with_bias=True) + print(mod) + mod = CombineParallelMatmul()(mod) print(mod) if __name__ == "__main__": # tvm.testing.main() - test_attention_qkv() + test_simple() From d7f83cdfdfc2bf38afd7b807bcdaf3623b780b9a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 17:49:35 +0900 Subject: [PATCH 16/20] fix bias and activation combine logic --- .../transform/combine_parallel_matmul.cc | 38 ++++++----- .../test_transform_combine_parallel_matmul.py | 65 ++++++++++++++++--- 2 files changed, 78 insertions(+), 25 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index 9ed9c6e34a9c..c62141cf35a7 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -212,10 +212,8 @@ std::vector GetBranchInfo(Function f) { auto bindings = AnalyzeVar2Value(f); - using BranchGroups = std::unordered_map; - - auto create_group = [&](DFPattern pat, const std::vector& ignore_groups) { - BranchGroups groups; + auto create_group = [&](DFPattern pat) { + std::unordered_map groups; PostOrderVisit(f, [&](const Expr& e) { if (!e->IsInstance()) return; @@ -223,12 +221,6 @@ std::vector GetBranchInfo(Function f) { auto matmul_call = Downcast(match.value()[matmul_pat]); auto matmul_lhs = Downcast(matmul_call->args[0]); - for (const auto& prev_group : ignore_groups) { - if (auto it = prev_group.find(matmul_lhs.get()); - it != prev_group.end() && it->second.num_branches > 1) - return; - } - auto it = groups.find(matmul_lhs.get()); BranchInfo* branch = it != groups.end() ? &it->second : nullptr; std::optional bias_dim = std::nullopt; @@ -249,11 +241,11 @@ std::vector GetBranchInfo(Function f) { } else { branch->num_branches += 1; - if (branch->bias_dim && branch->bias_dim != bias_dim) { + if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) { branch->bias_dim = std::nullopt; } - if ((branch->activation && activation) && *branch->activation != *activation) { + if (!activation || (branch->activation && *branch->activation != *activation)) { branch->activation = std::nullopt; } } @@ -264,19 +256,31 @@ std::vector GetBranchInfo(Function f) { return groups; }; - BranchGroups groups_activation; + std::unordered_map groups_activation; for (size_t i = 0; i < activations.size(); ++i) { - auto groups = create_group(bias_activation_pat[i], {}); + auto groups = create_group(bias_activation_pat[i]); groups_activation.merge(std::move(groups)); } for (size_t i = 0; i < activations.size(); ++i) { - auto groups = create_group(activation_pat[i], {}); + auto groups = create_group(activation_pat[i]); groups_activation.merge(std::move(groups)); } - auto groups_bias = create_group(bias_add_pat, {groups_activation}); - auto groups_matmul = create_group(matmul_pat, {groups_activation, groups_bias}); + auto groups_bias = create_group(bias_add_pat); + auto groups_matmul = create_group(matmul_pat); + + for (auto groups : {groups_bias, groups_activation}) { + for (const auto& [lhs, branch] : groups) { + // Prefer combining more matmuls than combining fewer ones and leaving additional uncombined + // matmuls followed by bias or activation So we combine matmuls + fused ops patterns only when + // all branches have the same fused ops. + if (auto it = groups_matmul.find(lhs); + it != groups_matmul.end() && it->second.num_branches == branch.num_branches) { + it->second = branch; + } + } + } std::vector info; diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index e345096e8a0a..bfeaae4b612e 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -27,7 +27,7 @@ def get_parallel_matmul( num_branches, lhs_shape=(640, 640), rhs_shape=(640, 640), - with_bias=False, + with_bias=None, activation=None, ): dtype = "float32" @@ -40,18 +40,20 @@ def get_parallel_matmul( rhs = [] bias = [] - for _ in range(num_branches): + for i in range(num_branches): rhs.append(R.arg("y", R.Tensor(rhs_shape, dtype))) - if with_bias: + if with_bias and with_bias[i]: bias.append(R.arg("bias", R.Tensor((rhs_shape[1],), dtype))) + else: + bias.append(None) with R.dataflow() as frame: branches = [] for i, r in enumerate(rhs): result = R.emit(R.matmul(x, r, out_dtype=dtype)) - if with_bias: + if bias[i]: result = R.emit(result + bias[i]) if activation is not None: result = R.emit(activation(result)) @@ -114,13 +116,60 @@ def expected2( def test_bias(): - mod = get_parallel_matmul(3, with_bias=True) - print(mod) + mod = get_parallel_matmul(3, with_bias=[True, True, True]) mod = CombineParallelMatmul()(mod) - print(mod) + @R.function + def expected1( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + bias: R.Tensor((640,), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + bias_1: R.Tensor((640,), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + bias_2: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv2 = R.concat((bias, bias_1, bias_2), axis=0) + lv3 = R.add(lv1, lv2) + lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], strides=[1]) + lv3_1 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], strides=[1]) + lv5 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv6 = R.concat((lv1_1, lv3_1, lv5), axis=1) + R.output(lv6) + return lv6 + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + mod = get_parallel_matmul(3, with_bias=[True, False, True]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected2( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + bias: R.Tensor((640,), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + bias_1: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv1_1 = R.add(lv_1, bias) + lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) + lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv4 = R.add(lv3, bias_1) + lv5 = R.concat((lv1_1, lv2, lv4), axis=1) + R.output(lv5) + return lv5 + + tvm.ir.assert_structural_equal(mod["main"], expected2) if __name__ == "__main__": # tvm.testing.main() - test_simple() + test_bias() From 90793c4b4e4fd35d45c1c63f6d573107af83a28a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 18:03:03 +0900 Subject: [PATCH 17/20] add tests --- .../transform/combine_parallel_matmul.cc | 9 +- .../test_transform_combine_parallel_matmul.py | 302 +++++++++++++++++- 2 files changed, 303 insertions(+), 8 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index c62141cf35a7..dbd13e43c6e5 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -161,11 +161,11 @@ runtime::TypedPackedFunc(Map)> GetRewriter( } if (branch_info.activation) { - if (*branch_info.activation == "relu") { + if (*branch_info.activation == "relax.nn.relu") { matmul_combined = relu(matmul_combined); - } else if (*branch_info.activation == "gelu") { + } else if (*branch_info.activation == "relax.nn.gelu") { matmul_combined = gelu(matmul_combined); - } else if (*branch_info.activation == "silu") { + } else if (*branch_info.activation == "relax.nn.silu") { matmul_combined = silu(matmul_combined); } else { LOG(FATAL) << "Unsupported activation: " << *branch_info.activation; @@ -231,7 +231,8 @@ std::vector GetBranchInfo(Function f) { } for (size_t i = 0; i < activations.size(); ++i) { - if (match.value().count(activation_pat[i])) { + if (match.value().count(activation_pat[i]) || + match.value().count(bias_activation_pat[i])) { activation = activations[i]; } } diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index bfeaae4b612e..f5cc269620f7 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -32,6 +32,8 @@ def get_parallel_matmul( ): dtype = "float32" + activation_map = {"relu": R.nn.relu, "gelu": R.nn.gelu} + with IRBuilder() as builder: with relax_builder.function(): R.func_name("main") @@ -55,8 +57,8 @@ def get_parallel_matmul( result = R.emit(R.matmul(x, r, out_dtype=dtype)) if bias[i]: result = R.emit(result + bias[i]) - if activation is not None: - result = R.emit(activation(result)) + if activation and activation[i]: + result = R.emit(activation_map[activation[i]](result)) branches.append(result) @@ -69,6 +71,11 @@ def get_parallel_matmul( def test_simple(): + mod_orig = get_parallel_matmul(1) + mod = CombineParallelMatmul()(mod_orig) + + tvm.ir.assert_structural_equal(mod, mod_orig) + mod = get_parallel_matmul(3) mod = CombineParallelMatmul()(mod) @@ -170,6 +177,293 @@ def expected2( tvm.ir.assert_structural_equal(mod["main"], expected2) +def test_activation(): + mod = get_parallel_matmul(3, activation=["relu", "relu", "relu"]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected1( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv2 = R.nn.relu(lv1) + lv1_1 = R.strided_slice(lv2, axes=[1], begin=[0], end=[640], strides=[1]) + lv3 = R.strided_slice(lv2, axes=[1], begin=[640], end=[1280], strides=[1]) + lv5 = R.strided_slice(lv2, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv6 = R.concat((lv1_1, lv3, lv5), axis=1) + R.output(lv6) + return lv6 + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected2( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv1_1 = R.nn.gelu(lv_1) + lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) + lv3 = R.nn.relu(lv2) + lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv5 = R.nn.relu(lv4) + lv6 = R.concat((lv1_1, lv3, lv5), axis=1) + R.output(lv6) + return lv6 + + tvm.ir.assert_structural_equal(mod["main"], expected2) + + mod = get_parallel_matmul(3, activation=["relu", None, None]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected3( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv1_1 = R.nn.relu(lv_1) + lv2 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) + lv3 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv4 = R.concat((lv1_1, lv2, lv3), axis=1) + R.output(lv4) + return lv4 + + tvm.ir.assert_structural_equal(mod["main"], expected3) + + +def test_bias_activation(): + mod = get_parallel_matmul(3, with_bias=[True, True, True], activation=["relu", "relu", "relu"]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected1( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + bias: R.Tensor((640,), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + bias_1: R.Tensor((640,), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + bias_2: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv2 = R.concat((bias, bias_1, bias_2), axis=0) + lv3 = R.add(lv1, lv2) + lv4 = R.nn.relu(lv3) + lv2_1 = R.strided_slice(lv4, axes=[1], begin=[0], end=[640], strides=[1]) + lv5 = R.strided_slice(lv4, axes=[1], begin=[640], end=[1280], strides=[1]) + lv8 = R.strided_slice(lv4, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv9 = R.concat((lv2_1, lv5, lv8), axis=1) + R.output(lv9) + return lv9 + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + mod = get_parallel_matmul(3, with_bias=[True, True, True], activation=["relu", None, "relu"]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected2( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + bias: R.Tensor((640,), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + bias_1: R.Tensor((640,), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + bias_2: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv2 = R.concat((bias, bias_1, bias_2), axis=0) + lv3 = R.add(lv1, lv2) + lv1_1 = R.strided_slice(lv3, axes=[1], begin=[0], end=[640], strides=[1]) + lv2_1 = R.nn.relu(lv1_1) + lv4 = R.strided_slice(lv3, axes=[1], begin=[640], end=[1280], strides=[1]) + lv6 = R.strided_slice(lv3, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv7 = R.nn.relu(lv6) + lv8 = R.concat((lv2_1, lv4, lv7), axis=1) + R.output(lv8) + return lv8 + + tvm.ir.assert_structural_equal(mod["main"], expected2) + + mod = get_parallel_matmul(3, with_bias=[True, False, True], activation=["relu", None, "relu"]) + mod = CombineParallelMatmul()(mod) + + @R.function + def expected3( + x: R.Tensor((640, 640), dtype="float32"), + y: R.Tensor((640, 640), dtype="float32"), + bias: R.Tensor((640,), dtype="float32"), + y_1: R.Tensor((640, 640), dtype="float32"), + y_2: R.Tensor((640, 640), dtype="float32"), + bias_1: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor((640, 1920), dtype="float32"): + with R.dataflow(): + lv = R.concat((y, y_1, y_2), axis=1) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv_1 = R.strided_slice(lv1, axes=[1], begin=[0], end=[640], strides=[1]) + lv1_1 = R.add(lv_1, bias) + lv2 = R.nn.relu(lv1_1) + lv3 = R.strided_slice(lv1, axes=[1], begin=[640], end=[1280], strides=[1]) + lv4 = R.strided_slice(lv1, axes=[1], begin=[1280], end=[1920], strides=[1]) + lv5 = R.add(lv4, bias_1) + lv6 = R.nn.relu(lv5) + lv7 = R.concat((lv2, lv3, lv6), axis=1) + R.output(lv7) + return lv7 + + tvm.ir.assert_structural_equal(mod["main"], expected3) + + +def test_rhs_batched(): + @tvm.script.ir_module + class four_matmul: + @R.function + def main( + x: R.Tensor((1024, 640), "float32"), + w0: R.Tensor((2, 640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((2, 640, 640), "float32"), + w3: R.Tensor((3, 4, 640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + lv3 = R.matmul(x, w3) + out = (lv0, lv1, lv2, lv3) + R.output(out) + return out + + mod = CombineParallelMatmul()(four_matmul) + + @R.function + def expected1( + x: R.Tensor((1024, 640), dtype="float32"), + w0: R.Tensor((2, 640, 640), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((2, 640, 640), dtype="float32"), + w3: R.Tensor((3, 4, 640, 640), dtype="float32"), + ) -> R.Tensor: + with R.dataflow(): + lv = R.concat((w0, w2), axis=2) + lv1 = R.matmul(x, lv, out_dtype="float32") + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) + lv1_1 = R.matmul(x, w1, out_dtype="void") + lv2 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) + lv3 = R.matmul(x, w3, out_dtype="void") + out = lv0, lv1_1, lv2, lv3 + R.output(out) + return out + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + @tvm.script.ir_module + class four_matmul_incompatible_batches: + @R.function + def main( + x: R.Tensor((1024, 640), "float32"), + w0: R.Tensor((2, 640, 640), "float32"), + w1: R.Tensor((3, 640, 640), "float32"), + w2: R.Tensor((2, 640, 640), "float32"), + w3: R.Tensor((2, 640, 640), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x, w0) + lv1 = R.matmul(x, w1) + lv2 = R.matmul(x, w2) + lv3 = R.matmul(x, w3) + out = (lv0, lv1, lv2, lv3) + R.output(out) + return out + + mod = CombineParallelMatmul()(four_matmul_incompatible_batches) + # For now, when rhs matrices have the same rank but different batch sizes, we don't + # combine any of them. + tvm.ir.assert_structural_equal(mod, four_matmul_incompatible_batches) + + +def test_multiple_combine(): + @tvm.script.ir_module + class multiple_combine: + @R.function + def main( + x1: R.Tensor((2, 1024, 640), "float32"), + x2: R.Tensor((2, 1024, 640), "float32"), + w0: R.Tensor((640, 640), "float32"), + w1: R.Tensor((640, 640), "float32"), + w2: R.Tensor((640, 640), "float32"), + w3: R.Tensor((640, 640), "float32"), + w4: R.Tensor((640, 640), "float32"), + b0: R.Tensor((640,), "float32"), + b1: R.Tensor((640,), "float32"), + ) -> R.Tensor: + with R.dataflow(): + lv0 = R.matmul(x1, w0) + lv3 = R.matmul(x2, w3) + lv1 = R.matmul(x1, w1) + lv5 = R.add(lv3, b0) + lv2 = R.matmul(x1, w2) + lv4 = R.matmul(x2, w4) + lv6 = R.add(lv4, b1) + out = (lv0, lv1, lv2, lv5, lv6) + R.output(out) + return out + + mod = CombineParallelMatmul()(multiple_combine) + + @R.function + def expected1( + x1: R.Tensor((2, 1024, 640), dtype="float32"), + x2: R.Tensor((2, 1024, 640), dtype="float32"), + w0: R.Tensor((640, 640), dtype="float32"), + w1: R.Tensor((640, 640), dtype="float32"), + w2: R.Tensor((640, 640), dtype="float32"), + w3: R.Tensor((640, 640), dtype="float32"), + w4: R.Tensor((640, 640), dtype="float32"), + b0: R.Tensor((640,), dtype="float32"), + b1: R.Tensor((640,), dtype="float32"), + ) -> R.Tensor: + with R.dataflow(): + lv = R.concat((w0, w1, w2), axis=1) + lv1 = R.matmul(x1, lv, out_dtype="float32") + lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640], strides=[1]) + lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280], strides=[1]) + lv_1 = R.concat((w3, w4), axis=1) + lv1_2 = R.matmul(x2, lv_1, out_dtype="float32") + lv2 = R.concat((b0, b1), axis=0) + lv3 = R.add(lv1_2, lv2) + lv5 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640], strides=[1]) + lv2_1 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920], strides=[1]) + lv6 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280], strides=[1]) + out = lv0, lv1_1, lv2_1, lv5, lv6 + R.output(out) + return out + + tvm.ir.assert_structural_equal(mod["main"], expected1) + + if __name__ == "__main__": - # tvm.testing.main() - test_bias() + tvm.testing.main() From e7b6e1ffef4251195cee23e60215628a2f66c3bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 19:52:51 +0900 Subject: [PATCH 18/20] add comment --- .../transform/combine_parallel_matmul.cc | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index dbd13e43c6e5..d7fb928a5b1d 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -40,6 +40,9 @@ namespace relax { using runtime::Map; +/*! \brief Group shapes of the RHS matrices by rank. Matrices in a group whose batch sizes + are compatible are combined. +*/ std::unordered_map> GroupShapes( const std::vector>& shapes) { std::unordered_map> indices_map; @@ -102,6 +105,7 @@ Patterns CreatePatterns(const BranchInfo& branch_info) { return patterns; } +/*! \brief Create a rewriter for the given parallel matmul branches. */ runtime::TypedPackedFunc(Map)> GetRewriter( const Patterns& patterns, const BranchInfo& branch_info) { auto batch_dims_compatible = [](int rhs_dim, const std::vector& indices, @@ -197,6 +201,9 @@ Function Rewrite(Function f, const BranchInfo& branch_info) { return RewriteBindings(patterns.ctx, rewriter, f); } +/*! \brief Look for subtrees with parallel matmul and return information about + them (the number of branches and the kind of fused ops) +*/ std::vector GetBranchInfo(Function f) { auto bias_pat = Wildcard(); auto matmul_pat = IsOp("relax.matmul")(Wildcard(), Wildcard()); @@ -213,6 +220,7 @@ std::vector GetBranchInfo(Function f) { auto bindings = AnalyzeVar2Value(f); auto create_group = [&](DFPattern pat) { + // Maps a LHS matrix to consumer parallel matmuls std::unordered_map groups; PostOrderVisit(f, [&](const Expr& e) { @@ -238,8 +246,11 @@ std::vector GetBranchInfo(Function f) { } if (!branch) { + // Create a new subgraph with one matmul groups[matmul_lhs.get()] = {1, bias_dim, activation}; } else { + // Create a new branch in the existing parallel matmul subtree, and + // invalidate bias and activation information when needed. branch->num_branches += 1; if (!bias_dim || (branch->bias_dim && *branch->bias_dim != *bias_dim)) { @@ -271,11 +282,11 @@ std::vector GetBranchInfo(Function f) { auto groups_bias = create_group(bias_add_pat); auto groups_matmul = create_group(matmul_pat); - for (auto groups : {groups_bias, groups_activation}) { + for (const auto& groups : {groups_bias, groups_activation}) { for (const auto& [lhs, branch] : groups) { // Prefer combining more matmuls than combining fewer ones and leaving additional uncombined - // matmuls followed by bias or activation So we combine matmuls + fused ops patterns only when - // all branches have the same fused ops. + // matmuls followed by bias or activation. So we combine matmuls + fused ops patterns only + // when all branches have the same fused ops. if (auto it = groups_matmul.find(lhs); it != groups_matmul.end() && it->second.num_branches == branch.num_branches) { it->second = branch; @@ -285,7 +296,7 @@ std::vector GetBranchInfo(Function f) { std::vector info; - for (auto groups : {groups_matmul, groups_activation, groups_bias}) { + for (const auto& groups : {groups_matmul, groups_activation, groups_bias}) { for (const auto& group : groups) { if (group.second.num_branches > 1) { info.push_back(group.second); @@ -293,14 +304,14 @@ std::vector GetBranchInfo(Function f) { } } - std::sort(info.begin(), info.end(), - [](const auto& b1, const auto& b2) { return b1.num_branches > b2.num_branches; }); - return info; } Function CombineParallelMatmul(Function f) { auto branches = GetBranchInfo(f); + std::sort(branches.begin(), branches.end(), + [](const auto& b1, const auto& b2) { return b1.num_branches > b2.num_branches; }); + for (const auto& branch : branches) { f = Rewrite(f, branch); } From 4399cc3c83bcac95d543f78cc77f0cac99556eaa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 11 Apr 2023 20:02:31 +0900 Subject: [PATCH 19/20] add doc --- include/tvm/relax/dataflow_matcher.h | 9 +++++++++ python/tvm/relax/transform/transform.py | 12 ++++++++++++ 2 files changed, 21 insertions(+) diff --git a/include/tvm/relax/dataflow_matcher.h b/include/tvm/relax/dataflow_matcher.h index 2235ea16af2a..16249377a27d 100644 --- a/include/tvm/relax/dataflow_matcher.h +++ b/include/tvm/relax/dataflow_matcher.h @@ -58,6 +58,15 @@ Optional> ExtractMatchedExpr( TVM_DLL Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb); +/** + * \brief Rewrite a function with the given pattern and the rewriter function. + * \param ctx The pattern constraint context under which rewriting takes place. + * \param rewriter The function to be called on a successful matching for rewriting. + Given the map of patterns and corresponding variables (bound variables or parameters), + it should return a map that specifies new values for matched bound variables. + * \param f The function to rewrite + * \return The rewritten or the input function, depending on the pattern matching result. + */ TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f); } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 82762feedeff..f0277151bbdc 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -935,6 +935,18 @@ def SplitCallTIRByPattern(patterns, fcodegen) -> tvm.ir.transform.Pass: def CombineParallelMatmul(): + """Combine multiple matmul operators sharing the same LHS matrix into one, + followed by slicing. When all matmul branches in a tree have the same set of fused ops, + the fused ops are applied to the combined matmul output before slicing. + + Currently, only a limited set of fused ops is supported. It includes bias add, + relu, gelu, and silu activation. + + Returns + ------- + ret : tvm.transform.Pass + The corresponding pass. + """ return _ffi_api.CombineParallelMatmul() # type: ignore From ceee41c0409504a2b195f4031f83195814d09ebb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 12 Apr 2023 04:10:28 +0900 Subject: [PATCH 20/20] fix compile warning --- src/relax/transform/combine_parallel_matmul.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/transform/combine_parallel_matmul.cc b/src/relax/transform/combine_parallel_matmul.cc index d7fb928a5b1d..d6435ec8292f 100644 --- a/src/relax/transform/combine_parallel_matmul.cc +++ b/src/relax/transform/combine_parallel_matmul.cc @@ -108,7 +108,7 @@ Patterns CreatePatterns(const BranchInfo& branch_info) { /*! \brief Create a rewriter for the given parallel matmul branches. */ runtime::TypedPackedFunc(Map)> GetRewriter( const Patterns& patterns, const BranchInfo& branch_info) { - auto batch_dims_compatible = [](int rhs_dim, const std::vector& indices, + auto batch_dims_compatible = [](size_t rhs_dim, const std::vector& indices, const std::vector>& rhs_shapes) { arith::Analyzer ana; for (auto ind : indices) {