-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[RELAY][PASS] CombineParallelConv2D #2089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ca7a659
c1b73c1
febef4e
aaade9b
0318eef
002b4e5
f923a16
1934fd2
c638625
8186479
74ca32d
9ee9396
9aafe00
4302789
bed2900
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,328 @@ | ||
| /*! | ||
| * Copyright (c) 2018 by Contributors | ||
| * | ||
| * \file combine_parallel_conv2d.cc | ||
| * \brief Combine parallel 2d convolutions into a single convolution. | ||
| * | ||
| * This pass replaces convolutions that share the same input node and the same | ||
| * arguments (except that the number of output channels can be different) with a | ||
| * single convolution. The weight of the new 2d convolution is the concatenation | ||
| * of the original weights. Elemwise and broadcast ops following conv2d are also | ||
| * combined if possible. | ||
| * | ||
| * This prevents launching multiple kernels in networks with multiple | ||
| * convolution branches, such as Inception block. | ||
| */ | ||
|
|
||
| #include <tvm/relay/pass.h> | ||
| #include <tvm/relay/expr_functor.h> | ||
| #include <tvm/relay/attrs/nn.h> | ||
| #include <tvm/relay/attrs/transform.h> | ||
| #include <tvm/relay/op_attr_types.h> | ||
| #include <unordered_map> | ||
| #include <unordered_set> | ||
| #include "./expr_subst.h" | ||
| #include "./pattern_util.h" | ||
|
|
||
|
|
||
| namespace tvm { | ||
| namespace relay { | ||
|
|
||
| using Branch = std::vector<const CallNode*>; | ||
| using Group = std::vector<Branch>; | ||
|
|
||
| /* | ||
| Find parallel branches starting with conv2d as shown below and then group branches by kernel | ||
| shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops. | ||
| Intermediate nodes have exactly one successor. It is possible that branches meet at a point, | ||
| which should be handled in ParallelConv2DCombiner. | ||
|
|
||
| data | ||
| / \ | ||
| conv2d conv2d | ||
| | | | ||
| op op | ||
| | | | ||
| */ | ||
| class BranchGroupFinder : private ExprVisitor { | ||
| public: | ||
| std::vector<Group> Find(const Expr& expr) { | ||
| this->VisitExpr(expr); | ||
|
|
||
| std::vector<Group> groups; | ||
| for (const auto& root : conv_roots_) { | ||
| const auto& convs = children_map_.at(root); | ||
| for (const CallNode* conv : convs) { | ||
| auto&& branch = CreateBranch(conv); | ||
| // add the branch to a group, or create a new group | ||
| auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) { | ||
| CHECK(!group.empty() && !group[0].empty()); | ||
| return IsCompatibleConv2D(conv, group[0][0]); | ||
| }); | ||
| if (it != groups.end()) { | ||
| it->push_back(branch); | ||
| } else { | ||
| groups.emplace_back(); | ||
| // each group has at least one branch | ||
| groups.back().push_back(branch); | ||
| } | ||
| } | ||
| } | ||
| return groups; | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_; | ||
| std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_; | ||
|
|
||
| // Two 2d convolutions can be combined if they have the same attributes or | ||
| // only have different output channels. | ||
| bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { | ||
| AttrsEqual eq; | ||
| static const Layout kOIHW("OIHW"); | ||
| const auto* attrs_a = a->attrs.as<Conv2DAttrs>(); | ||
| const auto* attrs_b = b->attrs.as<Conv2DAttrs>(); | ||
| CHECK(attrs_a); | ||
| CHECK(attrs_b); | ||
| const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>(); | ||
| const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>(); | ||
| const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW); | ||
| const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW); | ||
|
|
||
| return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) && | ||
| eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) && | ||
| eq(attrs_a->data_layout, attrs_b->data_layout) && | ||
| eq(attrs_a->weight_layout, attrs_b->weight_layout) && | ||
| eq(attrs_a->out_dtype, attrs_b->out_dtype) && | ||
| eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) && | ||
| eq(shape_a[3], shape_b[3]); | ||
| } | ||
|
|
||
| // Create a branch starting from conv2d. | ||
| Branch CreateBranch(const CallNode* conv) { | ||
| static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern"); | ||
| // each branch has at least one element, the first element is always conv2d | ||
| Branch branch{conv}; | ||
| auto it = children_map_.find(GetRef<Expr>(branch.back())); | ||
| while (it != children_map_.end() && it->second.size() == 1) { | ||
| const CallNode* call = it->second[0]; | ||
| auto pattern = fpattern[Downcast<Op>(call->op)]; | ||
| if (pattern <= kBroadcast) { | ||
| branch.push_back(it->second[0]); | ||
| it = children_map_.find(GetRef<Expr>(branch.back())); | ||
| } else { | ||
| break; | ||
| } | ||
| } | ||
| return branch; | ||
| } | ||
|
|
||
| void VisitExpr_(const CallNode* n) final { | ||
| static const Op& conv2d = Op::Get("nn.conv2d"); | ||
| ExprVisitor::VisitExpr_(n); | ||
| if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) { | ||
| conv_roots_.insert(n->args[0]); | ||
| children_map_[n->args[0]].push_back(n); | ||
| } else { | ||
| for (size_t i = 0; i < n->args.size(); i++) { | ||
| children_map_[n->args[i]].push_back(n); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| class ParallelConv2DCombiner { | ||
| public: | ||
| Expr Combine(const Expr& expr) { | ||
| auto groups = BranchGroupFinder().Find(expr); | ||
| for (const Group& group : groups) { | ||
| if (group.size() < 2) continue; | ||
| CombineBranches(group); | ||
| } | ||
| return ExprSubst(expr, std::move(subst_map_)); | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_; | ||
|
|
||
| std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) { | ||
| int64_t num_filters = 0; // number of filters of the transformed weight | ||
| Array<Expr> weights; | ||
| for (const auto& branch : branches) { | ||
| auto conv2d = branch[0]; | ||
| weights.push_back(conv2d->args[1]); | ||
| auto channels = GetConv2DSuperChannelsDim(conv2d); | ||
| num_filters += channels; | ||
| } | ||
| auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O'); | ||
| CHECK_NE(index, std::string::npos); | ||
| return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), | ||
| MakeConstScalar(Int(32), num_filters)); | ||
| } | ||
|
|
||
| Call MakeCombinedConv2D(const Group& branches) { | ||
| static const Op& conv2d = Op::Get("nn.conv2d"); | ||
| Expr data = branches[0][0]->args[0]; | ||
| Expr new_weight; | ||
| IndexExpr new_channels; | ||
| std::tie(new_weight, new_channels) = TransformWeight(branches); | ||
|
|
||
| const CallNode* group_root = branches[0][0]; | ||
| const auto* attrs = group_root->attrs.as<Conv2DAttrs>(); | ||
| CHECK(attrs); | ||
| const auto new_attrs = make_node<Conv2DAttrs>(); | ||
| new_attrs->strides = attrs->strides; | ||
| new_attrs->padding = attrs->padding; | ||
| new_attrs->dilation = attrs->dilation; | ||
| new_attrs->groups = attrs->groups; | ||
| new_attrs->kernel_size = attrs->kernel_size; | ||
| new_attrs->data_layout = attrs->data_layout; | ||
| new_attrs->weight_layout = attrs->weight_layout; | ||
| new_attrs->out_layout = attrs->out_layout; | ||
| new_attrs->out_dtype = attrs->out_dtype; | ||
| new_attrs->channels = new_channels; | ||
|
|
||
| return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); | ||
| } | ||
|
|
||
| bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { | ||
| AttrsEqual eq; | ||
| auto ta = a->args[index]->type_as<TensorTypeNode>(); | ||
| auto tb = b->args[index]->type_as<TensorTypeNode>(); | ||
| auto toutput_a = a->type_as<TensorTypeNode>(); | ||
| auto toutput_b = b->type_as<TensorTypeNode>(); | ||
|
|
||
| if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) | ||
| return false; | ||
|
|
||
| // Position of the 'C' dimension in the argument | ||
| size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); | ||
|
|
||
| // Channel super-dimension shoule be present and not broadcasted | ||
| if ((arg_channel_pos > channel_pos) || // size_t overflow | ||
| !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || | ||
| !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) | ||
| return false; | ||
|
|
||
| for (size_t i = 0; i < ta->shape.size(); i++) { | ||
| if (i == arg_channel_pos) continue; | ||
| if (!eq(ta->shape[i], tb->shape[i])) | ||
| return false; | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // Check if ops in depth-th level can be combined | ||
| bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { | ||
| const CallNode* call = branches[0][depth]; | ||
| AttrsEqual attrs_equal; | ||
| // check if all branches in current depth can be combined | ||
| for (auto it = branches.begin() + 1; it != branches.end(); it++) { | ||
| const Branch& branch = *it; | ||
| if (!branch[depth]->op.same_as(call->op) || | ||
| !attrs_equal(branch[depth]->attrs, call->attrs) || | ||
| branch[depth]->args.size() != call->args.size()) { | ||
| return false; | ||
| } | ||
|
|
||
| if (branch[depth]->args[parent_index].get() != branch[depth - 1]) | ||
| return false; | ||
|
|
||
| // Check args | ||
| for (size_t i = 0; i < call->args.size(); i++) { | ||
| if (i == parent_index) continue; | ||
|
|
||
| if (!IsArgCompatible(call, branch[depth], i, channel_pos) || | ||
| !attrs_equal(call->attrs, branch[depth]->attrs)) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // Combine args and make the combined CallNode | ||
| Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos, | ||
| size_t parent_index) { | ||
| Array<Expr> new_args; | ||
| const CallNode* call = branches[0][depth]; | ||
| size_t ndim = call->type_as<TensorTypeNode>()->shape.size(); | ||
|
|
||
| for (size_t i = 0; i < call->args.size(); i++) { | ||
| if (i == parent_index) { | ||
| new_args.push_back(data); | ||
| continue; | ||
| } | ||
| size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size(); | ||
| size_t arg_channel_pos = channel_pos - ndim + arg_ndim; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you give an example of values of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ndim is the dimension of output of con2
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, then for NCHW layout
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes in this case. |
||
| Array<Expr> tuple; | ||
| for (const auto& branch : branches) { | ||
| tuple.push_back(branch[depth]->args[i]); | ||
| } | ||
| auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); | ||
| new_args.push_back(std::move(concat)); | ||
| } | ||
| return CallNode::make(call->op, new_args, call->attrs, {}); | ||
| } | ||
|
|
||
| // Replace output of each branch with slices of the combined output | ||
| void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, | ||
| size_t channel_pos) { | ||
| int64_t index = 0; | ||
| for (const auto& branch : branches) { | ||
| const CallNode* conv2d = branch[0]; | ||
| int64_t channels = GetConv2DSuperChannelsDim(conv2d); | ||
| Array<Integer> begin; | ||
| Array<Integer> end; | ||
| for (size_t i = 0; i < channel_pos; i++) { | ||
| begin.push_back(0); | ||
| end.push_back(NullValue<Integer>()); | ||
| } | ||
| begin.push_back(index); | ||
| index += channels; | ||
| end.push_back(index); | ||
| auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{}); | ||
| subst_map_[GetRef<Expr>(branch[depth])] = slice; | ||
| } | ||
| } | ||
|
|
||
| // Combine branches in a group. Conv2d in different branches in the same group are safe to | ||
| // combine. Subsequent ops may or may not be combined. We start from conv2d and try to | ||
| // combine ops from all branches in the same depth. | ||
| void CombineBranches(const Group& branches) { | ||
| Call combined = MakeCombinedConv2D(branches); | ||
| auto conv_param = combined->attrs.as<Conv2DAttrs>(); | ||
| const std::string& layout = | ||
| conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; | ||
| size_t channel_pos = layout.find('C'); | ||
| CHECK_NE(channel_pos, std::string::npos); | ||
| auto it = std::min_element(branches.begin(), branches.end(), | ||
| [](const Branch& branch_a, | ||
| const Branch& branch_b) { | ||
| return branch_a.size() < branch_b.size(); | ||
| }); | ||
| size_t depth = it->size(); | ||
| size_t i; | ||
| // starting from 1 to skip the conv2d | ||
| for (i = 1; i < depth; i++) { | ||
| size_t parent_index; | ||
| for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { | ||
| if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; | ||
| } | ||
| CHECK_NE(parent_index, branches[0][i]->args.size()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When branches[0][i] is a node where multiple branches meet,
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. another case is that the parent is on rhs of binary op |
||
| if (!CheckLevel(branches, i, channel_pos, parent_index)) break; | ||
| combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); | ||
| } | ||
| UpdateGroupOutput(combined, branches, i - 1, channel_pos); | ||
| } | ||
| }; | ||
|
|
||
| Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); } | ||
|
|
||
| TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") | ||
| .set_body([](TVMArgs args, TVMRetValue* ret) { | ||
| *ret = CombineParallelConv2D(args[0]); | ||
| }); | ||
|
|
||
| } // namespace relay | ||
| } // namespace tvm | ||
Uh oh!
There was an error while loading. Please reload this page.