diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d67bc89702d3..c92483dc747f 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -13,9 +13,9 @@ # List of optimization pass and level when switch on OPT_PASS_LEVEL = { "SimplifyInference": 0, - "CombineParallelConv2D": 4, "OpFusion": 1, "FoldConstant": 2, + "CombineParallelConv2D": 3, "FoldScaleAxis": 3, } diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index 48d5d77990d6..e346aea518e9 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -37,7 +37,7 @@ using Group = std::vector; Intermediate nodes have exactly one successor. It is possible that branches meet at a point, which should be handled in ParallelConv2DCombiner. - data + data / \ conv2d conv2d | | @@ -47,17 +47,22 @@ using Group = std::vector; class BranchGroupFinder : private ExprVisitor { public: std::vector Find(const Expr& expr) { + static const Op& conv2d = Op::Get("nn.conv2d"); + this->VisitExpr(expr); std::vector groups; for (const auto& root : conv_roots_) { - const auto& convs = children_map_.at(root); - for (const CallNode* conv : convs) { - auto&& branch = CreateBranch(conv); + const auto& children = children_map_.at(root); + size_t ngroups = groups.size(); + for (const CallNode* child : children) { + if (!child->op.same_as(conv2d)) continue; + + auto&& branch = CreateBranch(child); // add the branch to a group, or create a new group - auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { + auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { CHECK(!group.empty() && !group[0].empty()); - return IsCompatibleConv2D(conv, group[0][0]); + return IsCompatibleConv2D(child, group[0][0]); }); if (it != groups.end()) { it->push_back(branch); @@ -108,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor { const CallNode* call = it->second[0]; auto pattern = fpattern[Downcast(call->op)]; if (pattern <= kBroadcast) { - branch.push_back(it->second[0]); + branch.push_back(call); it = children_map_.find(GetRef(branch.back())); } else { break; diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 31dfe095f682..6fea201d64c8 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -11,7 +11,8 @@ def before(x, w1, w2, w3, w4): # y3 cannot be combined y3 = relay.nn.conv2d(x, w3) y4 = relay.nn.conv2d(x, w4) - y = relay.Tuple((y1, y2, y3, y4)) + y5 = relay.nn.max_pool2d(x) + y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): @@ -24,7 +25,8 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): y3 = relay.nn.conv2d(x, w3) y4 = relay.strided_slice(y, [0, channels1 + channels2], [None, channels1 + channels2 + channels4]) - y = relay.Tuple((y1, y2, y3, y4)) + y5 = relay.nn.max_pool2d(x) + y = relay.Tuple((y1, y2, y3, y4, y5)) return relay.Function(args, y) def check(x_shape, channels1, channels2, channels3, channels4):