From 7911c836e18bad5b018f45346b561cb1b0968674 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 17 Jun 2021 16:03:52 +0200 Subject: [PATCH 1/3] [Metal] Add pass for splitting kernel with huge number of args The Metal has some limitations on the number of input parameters. More information can be found here: https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc In this commit a new pass for splitting functions with big number of arguments to smaller parts was added. In parameter `max_function_args` we can specify the maximum number of kernel arguments for specific target and then split kernel when the number of arguments exceeds the value of `max_function_args`. Currently this pass works only for concat layer. --- include/tvm/relay/transform.h | 7 ++ python/tvm/relay/transform/transform.py | 11 +++ python/tvm/target/target.py | 4 + src/relay/backend/build_module.cc | 6 ++ src/relay/transforms/pattern_utils.h | 7 ++ src/relay/transforms/split_args.cc | 91 ++++++++++++++++++++ src/target/source/codegen_metal.cc | 9 +- src/target/source/codegen_metal.h | 3 +- src/target/target_kind.cc | 5 ++ tests/python/relay/test_pass_split_args.py | 96 ++++++++++++++++++++++ 10 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 src/relay/transforms/split_args.cc create mode 100644 tests/python/relay/test_pass_split_args.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index b090e3e40063..bdc46d71a77d 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -97,6 +97,13 @@ TVM_DLL Pass LazyGradientInit(); */ TVM_DLL Pass FoldConstant(); +/*! + * \brief Split function with huge number of arguments to smaller pieces. + * + * \return The pass. + */ +TVM_DLL Pass SplitArgs(int max_function_args); + /*! * \brief Fuse operations into expr into seperate functions. * diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cdfd97c780dd..6294e7acea15 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1228,3 +1228,14 @@ def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): if missing_op_mode < 0 or missing_op_mode > 2: raise ValueError("Missing op mode is either 0, 1, or 2") return _ffi_api.ToMixedPrecision(mixed_precision_type, missing_op_mode) + + +def SplitArgs(max_function_args): + """Split function with huge number of arguments to smaller pieces. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for constant folding. + """ + return _ffi_api.SplitArgs(max_function_args) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index be39a6f6bd25..439674e0468e 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -139,6 +139,10 @@ def max_num_threads(self): def thread_warp_size(self): return int(self.attrs["thread_warp_size"]) + @property + def max_function_args(self): + return int(self.attrs.get("max_function_args", -1)) + @property def device_name(self): return str(self.attrs.get("device", "")) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 23670109e527..590cdcb08218 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -365,6 +365,12 @@ class RelayBuildModule : public runtime::ModuleNode { pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + if (targets.size() == 1) { + const auto& target = (*targets.begin()).second; + pass_seqs.push_back( + transform::SplitArgs(target->GetAttr("max_function_args").value())); + } + // Create a sequential pass and perform optimizations. transform::Pass seq = transform::Sequential(pass_seqs); if (targets.size() == 1) { diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index 920ac153b63d..f1f0092b691b 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -700,6 +700,13 @@ Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); +inline Expr Concat(Expr x, int axis = 0) { + static const Op& op = Op::Get("concatenate"); + auto attrs = make_object(); + attrs->axis = axis; + return Call(op, {x}, Attrs(attrs), {}); +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc new file mode 100644 index 000000000000..cdd596be37a5 --- /dev/null +++ b/src/relay/transforms/split_args.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file split_args.cc + */ +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class ArgumentSplitter : public ExprRewriter { + public: + explicit ArgumentSplitter(int max_function_args) + : max_function_args_(max_function_args), concat_op_(Op::Get("concatenate")) {} + + Expr Rewrite_(const CallNode* call, const Expr& post) final { + if (max_function_args_ < 0) return post; + if (call->op == concat_op_) { + auto op = call->args[0].as(); + const auto param = call->attrs.as(); + const int limit = max_function_args_ - 1; // one buffer with output + int argsNum = op->fields.size(); + if (argsNum < limit) return post; + int splitNum = argsNum / limit; + splitNum = (argsNum % limit) ? splitNum + 1 : splitNum; + + std::vector splitted(splitNum); + for (int i = 0; i < splitNum; ++i) { + int startIdx = i * limit; + int argsCount = std::min(limit, argsNum - startIdx); + tvm::Array args; + for (int j = 0; j < argsCount; ++j) { + args.push_back(op->fields[j + startIdx]); + } + Tuple tuple(args); + Expr body = Concat(tuple, param->axis); + splitted[i] = StopFusion(body); + } + tvm::Array tupleArgs(splitted); + Tuple tuple(tupleArgs); + return Concat(tuple, param->axis); + } + return post; + } + + private: + const int max_function_args_; + const Op& concat_op_; +}; + +Expr SplitArgs(const Expr& expr, int max_function_args) { + auto rewriter = ArgumentSplitter(max_function_args); + return PostOrderRewrite(expr, &rewriter); +} + +namespace transform { + +Pass SplitArgs(int max_function_args) { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SplitArgs(f, max_function_args)); + }; + return CreateFunctionPass(pass_func, 1, "SplitArgs", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 71e3529e0d80..3c5ff89c6f1d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -43,7 +43,7 @@ void CodeGenMetal::InitFuncState(const PrimFunc& f) { } } -CodeGenMetal::CodeGenMetal() { +CodeGenMetal::CodeGenMetal(Target target) : target_(target) { decl_stream << "#include \n"; decl_stream << "using namespace metal;\n\n"; decl_stream << "union __TVMArgUnion {\n" @@ -67,6 +67,11 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // Buffer arguments size_t num_buffer = 0; + int limit = target_->GetAttr("max_function_args").value(); + if (f->params.size() > limit) { + LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " + "buffers in the kernel"; + } for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { Var v = f->params[i]; if (!v.dtype().is_handle()) break; @@ -332,7 +337,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenMetal: Can only take PrimFunc"; code << "// Function: " << kv.first->name_hint << std::endl; - CodeGenMetal cg; + CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 614a191907af..9fb8f80303f9 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -35,7 +35,7 @@ namespace codegen { class CodeGenMetal final : public CodeGenC { public: - CodeGenMetal(); + explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); void AddFunction(const PrimFunc& f); // NOLINT(*) @@ -58,6 +58,7 @@ class CodeGenMetal final : public CodeGenC { private: int thread_index_bits_{32}; + Target target_; }; } // namespace codegen } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index b9d9706773f7..d037b9dfdbdb 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -347,10 +347,15 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("thread_warp_size", Integer(1)) .set_default_keys({"opencl", "gpu"}); +// The metal has some limitations on the number of input parameters. This is why attribute +// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More +// information about this limitation can be found here: +// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc TVM_REGISTER_TARGET_KIND("metal", kDLMetal) .add_attr_option("system-lib") .add_attr_option("max_num_threads", Integer(256)) .add_attr_option("thread_warp_size", Integer(16)) + .add_attr_option("max_function_args", Integer(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) diff --git a/tests/python/relay/test_pass_split_args.py b/tests/python/relay/test_pass_split_args.py new file mode 100644 index 000000000000..2039f464751f --- /dev/null +++ b/tests/python/relay/test_pass_split_args.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy as np +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing import run_infer_type, create_workload + + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, tvm.transform.Pass) + + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + mod = opt_pass(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_split_concat_metal(): + shape = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + inputs = [] + for i in range(100): + inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) + + def before(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + def expected(): + limit = tvm.target.Target("metal").max_function_args - 1 # one buffer with output + splitNum = int(len(inputs) / limit) + if len(inputs) % limit > 0: + splitNum += 1 + + splitted = [] + for i in range(splitNum): + startIdx = i * limit + argsCount = min(limit, len(inputs) - startIdx) + args = [] + for j in range(argsCount): + args.append(inputs[j + startIdx]) + t = relay.Tuple(args) + concat = relay.op.concatenate(t, axis) + splitted.append(relay.annotation.stop_fusion(concat)) + inp = relay.Tuple(splitted) + return relay.op.concatenate(inp, axis) + + # the fold constant should work on any context. + res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("metal").max_function_args)) + exp = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(res, exp) + + +def test_split_concat_cuda(): + shape = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + inputs = [] + for i in range(100): + inputs.append(relay.var("p{}".format(i), shape=shape, dtype=dtype)) + + def before(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + def expected(): + inp = relay.Tuple(inputs) + return relay.op.concatenate(inp, axis) + + # the fold constant should work on any context. + res = run_opt_pass(before(), transform.SplitArgs(tvm.target.Target("cuda").max_function_args)) + exp = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(res, exp) + + +if __name__ == "__main__": + test_split_concat_metal() + test_split_concat_cuda() From 858d3f0e82b16655532489291dfe8864cf60b7e5 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 24 Jun 2021 08:15:18 +0300 Subject: [PATCH 2/3] Add getting number of output parameters --- src/relay/transforms/split_args.cc | 8 ++++++-- src/target/source/codegen_metal.cc | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index cdd596be37a5..dadfb51a3ea6 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -38,7 +38,11 @@ class ArgumentSplitter : public ExprRewriter { if (call->op == concat_op_) { auto op = call->args[0].as(); const auto param = call->attrs.as(); - const int limit = max_function_args_ - 1; // one buffer with output + int outputsNum = 1; + if (const auto* tuple_type = call->checked_type().as()) { + outputsNum = tuple_type->fields.size(); + } + const int limit = max_function_args_ - outputsNum; int argsNum = op->fields.size(); if (argsNum < limit) return post; int splitNum = argsNum / limit; @@ -80,7 +84,7 @@ Pass SplitArgs(int max_function_args) { [=](Function f, IRModule m, PassContext pc) { return Downcast(SplitArgs(f, max_function_args)); }; - return CreateFunctionPass(pass_func, 1, "SplitArgs", {}); + return CreateFunctionPass(pass_func, 1, "SplitArgs", {"InferType"}); } TVM_REGISTER_GLOBAL("relay._transform.SplitArgs").set_body_typed(SplitArgs); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 3c5ff89c6f1d..b44afec57d5d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -68,7 +68,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { // Buffer arguments size_t num_buffer = 0; int limit = target_->GetAttr("max_function_args").value(); - if (f->params.size() > limit) { + if (static_cast(f->params.size()) > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; } From f71fc9600679200bd846105c8e45246ec8e9784e Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 29 Jun 2021 12:12:55 +0300 Subject: [PATCH 3/3] Fix CI and apply comments --- src/relay/backend/build_module.cc | 2 +- src/relay/transforms/pattern_utils.h | 7 ------- src/relay/transforms/split_args.cc | 4 ++-- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 590cdcb08218..ea53c34c793b 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -368,7 +368,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { const auto& target = (*targets.begin()).second; pass_seqs.push_back( - transform::SplitArgs(target->GetAttr("max_function_args").value())); + transform::SplitArgs(target->GetAttr("max_function_args", -1).value())); } // Create a sequential pass and perform optimizations. diff --git a/src/relay/transforms/pattern_utils.h b/src/relay/transforms/pattern_utils.h index f1f0092b691b..920ac153b63d 100644 --- a/src/relay/transforms/pattern_utils.h +++ b/src/relay/transforms/pattern_utils.h @@ -700,13 +700,6 @@ Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); -inline Expr Concat(Expr x, int axis = 0) { - static const Op& op = Op::Get("concatenate"); - auto attrs = make_object(); - attrs->axis = axis; - return Call(op, {x}, Attrs(attrs), {}); -} - } // namespace relay } // namespace tvm #endif // TVM_RELAY_TRANSFORMS_PATTERN_UTILS_H_ diff --git a/src/relay/transforms/split_args.cc b/src/relay/transforms/split_args.cc index dadfb51a3ea6..70d37d822d71 100644 --- a/src/relay/transforms/split_args.cc +++ b/src/relay/transforms/split_args.cc @@ -57,12 +57,12 @@ class ArgumentSplitter : public ExprRewriter { args.push_back(op->fields[j + startIdx]); } Tuple tuple(args); - Expr body = Concat(tuple, param->axis); + Expr body = MakeConcatenate(tuple, param->axis); splitted[i] = StopFusion(body); } tvm::Array tupleArgs(splitted); Tuple tuple(tupleArgs); - return Concat(tuple, param->axis); + return MakeConcatenate(tuple, param->axis); } return post; }