diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 557e4edac681..5a45ac276de9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,6 +13,7 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, + "CombineParallelConv2D": 1, "OpFusion": 1, "FoldConstant": 2, "FoldScaleAxis": 3, @@ -144,6 +145,10 @@ def optimize(func, params=None): func = ir_pass.infer_type(func) func = ir_pass.simplify_inference(func) + if cfg.pass_enabled("CombineParallelConv2D"): + func = ir_pass.infer_type(func) + func = ir_pass.combine_parallel_conv2d(func) + if cfg.pass_enabled("FoldScaleAxis"): func = ir_pass.infer_type(func) func = ir_pass.backward_fold_scale_axis(func) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 9d59980f6127..ef0a59cd3f6d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1): Transformed expression, containing fused result. """ return _ir_pass.FuseOps(expr, opt_level) + + +def combine_parallel_conv2d(expr): + """Fold multiple conv2d into one. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + transformed_expr : tvm.relay.Expr + Transformed expression + """ + return _ir_pass.CombineParallelConv2D(expr) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc new file mode 100644 index 000000000000..48d5d77990d6 --- /dev/null +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -0,0 +1,328 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file combine_parallel_conv2d.cc + * \brief Combine parallel 2d convolutions into a single convolution. + * + * This pass replaces convolutions that share the same input node and the same + * arguments (except that the number of output channels can be different) with a + * single convolution. The weight of the new 2d convolution is the concatenation + * of the original weights. Elemwise and broadcast ops following conv2d are also + * combined if possible. + * + * This prevents launching multiple kernels in networks with multiple + * convolution branches, such as Inception block. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" + + +namespace tvm { +namespace relay { + +using Branch = std::vector; +using Group = std::vector; + +/* + Find parallel branches starting with conv2d as shown below and then group branches by kernel + shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops. + Intermediate nodes have exactly one successor. It is possible that branches meet at a point, + which should be handled in ParallelConv2DCombiner. + + data + / \ + conv2d conv2d + | | + op op + | | +*/ +class BranchGroupFinder : private ExprVisitor { + public: + std::vector Find(const Expr& expr) { + this->VisitExpr(expr); + + std::vector groups; + for (const auto& root : conv_roots_) { + const auto& convs = children_map_.at(root); + for (const CallNode* conv : convs) { + auto&& branch = CreateBranch(conv); + // add the branch to a group, or create a new group + auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { + CHECK(!group.empty() && !group[0].empty()); + return IsCompatibleConv2D(conv, group[0][0]); + }); + if (it != groups.end()) { + it->push_back(branch); + } else { + groups.emplace_back(); + // each group has at least one branch + groups.back().push_back(branch); + } + } + } + return groups; + } + + private: + std::unordered_set conv_roots_; + std::unordered_map, NodeHash, NodeEqual> children_map_; + + // Two 2d convolutions can be combined if they have the same attributes or + // only have different output channels. + bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { + AttrsEqual eq; + static const Layout kOIHW("OIHW"); + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + CHECK(attrs_a); + CHECK(attrs_b); + const auto* tweight_a = a->args[1]->type_as(); + const auto* tweight_b = b->args[1]->type_as(); + const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); + const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); + + return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && + eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && + eq(attrs_a->data_layout, attrs_b->data_layout) && + eq(attrs_a->weight_layout, attrs_b->weight_layout) && + eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && + eq(shape_a[3], shape_b[3]); + } + + // Create a branch starting from conv2d. + Branch CreateBranch(const CallNode* conv) { + static auto fpattern = Op::GetAttr("TOpPattern"); + // each branch has at least one element, the first element is always conv2d + Branch branch{conv}; + auto it = children_map_.find(GetRef(branch.back())); + while (it != children_map_.end() && it->second.size() == 1) { + const CallNode* call = it->second[0]; + auto pattern = fpattern[Downcast(call->op)]; + if (pattern <= kBroadcast) { + branch.push_back(it->second[0]); + it = children_map_.find(GetRef(branch.back())); + } else { + break; + } + } + return branch; + } + + void VisitExpr_(const CallNode* n) final { + static const Op& conv2d = Op::Get("nn.conv2d"); + ExprVisitor::VisitExpr_(n); + if (n->op.same_as(conv2d) && n->attrs.as()->groups == 1) { + conv_roots_.insert(n->args[0]); + children_map_[n->args[0]].push_back(n); + } else { + for (size_t i = 0; i < n->args.size(); i++) { + children_map_[n->args[i]].push_back(n); + } + } + } +}; + +class ParallelConv2DCombiner { + public: + Expr Combine(const Expr& expr) { + auto groups = BranchGroupFinder().Find(expr); + for (const Group& group : groups) { + if (group.size() < 2) continue; + CombineBranches(group); + } + return ExprSubst(expr, std::move(subst_map_)); + } + + private: + std::unordered_map subst_map_; + + std::tuple TransformWeight(const Group& branches) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const auto& branch : branches) { + auto conv2d = branch[0]; + weights.push_back(conv2d->args[1]); + auto channels = GetConv2DSuperChannelsDim(conv2d); + num_filters += channels; + } + auto index = branches[0][0]->attrs.as()->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + MakeConstScalar(Int(32), num_filters)); + } + + Call MakeCombinedConv2D(const Group& branches) { + static const Op& conv2d = Op::Get("nn.conv2d"); + Expr data = branches[0][0]->args[0]; + Expr new_weight; + IndexExpr new_channels; + std::tie(new_weight, new_channels) = TransformWeight(branches); + + const CallNode* group_root = branches[0][0]; + const auto* attrs = group_root->attrs.as(); + CHECK(attrs); + const auto new_attrs = make_node(); + new_attrs->strides = attrs->strides; + new_attrs->padding = attrs->padding; + new_attrs->dilation = attrs->dilation; + new_attrs->groups = attrs->groups; + new_attrs->kernel_size = attrs->kernel_size; + new_attrs->data_layout = attrs->data_layout; + new_attrs->weight_layout = attrs->weight_layout; + new_attrs->out_layout = attrs->out_layout; + new_attrs->out_dtype = attrs->out_dtype; + new_attrs->channels = new_channels; + + return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + auto toutput_a = a->type_as(); + auto toutput_b = b->type_as(); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + // Position of the 'C' dimension in the argument + size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); + + // Channel super-dimension shoule be present and not broadcasted + if ((arg_channel_pos > channel_pos) || // size_t overflow + !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || + !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (i == arg_channel_pos) continue; + if (!eq(ta->shape[i], tb->shape[i])) + return false; + } + return true; + } + + // Check if ops in depth-th level can be combined + bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { + const CallNode* call = branches[0][depth]; + AttrsEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || + !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } + + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) + return false; + + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; + + if (!IsArgCompatible(call, branch[depth], i, channel_pos) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; + } + } + } + return true; + } + + // Combine args and make the combined CallNode + Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + size_t ndim = call->type_as()->shape.size(); + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + size_t arg_ndim = call->args[i]->type_as()->shape.size(); + size_t arg_channel_pos = channel_pos - ndim + arg_ndim; + Array tuple; + for (const auto& branch : branches) { + tuple.push_back(branch[depth]->args[i]); + } + auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); + new_args.push_back(std::move(concat)); + } + return CallNode::make(call->op, new_args, call->attrs, {}); + } + + // Replace output of each branch with slices of the combined output + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, + size_t channel_pos) { + int64_t index = 0; + for (const auto& branch : branches) { + const CallNode* conv2d = branch[0]; + int64_t channels = GetConv2DSuperChannelsDim(conv2d); + Array begin; + Array end; + for (size_t i = 0; i < channel_pos; i++) { + begin.push_back(0); + end.push_back(NullValue()); + } + begin.push_back(index); + index += channels; + end.push_back(index); + auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); + subst_map_[GetRef(branch[depth])] = slice; + } + } + + // Combine branches in a group. Conv2d in different branches in the same group are safe to + // combine. Subsequent ops may or may not be combined. We start from conv2d and try to + // combine ops from all branches in the same depth. + void CombineBranches(const Group& branches) { + Call combined = MakeCombinedConv2D(branches); + auto conv_param = combined->attrs.as(); + const std::string& layout = + conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; + size_t channel_pos = layout.find('C'); + CHECK_NE(channel_pos, std::string::npos); + auto it = std::min_element(branches.begin(), branches.end(), + [](const Branch& branch_a, + const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the conv2d + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; + } + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, channel_pos, parent_index)) break; + combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); + } + UpdateGroupOutput(combined, branches, i - 1, channel_pos); + } +}; + +Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); } + +TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = CombineParallelConv2D(args[0]); +}); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc new file mode 100644 index 000000000000..586f748abef5 --- /dev/null +++ b/src/relay/pass/expr_subst.cc @@ -0,0 +1,35 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_subst.h + * \brief Utility functions for substituting expressions. + */ + +#include +#include "./expr_subst.h" + +namespace tvm { +namespace relay { + +class ExprSubstituter : public ExprMutator { + public: + explicit ExprSubstituter(std::unordered_map subst_map) + : subst_map_(subst_map) {} + + Expr VisitExpr(const Expr& expr) final { + auto it = subst_map_.find(expr); + if (it != subst_map_.end()) { + return (*it).second; + } + return ExprMutator::VisitExpr(expr); + } + + private: + tvm::Map subst_map_; +}; + +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map) { + return ExprSubstituter(std::move(subst_map)).Mutate(expr); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h new file mode 100644 index 000000000000..67892b3a0af7 --- /dev/null +++ b/src/relay/pass/expr_subst.h @@ -0,0 +1,18 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_subst.h + * \brief Utility functions for substituting expressions. + */ +#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_ +#define TVM_RELAY_PASS_EXPR_SUBST_H_ +#include +#include + +namespace tvm { +namespace relay { + +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_EXPR_SUBST_H_ diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 1c855d9a53cb..38ae923c5274 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../op/layout.h" @@ -120,6 +121,19 @@ inline bool IsDepthwiseConv2D(const Call& call, is_const_int(wshape[1], 1); } +/*! + * \brief Get super-dimension of output channels of conv2d + * \param call The conv2d call. + * \return Super-dimension size of output channels of conv2d. + */ +inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = as_const_int(tweight->shape[index]); + return *channels; +} /*! * \brief Create a Constant with a scalar @@ -172,6 +186,10 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } +Expr MakeConcatenate(Expr data, int axis); + +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py new file mode 100644 index 000000000000..31dfe095f682 --- /dev/null +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -0,0 +1,138 @@ +from tvm import relay +import numpy as np + + +def test_combine_parallel_conv2d(): + """Simple testcase.""" + def before(x, w1, w2, w3, w4): + args = [x, w1, w2, w3, w4] + y1 = relay.nn.conv2d(x, w1) + y2 = relay.nn.conv2d(x, w2) + # y3 cannot be combined + y3 = relay.nn.conv2d(x, w3) + y4 = relay.nn.conv2d(x, w4) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + # use a fixed order of args so alpha equal check can pass + args = [x, w1, w2, w3, w4] + w = relay.concatenate((w1, w2, w4), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y3 = relay.nn.conv2d(x, w3) + y4 = relay.strided_slice(y, [0, channels1 + channels2], + [None, channels1 + channels2 + channels4]) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2, channels3, channels4): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + w3 = relay.var("w3", shape=(channels3, in_c, 3, 3)) + w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) + + y_before = before(x, w1, w2, w3, w4) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 4, 4, 4) + check((1, 4, 16, 16), 4, 8, 4, 7) + + +def test_combine_parallel_conv2d_scale_relu(): + """Testcase of combining conv2d + scale + relu""" + def before(x, w1, w2, scale1, scale2, bias): + args = [x, w1, w2, scale1, scale2, bias] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y2 = relay.nn.relu(y2) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): + args = [x, w1, w2, scale1, scale2, bias] + w = relay.concatenate((w1, w2), axis=0) + scale = relay.concatenate((scale1, scale2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y = relay.multiply(y, scale) + y = relay.nn.relu(y) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(channels1, 1, 1)) + scale2 = relay.var("scale2", shape=(channels2, 1, 1)) + bias = relay.var("bias", shape=(channels2, 1, 1)) + y_before = before(x, w1, w2, scale1, scale2, bias) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + + +def test_combine_parallel_conv2d_scale(): + """Testcase of un-combinable scale""" + def before(x, w1, w2, scale1, scale2): + args = [x, w1, w2, scale1, scale2] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, channels1, channels2): + args = [x, w1, w2, scale1, scale2] + w = relay.concatenate((w1, w2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.multiply(y1, scale1) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(1,)) + scale2 = relay.var("scale2", shape=(1,)) + y_before = before(x, w1, w2, scale1, scale2) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + +if __name__ == "__main__": + test_combine_parallel_conv2d() + test_combine_parallel_conv2d_scale_relu() + test_combine_parallel_conv2d_scale()