From fc53ed6b13f50f65db088e0753338141261cd81e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 26 Nov 2018 10:35:33 +0800 Subject: [PATCH 1/5] [Relay][Pass] Fix CombineParallelConv2D --- src/relay/pass/combine_parallel_conv2d.cc | 12 ++++++++---- .../relay/test_pass_combine_parallel_conv2d.py | 6 ++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 48d5d77990d6..f691b4db0c3d 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -47,17 +47,21 @@ using Group = std::vector; class BranchGroupFinder : private ExprVisitor { public: std::vector Find(const Expr& expr) { + static const Op& conv2d = Op::Get("nn.conv2d"); + this->VisitExpr(expr); std::vector groups; for (const auto& root : conv_roots_) { - const auto& convs = children_map_.at(root); - for (const CallNode* conv : convs) { - auto&& branch = CreateBranch(conv); + const auto& children = children_map_.at(root); + for (const CallNode* child : children) { + if (!child->op.same_as(conv2d)) continue; + + auto&& branch = CreateBranch(child); // add the branch to a group, or create a new group auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { CHECK(!group.empty() && !group[0].empty()); - return IsCompatibleConv2D(conv, group[0][0]); + return IsCompatibleConv2D(child, group[0][0]); }); if (it != groups.end()) { it->push_back(branch); diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 31dfe095f682..6fea201d64c8 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -11,7 +11,8 @@ def before(x, w1, w2, w3, w4): # y3 cannot be combined y3 = relay.nn.conv2d(x, w3) y4 = relay.nn.conv2d(x, w4) - y = relay.Tuple((y1, y2, y3, y4)) + y5 = relay.nn.max_pool2d(x) + y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): @@ -24,7 +25,8 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): y3 = relay.nn.conv2d(x, w3) y4 = relay.strided_slice(y, [0, channels1 + channels2], [None, channels1 + channels2 + channels4]) - y = relay.Tuple((y1, y2, y3, y4)) + y5 = relay.nn.max_pool2d(x) + y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def check(x_shape, channels1, channels2, channels3, channels4): From 10a23773207d771d0e45b93476376b0912341138 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 26 Nov 2018 12:46:40 +0800 Subject: [PATCH 2/5] Fix group searching --- src/relay/pass/combine_parallel_conv2d.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index f691b4db0c3d..cbd89decab78 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -54,12 +54,13 @@ class BranchGroupFinder : private ExprVisitor { std::vector groups; for (const auto& root : conv_roots_) { const auto& children = children_map_.at(root); + size_t ngroups = groups.size(); for (const CallNode* child : children) { if (!child->op.same_as(conv2d)) continue; auto&& branch = CreateBranch(child); // add the branch to a group, or create a new group - auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { + auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { CHECK(!group.empty() && !group[0].empty()); return IsCompatibleConv2D(child, group[0][0]); }); @@ -112,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor { const CallNode* call = it->second[0]; auto pattern = fpattern[Downcast(call->op)]; if (pattern <= kBroadcast) { - branch.push_back(it->second[0]); + branch.push_back(call); it = children_map_.find(GetRef(branch.back())); } else { break; From 7ba5a69b44c7690811403f23c6d28a2708347100 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 26 Nov 2018 13:52:42 +0800 Subject: [PATCH 3/5] Set opt_pass_level to 1 --- python/tvm/relay/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d67bc89702d3..41c3fe20cb3a 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,7 +13,7 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, - "CombineParallelConv2D": 4, + "CombineParallelConv2D": 1, "OpFusion": 1, "FoldConstant": 2, "FoldScaleAxis": 3, From 7ebcbcb2c871db04ae153e18ec6c78b82bc34349 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 26 Nov 2018 13:53:18 +0800 Subject: [PATCH 4/5] Fix doc --- src/relay/pass/combine_parallel_conv2d.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index cbd89decab78..e346aea518e9 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -37,7 +37,7 @@ using Group = std::vector; Intermediate nodes have exactly one successor. It is possible that branches meet at a point, which should be handled in ParallelConv2DCombiner. - data + data / \ conv2d conv2d | | From ac46f691dfd8c6e141c8c22266b6fe619ef5937e Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 26 Nov 2018 14:04:38 +0800 Subject: [PATCH 5/5] Set opt_pass_level to 3 --- python/tvm/relay/build_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 41c3fe20cb3a..c92483dc747f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,9 +13,9 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, - "CombineParallelConv2D": 1, "OpFusion": 1, "FoldConstant": 2, + "CombineParallelConv2D": 3, "FoldScaleAxis": 3, }