Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 4,
"OpFusion": 1,
"FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3,
}

Expand Down
19 changes: 12 additions & 7 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using Group = std::vector<Branch>;
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.

data
data
/ \
conv2d conv2d
| |
Expand All @@ -47,17 +47,22 @@ using Group = std::vector<Branch>;
class BranchGroupFinder : private ExprVisitor {
public:
std::vector<Group> Find(const Expr& expr) {
static const Op& conv2d = Op::Get("nn.conv2d");

this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& convs = children_map_.at(root);
for (const CallNode* conv : convs) {
auto&& branch = CreateBranch(conv);
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(conv2d)) continue;

auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]);
return IsCompatibleConv2D(child, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
Expand Down Expand Up @@ -108,7 +113,7 @@ class BranchGroupFinder : private ExprVisitor {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(it->second[0]);
branch.push_back(call);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
Expand Down
6 changes: 4 additions & 2 deletions tests/python/relay/test_pass_combine_parallel_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def before(x, w1, w2, w3, w4):
# y3 cannot be combined
y3 = relay.nn.conv2d(x, w3)
y4 = relay.nn.conv2d(x, w4)
y = relay.Tuple((y1, y2, y3, y4))
y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y)

def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
Expand All @@ -24,7 +25,8 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4):
y3 = relay.nn.conv2d(x, w3)
y4 = relay.strided_slice(y, [0, channels1 + channels2],
[None, channels1 + channels2 + channels4])
y = relay.Tuple((y1, y2, y3, y4))
y5 = relay.nn.max_pool2d(x)
y = relay.Tuple((y1, y2, y3, y4, y5))
return relay.Function(args, y)

def check(x_shape, channels1, channels2, channels3, channels4):
Expand Down