From ca7a659093dbf20667f8f74422a6f1d7f246fb20 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 12 Nov 2018 11:43:43 +0800 Subject: [PATCH 01/15] Add FoldConv2D pass --- python/tvm/relay/ir_pass.py | 17 ++ src/relay/pass/expr_subst.cc | 34 ++++ src/relay/pass/expr_subst.h | 17 ++ src/relay/pass/fold_conv2d.cc | 162 ++++++++++++++++++++ src/relay/pass/pattern_util.h | 19 ++- tests/python/relay/test_pass_fold_conv2d.py | 65 ++++++++ 6 files changed, 313 insertions(+), 1 deletion(-) create mode 100644 src/relay/pass/expr_subst.cc create mode 100644 src/relay/pass/expr_subst.h create mode 100644 src/relay/pass/fold_conv2d.cc create mode 100644 tests/python/relay/test_pass_fold_conv2d.py diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 9d59980f6127..1ba2bb21a067 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -292,3 +292,20 @@ def fuse_ops(expr, opt_level=1): Transformed expression, containing fused result. """ return _ir_pass.FuseOps(expr, opt_level) + + +def fold_conv2d(expr): + """Fold multiple conv2d into one. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + Returns + ------- + transformed_expr : tvm.relay.Expr + Transformed expression, containing fused result. + """ + return _ir_pass.FoldConv2D(expr) +>>>>>>> Add FoldConv2D pass diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc new file mode 100644 index 000000000000..3e342dee5061 --- /dev/null +++ b/src/relay/pass/expr_subst.cc @@ -0,0 +1,34 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_subst.h + * \brief Utility functions for substituting expressions. + */ + +#include +#include "./expr_subst.h" + +namespace tvm { +namespace relay { + +class ExprSubstituter : public ExprMutator { + public: + explicit ExprSubstituter(tvm::Map subst_map) : subst_map_(subst_map) {} + + Expr VisitExpr(const Expr& expr) final { + auto it = subst_map_.find(expr); + if (it != subst_map_.end()) { + return (*it).second; + } + return ExprMutator::VisitExpr(expr); + } + + private: + tvm::Map subst_map_; +}; + +Expr ExprSubst(const Expr& expr, tvm::Map subst_map) { + return ExprSubstituter(std::move(subst_map)).Mutate(expr); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h new file mode 100644 index 000000000000..7656baba1fa6 --- /dev/null +++ b/src/relay/pass/expr_subst.h @@ -0,0 +1,17 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file expr_subst.h + * \brief Utility functions for substituting expressions. + */ +#ifndef TVM_RELAY_PASS_EXPR_SUBST_H_ +#define TVM_RELAY_PASS_EXPR_SUBST_H_ +#include + +namespace tvm { +namespace relay { + + Expr ExprSubst(const Expr& expr, tvm::Map subst_map); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_EXPR_SUBST_H_ diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/fold_conv2d.cc new file mode 100644 index 000000000000..e4ee45d3dbe0 --- /dev/null +++ b/src/relay/pass/fold_conv2d.cc @@ -0,0 +1,162 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file fold_conv2d.cc + * + * \brief Fold multiple 2d convolutions into a single convolution. + * + * This pass replaces convolutions that share the same input node and the same arguments (except + * that the number of output channels can be different) with a single convolution. The weight of + * the new 2d convolution is the concatenation of the original weights. + */ + +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" + +namespace tvm { +namespace relay { + +class SiblingConv2DFinder : public ExprVisitor { + public: + std::unordered_map, NodeHash, NodeEqual> + Find(const Expr& expr) { + this->VisitExpr(expr); + return std::move(children_map_); + } + + void VisitExpr_(const CallNode* n) final { + static const Op& conv2d = Op::Get("nn.conv2d"); + ExprVisitor::VisitExpr_(n); + if (n->op.same_as(conv2d) && n->attrs.as()->groups == 1) { + children_map_[n->args[0]].push_back(n); + } + } + + private: + std::unordered_map, NodeHash, NodeEqual> children_map_; +}; + +std::tuple TransformWeight(std::vector convolutions) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const CallNode* n : convolutions) { + weights.push_back(n->args[1]); + auto channels = as_const_int(n->attrs.as()->channels); + CHECK(channels); + num_filters += *channels; + } + return std::tuple{ MakeConcatenate(TupleNode::make(weights), 0), + MakeConstScalar(Int(32), num_filters) }; +} + +// Two 2d convolutions can be combined if they have the same attributes or only have +// different output channels. +bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { + AttrsEqual eq; + return eq(a.strides, b.strides) && + eq(a.padding, b.padding) && + eq(a.dilation, b.dilation) && + eq(a.groups, b.groups) && + eq(a.kernel_size, b.kernel_size) && + eq(a.data_layout, b.data_layout) && + eq(a.weight_layout, b.weight_layout) && + eq(a.out_dtype, b.out_dtype) && + eq(a.out_layout, b.out_layout); +} + +Expr MakeFoldedConv2D(const Expr& data, const std::vector& convolutions) { + static const Op& conv2d = Op::Get("nn.conv2d"); + + Expr new_weight; + IndexExpr new_channels; + std::tie(new_weight, new_channels) = TransformWeight(convolutions); + + const CallNode* group_root = *(convolutions).begin(); + auto attrs = group_root->attrs.as(); + auto new_attrs = make_node(); + new_attrs->strides = attrs->strides; + new_attrs->padding = attrs->padding; + new_attrs->dilation = attrs->dilation; + new_attrs->groups = attrs->groups; + new_attrs->kernel_size = attrs->kernel_size; + new_attrs->data_layout = attrs->data_layout; + new_attrs->weight_layout = attrs->weight_layout; + new_attrs->out_layout = attrs->out_layout; + new_attrs->out_dtype = attrs->out_dtype; + new_attrs->channels = new_channels; + + return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); +} + +Expr FoldConv2D(const Expr& expr) { + // data -> array of conv2d with the same input + auto children_map = SiblingConv2DFinder().Find(expr); + Map subst_map; + + for (const auto& pair : children_map) { + Expr data = pair.first; + std::vector children = pair.second; + + if (children.size() < 2) + continue; + + std::vector group_ids(children.size()); + std::vector> groups; + + for (size_t i = 0; i < children.size(); i++) { + const CallNode* n = children[i]; + auto args = n->attrs.as(); + + // assign a group id or create a new group for each conv2d + auto it = + std::find_if(groups.begin(), groups.end(), + [&](std::vector group) { + const CallNode* group_root = *(group.begin()); + auto group_args = group_root->attrs.as(); + return IsCompatibleConv2D(*args, *group_args); + }); + + if (it != groups.end()) { + auto group_id = std::distance(groups.begin(), it); + group_ids[i] = group_id; + groups[group_id].push_back(n); + } else { + group_ids[i] = groups.size(); + groups.emplace_back(std::vector{n}); + } + } + + for (const auto& convs : groups) { + if (convs.size() < 2) { + continue; + } + auto new_conv2d = MakeFoldedConv2D(data, convs); + + int64_t start = 0; + // replace original conv2d with slice of output of the new conv2d + for (const auto& conv2d : convs) { + auto params = conv2d->attrs.as(); + auto channels = as_const_int(params->channels); + CHECK(channels); + auto indices = MakeConstantArrayFromRange(Int(64), start, start + *channels); + auto take = MakeTake(new_conv2d, indices, 1); + start += *channels; + subst_map.Set(GetRef(conv2d), take); + } + } + } + + return ExprSubst(expr, std::move(subst_map)); +} + +TVM_REGISTER_API("relay._ir_pass.FoldConv2D") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = FoldConv2D(args[0]); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 1c855d9a53cb..6270b83e1988 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -13,7 +13,6 @@ #include #include "../op/layout.h" - namespace tvm { namespace relay { @@ -136,6 +135,20 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { return ConstantNode::make(arr); } +template::value>::type> +inline Constant MakeConstantArrayFromRange(DataType dtype, T start, T end, T step = 1) { + CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; + CHECK(step); + CHECK_GE((end - start) / step, 0); + runtime::NDArray arr = runtime::NDArray::Empty({(int64_t)(end - start) / step}, + Type2TVMType(dtype), {kDLCPU, 0}); + for (auto *data = static_cast(arr->data); (step > 0) ? (start < end) : (start > end); + start += step, data++) { + *data = start; + } + return ConstantNode::make(arr); +} + inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); @@ -172,6 +185,10 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } +Expr MakeConcatenate(Expr data, int axis); + +Expr MakeTake(Expr data, Expr indices, Integer axis); + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_PATTERN_UTIL_H_ diff --git a/tests/python/relay/test_pass_fold_conv2d.py b/tests/python/relay/test_pass_fold_conv2d.py new file mode 100644 index 000000000000..ef716f707900 --- /dev/null +++ b/tests/python/relay/test_pass_fold_conv2d.py @@ -0,0 +1,65 @@ +from tvm import relay +import numpy as np + + +def test_fold_conv2d(): + """Simple testcase.""" + def before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + args = [x, w1, w2, w3, w4] + y1 = relay.nn.conv2d(x, w1, + channels=channels1, + kernel_size=(3, 3), + padding=(1, 1)) + y2 = relay.nn.conv2d(x, w2, + channels=channels2, + kernel_size=(3, 3), + padding=(1, 1)) + # y3 is not foldable + y3 = relay.nn.conv2d(x, w3, + channels=channels3, + kernel_size=(1, 1), + padding=(1, 1)) + y4 = relay.nn.conv2d(x, w4, + channels=channels4, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + # use a fixed order of args so alpha equal check can pass + args = [x, w1, w2, w3, w4] + w = relay.concatenate((w1, w2, w4), axis=0) + y = relay.nn.conv2d(x, w, + channels=channels1 + channels2 + channels4, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.take(y, relay.const(np.arange(channels1)), axis=1) + y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2)), axis=1) + y3 = relay.nn.conv2d(x, w3, + channels=channels3, + kernel_size=(1, 1), + padding=(1, 1)) + y4 = relay.take(y, relay.const(np.arange(channels1 + channels2, + channels1 + channels2 + channels4)), axis=1) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def check(channels1, channels2, channels3, channels4): + x = relay.var("x") + w1 = relay.var("w1") + w2 = relay.var("w2") + w3 = relay.var("w3") + w4 = relay.var("w4") + + y_before = before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + y = relay.ir_pass.fold_conv2d(y_before) + y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check(4, 4, 4, 4) + check(4, 8, 4, 7) + + +if __name__ == "__main__": + test_fold_conv2d() From c1b73c1f5a506ca4dc3d177aa58366944e5eb8bf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 12 Nov 2018 12:35:17 +0800 Subject: [PATCH 02/15] Fix test on i386 --- tests/python/relay/test_pass_fold_conv2d.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_pass_fold_conv2d.py b/tests/python/relay/test_pass_fold_conv2d.py index ef716f707900..38733dc3cdfa 100644 --- a/tests/python/relay/test_pass_fold_conv2d.py +++ b/tests/python/relay/test_pass_fold_conv2d.py @@ -34,14 +34,14 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): channels=channels1 + channels2 + channels4, kernel_size=(3, 3), padding=(1, 1)) - y1 = relay.take(y, relay.const(np.arange(channels1)), axis=1) - y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2)), axis=1) + y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1) + y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1) y3 = relay.nn.conv2d(x, w3, channels=channels3, kernel_size=(1, 1), padding=(1, 1)) y4 = relay.take(y, relay.const(np.arange(channels1 + channels2, - channels1 + channels2 + channels4)), axis=1) + channels1 + channels2 + channels4, dtype='int64')), axis=1) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) From febef4e6527b2169ac0701b31aa4996384bf2fec Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 12 Nov 2018 13:31:03 +0800 Subject: [PATCH 03/15] Update comments --- python/tvm/relay/ir_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 1ba2bb21a067..fd022ae2261a 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -305,7 +305,7 @@ def fold_conv2d(expr): Returns ------- transformed_expr : tvm.relay.Expr - Transformed expression, containing fused result. + Transformed expression, containing folded conv2d. """ return _ir_pass.FoldConv2D(expr) >>>>>>> Add FoldConv2D pass From aaade9b9b5228eb34d4c6e6c0b64dbd055e32938 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 12 Nov 2018 13:35:12 +0800 Subject: [PATCH 04/15] Fix style --- src/relay/pass/expr_subst.cc | 2 +- src/relay/pass/expr_subst.h | 2 +- src/relay/pass/fold_conv2d.cc | 47 ++++++++++++++++++----------------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index 3e342dee5061..bac66bc0acf1 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -27,7 +27,7 @@ class ExprSubstituter : public ExprMutator { }; Expr ExprSubst(const Expr& expr, tvm::Map subst_map) { - return ExprSubstituter(std::move(subst_map)).Mutate(expr); + return ExprSubstituter(std::move(subst_map)).Mutate(expr); } } // namespace relay diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index 7656baba1fa6..02f4179dae66 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -10,7 +10,7 @@ namespace tvm { namespace relay { - Expr ExprSubst(const Expr& expr, tvm::Map subst_map); +Expr ExprSubst(const Expr& expr, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/fold_conv2d.cc index e4ee45d3dbe0..e49ba6706f01 100644 --- a/src/relay/pass/fold_conv2d.cc +++ b/src/relay/pass/fold_conv2d.cc @@ -8,6 +8,9 @@ * This pass replaces convolutions that share the same input node and the same arguments (except * that the number of output channels can be different) with a single convolution. The weight of * the new 2d convolution is the concatenation of the original weights. + * + * This prevents launching multiple kernels in networks with multiple convolution branches, such + * as Inception block. */ #include @@ -22,8 +25,8 @@ namespace relay { class SiblingConv2DFinder : public ExprVisitor { public: - std::unordered_map, NodeHash, NodeEqual> - Find(const Expr& expr) { + std::unordered_map, NodeHash, NodeEqual> Find( + const Expr& expr) { this->VisitExpr(expr); return std::move(children_map_); } @@ -49,23 +52,23 @@ std::tuple TransformWeight(std::vector convolu CHECK(channels); num_filters += *channels; } - return std::tuple{ MakeConcatenate(TupleNode::make(weights), 0), - MakeConstScalar(Int(32), num_filters) }; + return std::tuple{MakeConcatenate(TupleNode::make(weights), 0), + MakeConstScalar(Int(32), num_filters)}; } // Two 2d convolutions can be combined if they have the same attributes or only have // different output channels. bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { - AttrsEqual eq; - return eq(a.strides, b.strides) && - eq(a.padding, b.padding) && - eq(a.dilation, b.dilation) && - eq(a.groups, b.groups) && - eq(a.kernel_size, b.kernel_size) && - eq(a.data_layout, b.data_layout) && - eq(a.weight_layout, b.weight_layout) && - eq(a.out_dtype, b.out_dtype) && - eq(a.out_layout, b.out_layout); + AttrsEqual eq; + return eq(a.strides, b.strides) && + eq(a.padding, b.padding) && + eq(a.dilation, b.dilation) && + eq(a.groups, b.groups) && + eq(a.kernel_size, b.kernel_size) && + eq(a.data_layout, b.data_layout) && + eq(a.weight_layout, b.weight_layout) && + eq(a.out_dtype, b.out_dtype) && + eq(a.out_layout, b.out_layout); } Expr MakeFoldedConv2D(const Expr& data, const std::vector& convolutions) { @@ -101,8 +104,7 @@ Expr FoldConv2D(const Expr& expr) { Expr data = pair.first; std::vector children = pair.second; - if (children.size() < 2) - continue; + if (children.size() < 2) continue; std::vector group_ids(children.size()); std::vector> groups; @@ -112,13 +114,12 @@ Expr FoldConv2D(const Expr& expr) { auto args = n->attrs.as(); // assign a group id or create a new group for each conv2d - auto it = - std::find_if(groups.begin(), groups.end(), - [&](std::vector group) { - const CallNode* group_root = *(group.begin()); - auto group_args = group_root->attrs.as(); - return IsCompatibleConv2D(*args, *group_args); - }); + auto it = std::find_if(groups.begin(), groups.end(), + [&](const std::vector& group) { + const CallNode* group_root = *(group.begin()); + auto group_args = group_root->attrs.as(); + return IsCompatibleConv2D(*args, *group_args); + }); if (it != groups.end()) { auto group_id = std::distance(groups.begin(), it); From 0318eeffe2740b1d6b85e6a61d8e1f1295048b78 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 12 Nov 2018 17:37:58 +0800 Subject: [PATCH 05/15] Minor enhancement --- src/relay/pass/fold_conv2d.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/fold_conv2d.cc index e49ba6706f01..aa33b7d3e4f5 100644 --- a/src/relay/pass/fold_conv2d.cc +++ b/src/relay/pass/fold_conv2d.cc @@ -52,8 +52,8 @@ std::tuple TransformWeight(std::vector convolu CHECK(channels); num_filters += *channels; } - return std::tuple{MakeConcatenate(TupleNode::make(weights), 0), - MakeConstScalar(Int(32), num_filters)}; + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), 0), + MakeConstScalar(Int(32), num_filters)); } // Two 2d convolutions can be combined if they have the same attributes or only have @@ -78,7 +78,7 @@ Expr MakeFoldedConv2D(const Expr& data, const std::vector& conv IndexExpr new_channels; std::tie(new_weight, new_channels) = TransformWeight(convolutions); - const CallNode* group_root = *(convolutions).begin(); + const CallNode* group_root = convolutions[0]; auto attrs = group_root->attrs.as(); auto new_attrs = make_node(); new_attrs->strides = attrs->strides; From 002b4e56bb2abc91ffe3da1ec0beec70c1d66fe5 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 13 Nov 2018 10:26:47 +0800 Subject: [PATCH 06/15] Use proper index in data_layout and weight_layout --- src/relay/pass/fold_conv2d.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/fold_conv2d.cc index aa33b7d3e4f5..ed029551373f 100644 --- a/src/relay/pass/fold_conv2d.cc +++ b/src/relay/pass/fold_conv2d.cc @@ -52,7 +52,9 @@ std::tuple TransformWeight(std::vector convolu CHECK(channels); num_filters += *channels; } - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), 0), + auto index = convolutions[0]->attrs.as()->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), MakeConstScalar(Int(32), num_filters)); } @@ -144,7 +146,9 @@ Expr FoldConv2D(const Expr& expr) { auto channels = as_const_int(params->channels); CHECK(channels); auto indices = MakeConstantArrayFromRange(Int(64), start, start + *channels); - auto take = MakeTake(new_conv2d, indices, 1); + auto channel_index = params->data_layout.find('C'); + CHECK_NE(channel_index, std::string::npos); + auto take = MakeTake(new_conv2d, indices, channel_index); start += *channels; subst_map.Set(GetRef(conv2d), take); } From f923a162480e294829a2f19214acd3133f3950d0 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 13 Nov 2018 10:33:39 +0800 Subject: [PATCH 07/15] Rename to CombineParallelConv2D --- python/tvm/relay/ir_pass.py | 5 ++--- .../{fold_conv2d.cc => combine_parallel_conv2d.cc} | 14 +++++++------- ...v2d.py => test_pass_combine_parallel_conv2d.py} | 6 +++--- 3 files changed, 12 insertions(+), 13 deletions(-) rename src/relay/pass/{fold_conv2d.cc => combine_parallel_conv2d.cc} (93%) rename tests/python/relay/{test_pass_fold_conv2d.py => test_pass_combine_parallel_conv2d.py} (95%) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index fd022ae2261a..8d54081d7090 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -294,7 +294,7 @@ def fuse_ops(expr, opt_level=1): return _ir_pass.FuseOps(expr, opt_level) -def fold_conv2d(expr): +def combine_parallel_conv2d(expr): """Fold multiple conv2d into one. Parameters @@ -307,5 +307,4 @@ def fold_conv2d(expr): transformed_expr : tvm.relay.Expr Transformed expression, containing folded conv2d. """ - return _ir_pass.FoldConv2D(expr) ->>>>>>> Add FoldConv2D pass + return _ir_pass.CombineParallelConv2D(expr) diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc similarity index 93% rename from src/relay/pass/fold_conv2d.cc rename to src/relay/pass/combine_parallel_conv2d.cc index ed029551373f..3d9fb5891c2a 100644 --- a/src/relay/pass/fold_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -1,9 +1,9 @@ /*! * Copyright (c) 2018 by Contributors * - * \file fold_conv2d.cc + * \file combine_parallel_conv2d.cc * - * \brief Fold multiple 2d convolutions into a single convolution. + * \brief Combine parallel 2d convolutions into a single convolution. * * This pass replaces convolutions that share the same input node and the same arguments (except * that the number of output channels can be different) with a single convolution. The weight of @@ -73,7 +73,7 @@ bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { eq(a.out_layout, b.out_layout); } -Expr MakeFoldedConv2D(const Expr& data, const std::vector& convolutions) { +Expr MakeCombinedConv2D(const Expr& data, const std::vector& convolutions) { static const Op& conv2d = Op::Get("nn.conv2d"); Expr new_weight; @@ -97,7 +97,7 @@ Expr MakeFoldedConv2D(const Expr& data, const std::vector& conv return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } -Expr FoldConv2D(const Expr& expr) { +Expr CombineParallelConv2D(const Expr& expr) { // data -> array of conv2d with the same input auto children_map = SiblingConv2DFinder().Find(expr); Map subst_map; @@ -137,7 +137,7 @@ Expr FoldConv2D(const Expr& expr) { if (convs.size() < 2) { continue; } - auto new_conv2d = MakeFoldedConv2D(data, convs); + auto new_conv2d = MakeCombinedConv2D(data, convs); int64_t start = 0; // replace original conv2d with slice of output of the new conv2d @@ -158,9 +158,9 @@ Expr FoldConv2D(const Expr& expr) { return ExprSubst(expr, std::move(subst_map)); } -TVM_REGISTER_API("relay._ir_pass.FoldConv2D") +TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = FoldConv2D(args[0]); + *ret = CombineParallelConv2D(args[0]); }); } // namespace relay diff --git a/tests/python/relay/test_pass_fold_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py similarity index 95% rename from tests/python/relay/test_pass_fold_conv2d.py rename to tests/python/relay/test_pass_combine_parallel_conv2d.py index 38733dc3cdfa..ce34f0caed89 100644 --- a/tests/python/relay/test_pass_fold_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -2,7 +2,7 @@ import numpy as np -def test_fold_conv2d(): +def test_combine_parallel_conv2d(): """Simple testcase.""" def before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): args = [x, w1, w2, w3, w4] @@ -53,7 +53,7 @@ def check(channels1, channels2, channels3, channels4): w4 = relay.var("w4") y_before = before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) - y = relay.ir_pass.fold_conv2d(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y_before) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) assert relay.ir_pass.alpha_equal(y, y_expected) @@ -62,4 +62,4 @@ def check(channels1, channels2, channels3, channels4): if __name__ == "__main__": - test_fold_conv2d() + test_combine_parallel_conv2d() From 1934fd20dbfe766eb06ec0820c90d66a36ddd8f8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 13 Nov 2018 15:38:37 +0800 Subject: [PATCH 08/15] Use unordered_map --- src/relay/pass/combine_parallel_conv2d.cc | 4 ++-- src/relay/pass/expr_subst.cc | 5 +++-- src/relay/pass/expr_subst.h | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 3d9fb5891c2a..60f643e436b2 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -100,7 +100,7 @@ Expr MakeCombinedConv2D(const Expr& data, const std::vector& co Expr CombineParallelConv2D(const Expr& expr) { // data -> array of conv2d with the same input auto children_map = SiblingConv2DFinder().Find(expr); - Map subst_map; + std::unordered_map subst_map; for (const auto& pair : children_map) { Expr data = pair.first; @@ -150,7 +150,7 @@ Expr CombineParallelConv2D(const Expr& expr) { CHECK_NE(channel_index, std::string::npos); auto take = MakeTake(new_conv2d, indices, channel_index); start += *channels; - subst_map.Set(GetRef(conv2d), take); + subst_map[GetRef(conv2d)] = take; } } } diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index bac66bc0acf1..586f748abef5 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -12,7 +12,8 @@ namespace relay { class ExprSubstituter : public ExprMutator { public: - explicit ExprSubstituter(tvm::Map subst_map) : subst_map_(subst_map) {} + explicit ExprSubstituter(std::unordered_map subst_map) + : subst_map_(subst_map) {} Expr VisitExpr(const Expr& expr) final { auto it = subst_map_.find(expr); @@ -26,7 +27,7 @@ class ExprSubstituter : public ExprMutator { tvm::Map subst_map_; }; -Expr ExprSubst(const Expr& expr, tvm::Map subst_map) { +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map) { return ExprSubstituter(std::move(subst_map)).Mutate(expr); } diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index 02f4179dae66..67892b3a0af7 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -6,11 +6,12 @@ #ifndef TVM_RELAY_PASS_EXPR_SUBST_H_ #define TVM_RELAY_PASS_EXPR_SUBST_H_ #include +#include namespace tvm { namespace relay { -Expr ExprSubst(const Expr& expr, tvm::Map subst_map); +Expr ExprSubst(const Expr& expr, std::unordered_map subst_map); } // namespace relay } // namespace tvm From c638625667d1eecdf815b6f01fa50d9ceb311ebf Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 13 Nov 2018 19:50:48 +0800 Subject: [PATCH 09/15] Get channels info from type instead of attrs --- src/relay/pass/combine_parallel_conv2d.cc | 54 ++++++++--------- src/relay/pass/pattern_util.h | 14 +++++ .../test_pass_combine_parallel_conv2d.py | 58 +++++++------------ 3 files changed, 64 insertions(+), 62 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 60f643e436b2..61dfc632086f 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -48,9 +48,8 @@ std::tuple TransformWeight(std::vector convolu Array weights; for (const CallNode* n : convolutions) { weights.push_back(n->args[1]); - auto channels = as_const_int(n->attrs.as()->channels); - CHECK(channels); - num_filters += *channels; + auto channels = GetConv2DSuperChannelsDim(n); + num_filters += channels; } auto index = convolutions[0]->attrs.as()->weight_layout.find('O'); CHECK_NE(index, std::string::npos); @@ -60,17 +59,26 @@ std::tuple TransformWeight(std::vector convolu // Two 2d convolutions can be combined if they have the same attributes or only have // different output channels. -bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { +bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { AttrsEqual eq; - return eq(a.strides, b.strides) && - eq(a.padding, b.padding) && - eq(a.dilation, b.dilation) && - eq(a.groups, b.groups) && - eq(a.kernel_size, b.kernel_size) && - eq(a.data_layout, b.data_layout) && - eq(a.weight_layout, b.weight_layout) && - eq(a.out_dtype, b.out_dtype) && - eq(a.out_layout, b.out_layout); + static const Layout kOIHW("OIHW"); + auto attrs_a = a->attrs.as(); + auto attrs_b = b->attrs.as(); + auto tweight_a = a->args[1]->type_as(); + auto tweight_b = b->args[1]->type_as(); + auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); + auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); + + return eq(attrs_a->strides, attrs_b->strides) && + eq(attrs_a->padding, attrs_b->padding) && + eq(attrs_a->dilation, attrs_b->dilation) && + eq(attrs_a->groups, attrs_b->groups) && + eq(attrs_a->data_layout, attrs_b->data_layout) && + eq(attrs_a->weight_layout, attrs_b->weight_layout) && + eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(attrs_a->out_layout, attrs_b->out_layout) && + eq(shape_a[2], shape_b[2]) && + eq(shape_a[3], shape_b[3]); } Expr MakeCombinedConv2D(const Expr& data, const std::vector& convolutions) { @@ -113,14 +121,11 @@ Expr CombineParallelConv2D(const Expr& expr) { for (size_t i = 0; i < children.size(); i++) { const CallNode* n = children[i]; - auto args = n->attrs.as(); // assign a group id or create a new group for each conv2d auto it = std::find_if(groups.begin(), groups.end(), [&](const std::vector& group) { - const CallNode* group_root = *(group.begin()); - auto group_args = group_root->attrs.as(); - return IsCompatibleConv2D(*args, *group_args); + return IsCompatibleConv2D(n, group[0]); }); if (it != groups.end()) { @@ -134,22 +139,19 @@ Expr CombineParallelConv2D(const Expr& expr) { } for (const auto& convs : groups) { - if (convs.size() < 2) { - continue; - } - auto new_conv2d = MakeCombinedConv2D(data, convs); + if (convs.size() < 2) continue; + auto new_conv2d = MakeCombinedConv2D(data, convs); int64_t start = 0; // replace original conv2d with slice of output of the new conv2d - for (const auto& conv2d : convs) { + for (const CallNode* conv2d : convs) { auto params = conv2d->attrs.as(); - auto channels = as_const_int(params->channels); - CHECK(channels); - auto indices = MakeConstantArrayFromRange(Int(64), start, start + *channels); + auto channels = GetConv2DSuperChannelsDim(conv2d); + auto indices = MakeConstantArrayFromRange(Int(64), start, start + channels); auto channel_index = params->data_layout.find('C'); CHECK_NE(channel_index, std::string::npos); auto take = MakeTake(new_conv2d, indices, channel_index); - start += *channels; + start += channels; subst_map[GetRef(conv2d)] = take; } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 6270b83e1988..3775b77d6840 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "../op/layout.h" namespace tvm { @@ -119,6 +120,19 @@ inline bool IsDepthwiseConv2D(const Call& call, is_const_int(wshape[1], 1); } +/*! + * \brief Get super-dimension of output channels of conv2d + * \param call The conv2d call. + * \return Super-dimension size of output channels of conv2d. + */ +inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { + auto param = call->attrs.as(); + auto tweight = call->args[1]->type_as(); + auto index = param->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + auto channels = as_const_int(tweight->shape[index]); + return *channels; +} /*! * \brief Create a Constant with a scalar diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index ce34f0caed89..25c788bc9ff2 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -4,25 +4,13 @@ def test_combine_parallel_conv2d(): """Simple testcase.""" - def before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): + def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] - y1 = relay.nn.conv2d(x, w1, - channels=channels1, - kernel_size=(3, 3), - padding=(1, 1)) - y2 = relay.nn.conv2d(x, w2, - channels=channels2, - kernel_size=(3, 3), - padding=(1, 1)) + y1 = relay.nn.conv2d(x, w1) + y2 = relay.nn.conv2d(x, w2) # y3 is not foldable - y3 = relay.nn.conv2d(x, w3, - channels=channels3, - kernel_size=(1, 1), - padding=(1, 1)) - y4 = relay.nn.conv2d(x, w4, - channels=channels4, - kernel_size=(3, 3), - padding=(1, 1)) + y3 = relay.nn.conv2d(x, w3) + y4 = relay.nn.conv2d(x, w4) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) @@ -30,35 +18,33 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): # use a fixed order of args so alpha equal check can pass args = [x, w1, w2, w3, w4] w = relay.concatenate((w1, w2, w4), axis=0) - y = relay.nn.conv2d(x, w, - channels=channels1 + channels2 + channels4, - kernel_size=(3, 3), - padding=(1, 1)) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1) y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1) - y3 = relay.nn.conv2d(x, w3, - channels=channels3, - kernel_size=(1, 1), - padding=(1, 1)) + y3 = relay.nn.conv2d(x, w3) y4 = relay.take(y, relay.const(np.arange(channels1 + channels2, channels1 + channels2 + channels4, dtype='int64')), axis=1) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) - def check(channels1, channels2, channels3, channels4): - x = relay.var("x") - w1 = relay.var("w1") - w2 = relay.var("w2") - w3 = relay.var("w3") - w4 = relay.var("w4") - - y_before = before(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) - y = relay.ir_pass.combine_parallel_conv2d(y_before) + def check(x_shape, channels1, channels2, channels3, channels4): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + w3 = relay.var("w3", shape=(channels3, in_c, 3, 3)) + w4 = relay.var("w4", shape=(channels4, in_c, 1, 1)) + + y_before = before(x, w1, w2, w3, w4) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) + y_expected = relay.ir_pass.infer_type(y_expected) assert relay.ir_pass.alpha_equal(y, y_expected) - check(4, 4, 4, 4) - check(4, 8, 4, 7) + check((1, 4, 16, 16), 4, 4, 4, 4) + check((1, 4, 16, 16), 4, 8, 4, 7) if __name__ == "__main__": From 818647947ae70e0b984c36f30de052d642bf3065 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 14 Nov 2018 17:30:35 +0800 Subject: [PATCH 10/15] Replace take with strided_slice --- src/relay/pass/combine_parallel_conv2d.cc | 24 ++++++++++++------- src/relay/pass/pattern_util.h | 16 +------------ .../test_pass_combine_parallel_conv2d.py | 10 ++++---- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 61dfc632086f..f128f93f800f 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -142,17 +142,25 @@ Expr CombineParallelConv2D(const Expr& expr) { if (convs.size() < 2) continue; auto new_conv2d = MakeCombinedConv2D(data, convs); - int64_t start = 0; + int64_t index = 0; // replace original conv2d with slice of output of the new conv2d for (const CallNode* conv2d : convs) { auto params = conv2d->attrs.as(); - auto channels = GetConv2DSuperChannelsDim(conv2d); - auto indices = MakeConstantArrayFromRange(Int(64), start, start + channels); - auto channel_index = params->data_layout.find('C'); - CHECK_NE(channel_index, std::string::npos); - auto take = MakeTake(new_conv2d, indices, channel_index); - start += channels; - subst_map[GetRef(conv2d)] = take; + int64_t channels = GetConv2DSuperChannelsDim(conv2d); + size_t channel_pos = params->data_layout.find('C'); + CHECK_NE(channel_pos, std::string::npos); + Array begin; + Array end; + for (size_t i = 0; i < channel_pos; i++) { + begin.push_back(0); + end.push_back(NullValue()); + } + begin.push_back(index); + index += channels; + end.push_back(index); + auto slice = MakeStridedSlice(new_conv2d, std::move(begin), std::move(end), + Array{}); + subst_map[GetRef(conv2d)] = slice; } } } diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 3775b77d6840..550d008634ff 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -149,20 +149,6 @@ inline Constant MakeConstantScalar(DataType dtype, T value) { return ConstantNode::make(arr); } -template::value>::type> -inline Constant MakeConstantArrayFromRange(DataType dtype, T start, T end, T step = 1) { - CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; - CHECK(step); - CHECK_GE((end - start) / step, 0); - runtime::NDArray arr = runtime::NDArray::Empty({(int64_t)(end - start) / step}, - Type2TVMType(dtype), {kDLCPU, 0}); - for (auto *data = static_cast(arr->data); (step > 0) ? (start < end) : (start > end); - start += step, data++) { - *data = start; - } - return ConstantNode::make(arr); -} - inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); @@ -201,7 +187,7 @@ inline Expr ReshapeLike(Expr lhs, Expr rhs) { Expr MakeConcatenate(Expr data, int axis); -Expr MakeTake(Expr data, Expr indices, Integer axis); +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 25c788bc9ff2..75270d52f70e 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -8,7 +8,7 @@ def before(x, w1, w2, w3, w4): args = [x, w1, w2, w3, w4] y1 = relay.nn.conv2d(x, w1) y2 = relay.nn.conv2d(x, w2) - # y3 is not foldable + # y3 cannot be combined y3 = relay.nn.conv2d(x, w3) y4 = relay.nn.conv2d(x, w4) y = relay.Tuple((y1, y2, y3, y4)) @@ -19,11 +19,11 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): args = [x, w1, w2, w3, w4] w = relay.concatenate((w1, w2, w4), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) - y1 = relay.take(y, relay.const(np.arange(channels1, dtype='int64')), axis=1) - y2 = relay.take(y, relay.const(np.arange(channels1, channels1 + channels2, dtype='int64')), axis=1) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) y3 = relay.nn.conv2d(x, w3) - y4 = relay.take(y, relay.const(np.arange(channels1 + channels2, - channels1 + channels2 + channels4, dtype='int64')), axis=1) + y4 = relay.strided_slice(y, [0, channels1 + channels2], + [None, channels1 + channels2 + channels4]) y = relay.Tuple((y1, y2, y3, y4)) return relay.Function(args, y) From 74ca32de4f738c4f1bfc815cd88b7291d585363e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Fri, 16 Nov 2018 13:52:40 +0800 Subject: [PATCH 11/15] Combine subsequent elemwise/broadcast ops --- python/tvm/relay/build_module.py | 5 + python/tvm/relay/ir_pass.py | 2 +- src/relay/pass/combine_parallel_conv2d.cc | 410 ++++++++++++------ .../test_pass_combine_parallel_conv2d.py | 87 ++++ 4 files changed, 373 insertions(+), 131 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 557e4edac681..5a45ac276de9 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,6 +13,7 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, + "CombineParallelConv2D": 1, "OpFusion": 1, "FoldConstant": 2, "FoldScaleAxis": 3, @@ -144,6 +145,10 @@ def optimize(func, params=None): func = ir_pass.infer_type(func) func = ir_pass.simplify_inference(func) + if cfg.pass_enabled("CombineParallelConv2D"): + func = ir_pass.infer_type(func) + func = ir_pass.combine_parallel_conv2d(func) + if cfg.pass_enabled("FoldScaleAxis"): func = ir_pass.infer_type(func) func = ir_pass.backward_fold_scale_axis(func) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8d54081d7090..ef0a59cd3f6d 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -305,6 +305,6 @@ def combine_parallel_conv2d(expr): Returns ------- transformed_expr : tvm.relay.Expr - Transformed expression, containing folded conv2d. + Transformed expression """ return _ir_pass.CombineParallelConv2D(expr) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index f128f93f800f..d1b065a30782 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -2,176 +2,326 @@ * Copyright (c) 2018 by Contributors * * \file combine_parallel_conv2d.cc - * * \brief Combine parallel 2d convolutions into a single convolution. * - * This pass replaces convolutions that share the same input node and the same arguments (except - * that the number of output channels can be different) with a single convolution. The weight of - * the new 2d convolution is the concatenation of the original weights. + * This pass replaces convolutions that share the same input node and the same + * arguments (except that the number of output channels can be different) with a + * single convolution. The weight of the new 2d convolution is the concatenation + * of the original weights. Elemwise and broadcast ops following conv2d are also + * combined if possible. * - * This prevents launching multiple kernels in networks with multiple convolution branches, such - * as Inception block. + * This prevents launching multiple kernels in networks with multiple + * convolution branches, such as Inception block. */ #include #include #include #include +#include +#include +#include +#include #include "./expr_subst.h" #include "./pattern_util.h" namespace tvm { namespace relay { -class SiblingConv2DFinder : public ExprVisitor { +using Group = std::vector>; + +/* + Find parallel branches starting with conv2d as shown below and then group branches by kernel + shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops. + Intermediate nodes have exactly one successor. It is possible that branches meet at a point, + which should be handled in ParallelConv2DCombiner. + + data + / \ + conv2d conv2d + | | + op op + | | +*/ +class BranchGroupFinder : private ExprVisitor { public: - std::unordered_map, NodeHash, NodeEqual> Find( - const Expr& expr) { + std::vector Find(const Expr& expr) { this->VisitExpr(expr); - return std::move(children_map_); + + std::vector groups; + for (const auto& root : conv_roots_) { + const auto& convs = children_map_.at(root); + for (const CallNode* conv : convs) { + auto&& branch = CreateBranch(conv); + // add the branch to a group, or create a new group + auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { + CHECK(group.size() && group[0].size()); + return IsCompatibleConv2D(conv, group[0][0]); + }); + if (it != groups.end()) { + it->push_back(branch); + } else { + groups.emplace_back(); + // each group has at least one branch + groups.back().push_back(branch); + } + } + } + return groups; + } + + private: + std::unordered_set conv_roots_; + std::unordered_map, NodeHash, NodeEqual> children_map_; + + // Two 2d convolutions can be combined if they have the same attributes or + // only have different output channels. + bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { + AttrsEqual eq; + static const Layout kOIHW("OIHW"); + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + CHECK(attrs_a); + CHECK(attrs_b); + const auto* tweight_a = a->args[1]->type_as(); + const auto* tweight_b = b->args[1]->type_as(); + const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); + const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); + + return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && + eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && + eq(attrs_a->data_layout, attrs_b->data_layout) && + eq(attrs_a->weight_layout, attrs_b->weight_layout) && + eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && + eq(shape_a[3], shape_b[3]); + } + + // Create a branch starting from conv2d. + std::vector CreateBranch(const CallNode* conv) { + static auto fpattern = Op::GetAttr("TOpPattern"); + // each branch has at least one element, the first element is always conv2d + std::vector branch{conv}; + auto it = children_map_.find(GetRef(branch.back())); + while (it != children_map_.end() && it->second.size() == 1) { + const CallNode* call = it->second[0]; + auto pattern = fpattern[Downcast(call->op)]; + if (pattern <= kBroadcast) { + branch.push_back(it->second[0]); + it = children_map_.find(GetRef(branch.back())); + } else { + break; + } + } + return branch; } void VisitExpr_(const CallNode* n) final { static const Op& conv2d = Op::Get("nn.conv2d"); ExprVisitor::VisitExpr_(n); if (n->op.same_as(conv2d) && n->attrs.as()->groups == 1) { + conv_roots_.insert(n->args[0]); children_map_[n->args[0]].push_back(n); + } else { + for (size_t i = 0; i < n->args.size(); i++) { + children_map_[n->args[i]].push_back(n); + } } } +}; + +class ParallelConv2DCombiner { + public: + Expr Combine(const Expr& expr) { + auto groups = BranchGroupFinder().Find(expr); + for (const Group& group : groups) { + if (group.size() < 2) continue; + CombineBranches(group); + } + return ExprSubst(expr, std::move(subst_map_)); + } private: - std::unordered_map, NodeHash, NodeEqual> children_map_; -}; + std::unordered_map subst_map_; -std::tuple TransformWeight(std::vector convolutions) { - int64_t num_filters = 0; // number of filters of the transformed weight - Array weights; - for (const CallNode* n : convolutions) { - weights.push_back(n->args[1]); - auto channels = GetConv2DSuperChannelsDim(n); - num_filters += channels; + std::tuple TransformWeight(const Group& branches) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const auto& branch : branches) { + auto conv2d = branch[0]; + weights.push_back(conv2d->args[1]); + auto channels = GetConv2DSuperChannelsDim(conv2d); + num_filters += channels; + } + auto index = branches[0][0]->attrs.as()->weight_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + MakeConstScalar(Int(32), num_filters)); } - auto index = convolutions[0]->attrs.as()->weight_layout.find('O'); - CHECK_NE(index, std::string::npos); - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), - MakeConstScalar(Int(32), num_filters)); -} - -// Two 2d convolutions can be combined if they have the same attributes or only have -// different output channels. -bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { - AttrsEqual eq; - static const Layout kOIHW("OIHW"); - auto attrs_a = a->attrs.as(); - auto attrs_b = b->attrs.as(); - auto tweight_a = a->args[1]->type_as(); - auto tweight_b = b->args[1]->type_as(); - auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); - auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); - - return eq(attrs_a->strides, attrs_b->strides) && - eq(attrs_a->padding, attrs_b->padding) && - eq(attrs_a->dilation, attrs_b->dilation) && - eq(attrs_a->groups, attrs_b->groups) && - eq(attrs_a->data_layout, attrs_b->data_layout) && - eq(attrs_a->weight_layout, attrs_b->weight_layout) && - eq(attrs_a->out_dtype, attrs_b->out_dtype) && - eq(attrs_a->out_layout, attrs_b->out_layout) && - eq(shape_a[2], shape_b[2]) && - eq(shape_a[3], shape_b[3]); -} - -Expr MakeCombinedConv2D(const Expr& data, const std::vector& convolutions) { - static const Op& conv2d = Op::Get("nn.conv2d"); - - Expr new_weight; - IndexExpr new_channels; - std::tie(new_weight, new_channels) = TransformWeight(convolutions); - - const CallNode* group_root = convolutions[0]; - auto attrs = group_root->attrs.as(); - auto new_attrs = make_node(); - new_attrs->strides = attrs->strides; - new_attrs->padding = attrs->padding; - new_attrs->dilation = attrs->dilation; - new_attrs->groups = attrs->groups; - new_attrs->kernel_size = attrs->kernel_size; - new_attrs->data_layout = attrs->data_layout; - new_attrs->weight_layout = attrs->weight_layout; - new_attrs->out_layout = attrs->out_layout; - new_attrs->out_dtype = attrs->out_dtype; - new_attrs->channels = new_channels; - - return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); -} - -Expr CombineParallelConv2D(const Expr& expr) { - // data -> array of conv2d with the same input - auto children_map = SiblingConv2DFinder().Find(expr); - std::unordered_map subst_map; - - for (const auto& pair : children_map) { - Expr data = pair.first; - std::vector children = pair.second; - - if (children.size() < 2) continue; - - std::vector group_ids(children.size()); - std::vector> groups; - - for (size_t i = 0; i < children.size(); i++) { - const CallNode* n = children[i]; - - // assign a group id or create a new group for each conv2d - auto it = std::find_if(groups.begin(), groups.end(), - [&](const std::vector& group) { - return IsCompatibleConv2D(n, group[0]); - }); - - if (it != groups.end()) { - auto group_id = std::distance(groups.begin(), it); - group_ids[i] = group_id; - groups[group_id].push_back(n); - } else { - group_ids[i] = groups.size(); - groups.emplace_back(std::vector{n}); - } + + Call MakeCombinedConv2D(const Group& branches) { + static const Op& conv2d = Op::Get("nn.conv2d"); + Expr data = branches[0][0]->args[0]; + Expr new_weight; + IndexExpr new_channels; + std::tie(new_weight, new_channels) = TransformWeight(branches); + + const CallNode* group_root = branches[0][0]; + const auto* attrs = group_root->attrs.as(); + CHECK(attrs); + const auto new_attrs = make_node(); + new_attrs->strides = attrs->strides; + new_attrs->padding = attrs->padding; + new_attrs->dilation = attrs->dilation; + new_attrs->groups = attrs->groups; + new_attrs->kernel_size = attrs->kernel_size; + new_attrs->data_layout = attrs->data_layout; + new_attrs->weight_layout = attrs->weight_layout; + new_attrs->out_layout = attrs->out_layout; + new_attrs->out_dtype = attrs->out_dtype; + new_attrs->channels = new_channels; + + return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + auto toutput_a = a->type_as(); + auto toutput_b = b->type_as(); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + // Position of the 'C' dimension in the argument + int64_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); + + // Channel super-dimension shoule be present and not broadcasted + if ((arg_channel_pos < 0) || + !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || + !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (i == static_cast(arg_channel_pos)) continue; + if (!eq(ta->shape[i], tb->shape[i])) + return false; } + return true; + } + + // Check if ops in depth-th level can be combined + bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { + const CallNode* call = branches[0][depth]; + AttrsEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const std::vector& branch = *it; + if (!branch[depth]->op.same_as(call->op) || + !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } - for (const auto& convs : groups) { - if (convs.size() < 2) continue; - - auto new_conv2d = MakeCombinedConv2D(data, convs); - int64_t index = 0; - // replace original conv2d with slice of output of the new conv2d - for (const CallNode* conv2d : convs) { - auto params = conv2d->attrs.as(); - int64_t channels = GetConv2DSuperChannelsDim(conv2d); - size_t channel_pos = params->data_layout.find('C'); - CHECK_NE(channel_pos, std::string::npos); - Array begin; - Array end; - for (size_t i = 0; i < channel_pos; i++) { - begin.push_back(0); - end.push_back(NullValue()); + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) + return false; + + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; + + if (!IsArgCompatible(call, branch[depth], i, channel_pos) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; } - begin.push_back(index); - index += channels; - end.push_back(index); - auto slice = MakeStridedSlice(new_conv2d, std::move(begin), std::move(end), - Array{}); - subst_map[GetRef(conv2d)] = slice; } } + return true; } - return ExprSubst(expr, std::move(subst_map)); -} + // Combine args and make the combined CallNode + Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + size_t ndim = call->type_as()->shape.size(); + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + size_t arg_ndim = call->args[i]->type_as()->shape.size(); + size_t arg_channel_pos = channel_pos - ndim + arg_ndim; + Array tuple; + for (const auto& branch : branches) { + tuple.push_back(branch[depth]->args[i]); + } + auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); + new_args.push_back(std::move(concat)); + } + return CallNode::make(call->op, new_args, call->attrs, {}); + } + + // Replace output of each branch with slices of the combined output + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, + size_t channel_pos) { + int64_t index = 0; + for (const auto& branch : branches) { + const CallNode* conv2d = branch[0]; + int64_t channels = GetConv2DSuperChannelsDim(conv2d); + Array begin; + Array end; + for (size_t i = 0; i < channel_pos; i++) { + begin.push_back(0); + end.push_back(NullValue()); + } + begin.push_back(index); + index += channels; + end.push_back(index); + auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); + subst_map_[GetRef(branch[depth])] = slice; + } + } + + // Combine branches in a group. Conv2d in different branches in the same group are safe to + // combined. Subsequent ops may or may not be combined. We start from conv2d and try to + // combine ops from all branches in the same depth. + void CombineBranches(const Group& branches) { + Call combined = MakeCombinedConv2D(branches); + auto conv_param = combined->attrs.as(); + const std::string& layout = + conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; + size_t channel_pos = layout.find('C'); + CHECK_NE(channel_pos, std::string::npos); + auto it = std::min_element(branches.begin(), branches.end(), + [](const std::vector& branch_a, + const std::vector& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the conv2d + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; + } + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, channel_pos, parent_index)) break; + combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); + } + UpdateGroupOutput(combined, branches, i - 1, channel_pos); + } +}; + +Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); } TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CombineParallelConv2D(args[0]); - }); + *ret = CombineParallelConv2D(args[0]); +}); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 75270d52f70e..31dfe095f682 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -47,5 +47,92 @@ def check(x_shape, channels1, channels2, channels3, channels4): check((1, 4, 16, 16), 4, 8, 4, 7) +def test_combine_parallel_conv2d_scale_relu(): + """Testcase of combining conv2d + scale + relu""" + def before(x, w1, w2, scale1, scale2, bias): + args = [x, w1, w2, scale1, scale2, bias] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y1 = relay.nn.relu(y1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y2 = relay.nn.relu(y2) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): + args = [x, w1, w2, scale1, scale2, bias] + w = relay.concatenate((w1, w2), axis=0) + scale = relay.concatenate((scale1, scale2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y = relay.multiply(y, scale) + y = relay.nn.relu(y) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y2 = relay.add(y2, bias) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(channels1, 1, 1)) + scale2 = relay.var("scale2", shape=(channels2, 1, 1)) + bias = relay.var("bias", shape=(channels2, 1, 1)) + y_before = before(x, w1, w2, scale1, scale2, bias) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + + +def test_combine_parallel_conv2d_scale(): + """Testcase of un-combinable scale""" + def before(x, w1, w2, scale1, scale2): + args = [x, w1, w2, scale1, scale2] + y1 = relay.nn.conv2d(x, w1) + y1 = relay.multiply(y1, scale1) + y2 = relay.nn.conv2d(x, w2) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, scale1, scale2, channels1, channels2): + args = [x, w1, w2, scale1, scale2] + w = relay.concatenate((w1, w2), axis=0) + y = relay.nn.conv2d(x, w, channels=channels1 + channels2) + y1 = relay.strided_slice(y, [0, 0], [None, channels1]) + y2 = relay.strided_slice(y, [0, channels1], [None, channels1 + channels2]) + y1 = relay.multiply(y1, scale1) + y2 = relay.multiply(y2, scale2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(x_shape, channels1, channels2): + x = relay.var("x", shape=x_shape) + in_c = x_shape[1] + w1 = relay.var("w1", shape=(channels1, in_c, 1, 1)) + w2 = relay.var("w2", shape=(channels2, in_c, 1, 1)) + scale1 = relay.var("scale1", shape=(1,)) + scale2 = relay.var("scale2", shape=(1,)) + y_before = before(x, w1, w2, scale1, scale2) + y = relay.ir_pass.infer_type(y_before) + y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.infer_type(y) + y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) + y_expected = relay.ir_pass.infer_type(y_expected) + assert relay.ir_pass.alpha_equal(y, y_expected) + + check((1, 4, 16, 16), 4, 8) + if __name__ == "__main__": test_combine_parallel_conv2d() + test_combine_parallel_conv2d_scale_relu() + test_combine_parallel_conv2d_scale() From 9ee9396eb2fc3ae5e3c586a8653ab65a0d45f1d4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 19 Nov 2018 20:45:16 +0800 Subject: [PATCH 12/15] Minor improvement --- src/relay/pass/combine_parallel_conv2d.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index d1b065a30782..96b6faac0dc6 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -28,7 +28,8 @@ namespace tvm { namespace relay { -using Group = std::vector>; +using Branch = std::vector; +using Group = std::vector; /* Find parallel branches starting with conv2d as shown below and then group branches by kernel @@ -55,7 +56,7 @@ class BranchGroupFinder : private ExprVisitor { auto&& branch = CreateBranch(conv); // add the branch to a group, or create a new group auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { - CHECK(group.size() && group[0].size()); + CHECK(!group.empty() && !group[0].empty()); return IsCompatibleConv2D(conv, group[0][0]); }); if (it != groups.end()) { @@ -98,10 +99,10 @@ class BranchGroupFinder : private ExprVisitor { } // Create a branch starting from conv2d. - std::vector CreateBranch(const CallNode* conv) { + Branch CreateBranch(const CallNode* conv) { static auto fpattern = Op::GetAttr("TOpPattern"); // each branch has at least one element, the first element is always conv2d - std::vector branch{conv}; + Branch branch{conv}; auto it = children_map_.find(GetRef(branch.back())); while (it != children_map_.end() && it->second.size() == 1) { const CallNode* call = it->second[0]; @@ -217,7 +218,7 @@ class ParallelConv2DCombiner { AttrsEqual attrs_equal; // check if all branches in current depth can be combined for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const std::vector& branch = *it; + const Branch& branch = *it; if (!branch[depth]->op.same_as(call->op) || !attrs_equal(branch[depth]->attrs, call->attrs) || branch[depth]->args.size() != call->args.size()) { @@ -286,7 +287,7 @@ class ParallelConv2DCombiner { } // Combine branches in a group. Conv2d in different branches in the same group are safe to - // combined. Subsequent ops may or may not be combined. We start from conv2d and try to + // combine. Subsequent ops may or may not be combined. We start from conv2d and try to // combine ops from all branches in the same depth. void CombineBranches(const Group& branches) { Call combined = MakeCombinedConv2D(branches); @@ -296,8 +297,8 @@ class ParallelConv2DCombiner { size_t channel_pos = layout.find('C'); CHECK_NE(channel_pos, std::string::npos); auto it = std::min_element(branches.begin(), branches.end(), - [](const std::vector& branch_a, - const std::vector& branch_b) { + [](const Branch& branch_a, + const Branch& branch_b) { return branch_a.size() < branch_b.size(); }); size_t depth = it->size(); From 9aafe00a1aef68d5f937602e7efa7468e5502789 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 20 Nov 2018 00:29:00 +0800 Subject: [PATCH 13/15] Fix size_t issue for i386 --- src/relay/pass/combine_parallel_conv2d.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 96b6faac0dc6..5dc17edea943 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -196,16 +196,16 @@ class ParallelConv2DCombiner { return false; // Position of the 'C' dimension in the argument - int64_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); + size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); // Channel super-dimension shoule be present and not broadcasted - if ((arg_channel_pos < 0) || + if ((arg_channel_pos > channel_pos) || // size_t overflow !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) return false; for (size_t i = 0; i < ta->shape.size(); i++) { - if (i == static_cast(arg_channel_pos)) continue; + if (i == arg_channel_pos) continue; if (!eq(ta->shape[i], tb->shape[i])) return false; } From 43027897d1d26b77c09c80e4e72dfbea44e530cb Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 20 Nov 2018 00:30:33 +0800 Subject: [PATCH 14/15] Fix indent --- src/relay/pass/combine_parallel_conv2d.cc | 26 +++++++++++------------ 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 5dc17edea943..099ab0c16537 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -199,7 +199,7 @@ class ParallelConv2DCombiner { size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); // Channel super-dimension shoule be present and not broadcasted - if ((arg_channel_pos > channel_pos) || // size_t overflow + if ((arg_channel_pos > channel_pos) || // size_t overflow !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) return false; @@ -270,18 +270,18 @@ class ParallelConv2DCombiner { size_t channel_pos) { int64_t index = 0; for (const auto& branch : branches) { - const CallNode* conv2d = branch[0]; - int64_t channels = GetConv2DSuperChannelsDim(conv2d); - Array begin; - Array end; - for (size_t i = 0; i < channel_pos; i++) { - begin.push_back(0); - end.push_back(NullValue()); - } - begin.push_back(index); - index += channels; - end.push_back(index); - auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); + const CallNode* conv2d = branch[0]; + int64_t channels = GetConv2DSuperChannelsDim(conv2d); + Array begin; + Array end; + for (size_t i = 0; i < channel_pos; i++) { + begin.push_back(0); + end.push_back(NullValue()); + } + begin.push_back(index); + index += channels; + end.push_back(index); + auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); subst_map_[GetRef(branch[depth])] = slice; } } From bed290033821a7c552b2476010ce7e8ae52ee234 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 22 Nov 2018 10:24:29 +0800 Subject: [PATCH 15/15] Remove unused header --- src/relay/pass/combine_parallel_conv2d.cc | 2 +- src/relay/pass/pattern_util.h | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 099ab0c16537..48d5d77990d6 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -19,12 +19,12 @@ #include #include #include -#include #include #include #include "./expr_subst.h" #include "./pattern_util.h" + namespace tvm { namespace relay { diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 550d008634ff..38ae923c5274 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -14,6 +14,7 @@ #include #include "../op/layout.h" + namespace tvm { namespace relay {