[MKLDNN] Enable convolution fusion.#12308
Conversation
| } | ||
| const int GetHash() const { | ||
| int hash = 0; | ||
| hash = hash * 2 + this->with_bn ? 1 : 0; |
There was a problem hiding this comment.
Possible hash collision: with_bn=0 and with_relu=1 equals BN=1 and relu0. Consider using bitflags
|
|
||
| static inline std::string PrintArguments(const ConvolutionParam& param_) { | ||
| auto args = ListArguments(param_); | ||
| std::string str = "["; |
There was a problem hiding this comment.
It's better to use std::stringstream to compose such string, otherwise it's a lot of redundant copying internally.
| DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false) | ||
| .describe("Add post relu after sum"); | ||
| } | ||
| const int GetHash() const { |
| LOG(INFO) << "Conv req size: " << req.size(); | ||
| for (size_t k = 0; k < inputs.size(); ++k) { | ||
| auto input = inputs[k]; | ||
| printf("input %ld :", k); |
There was a problem hiding this comment.
Sometimes it's LOG(INFO), sometimes it's printf. I think it's better to use LOG always.
| if (it != attrs.dict.end() && it->second == "true") { | ||
| it = attrs.dict.find("in_sum_at_begin"); | ||
| if (it != attrs.dict.end() && it->second == "true") { | ||
| return std::vector<std::pair<int, int>>{std::pair<int, int>{0, 0}}; |
There was a problem hiding this comment.
You can use std::make_pair and ommit the template argument types
| DefaultSubgraphOpResourceRequest) | ||
| .set_attr<std::string>("key_var_num_args", "num_args") | ||
| .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const nnvm::NodeAttrs | ||
| &attrs) { |
There was a problem hiding this comment.
If you specify the return type, like -> std::vector<std::pair<int, int>> then you can do things like return {}; further in the lambda or just return std::make_pair(0, 0).
| class SgMKLDNNConvSelector : public SubgraphSelector { | ||
| public: | ||
| /*! \brief pattern match status */ | ||
| enum SelectStatus { |
There was a problem hiding this comment.
It's better to use typed enums since they have better scoping and you don't need any prefixes or all caps. Usage will be like: SelectStatus.Fail
| if (new_node.inputs[1].node.get() == &n) { | ||
| sum_entry = new_node.inputs[0]; | ||
| } | ||
| #if 0 |
| auto last_node = sym.outputs[0].node; | ||
| nnvm::Symbol new_sym; | ||
| new_sym.outputs.emplace_back(nnvm::NodeEntry{last_node, 0, 0}); | ||
| std::string node_name = ""; |
| ConvFusionFallBackCompute(); | ||
| } | ||
|
|
||
| class SgMKLDNNConvOperator { |
There was a problem hiding this comment.
Why not have the definition in the header file?
| nnvm::NodeEntry conv_data; | ||
| std::vector<const nnvm::Node *> matched_list; | ||
|
|
||
| bool HandleMatchStatus() { |
There was a problem hiding this comment.
Why not have the implementations for all methods in the .cc file?
|
@marcoabreu @lebeg @larroy. Thanks for your reviewing comments, I will address them in next version. |
|
@ZhennanQin Thanks for your contribution. Could you address the feedback? Let us know if you need help! |
|
@mxnet-label-bot [pr-awaiting-response] |
|
@ZhennanQin Requesting an update on the PR, have the changes been addressed? |
Description
Implement mkldnn convlution fusion(eg. conv+relu, conv+bn, conv+sum) based on subgraph.
@pengzhao-intel @TaoLv @zheng-da @reminisce
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments