diff --git a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h deleted file mode 100644 index 5e015cbf14e1..000000000000 --- a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file dnnl_elemwisemul_post_quantize_property.cc - * \brief Partition gragph property for oneDNN Quantized ElemwiseMul operator - * \author Xinyu Chen - */ - -#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_ -#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_ -#if MXNET_USE_ONEDNN == 1 - -#include -#include -#include - -#include "../../quantization/requantize-inl.h" -#include "../../tensor/elemwise_binary_op-inl.h" -#include "../common.h" -#include "dnnl_subgraph_base-inl.h" - -namespace mxnet { -namespace op { - -#define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul" - -class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 { - public: - /*! \brief pattern match status */ - enum SelectStatus { - kFail = 0, - kStart, - kRequantize, - kSuccess, - }; - - private: - bool disable_all; - bool disable_float_output; - SelectStatus status; - std::vector matched_list; - - public: - explicit ElemwiseMulPostQuantizeSelector(const bool dis_all, const bool dis_float_output) - : disable_all(dis_all), disable_float_output(dis_float_output) {} - - bool Select(const BiDirectedNode& n) override { - const auto rawnode = n.node; - if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) { - status = disable_all ? kSuccess : kStart; - matched_list.clear(); - matched_list.push_back(&n); - return true; - } - return false; - } - - bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { - return false; - } - - bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { - const auto raw_node = n.node; - const auto raw_new_node = new_node.node; - if (status == kFail || status == kSuccess || raw_new_node->is_variable()) - return false; - // If n isn't the last matched node, then we encoutered a internal - // branch, we should pop out the node behind n and stop fusion. - if (matched_list.back() != &n) { - if (std::find(matched_list.begin(), matched_list.end(), &n) != matched_list.end()) { - while (matched_list.back() != &n) { - matched_list.pop_back(); - } - } - - status = kSuccess; - return false; - } - - switch (status) { - case kStart: - if (raw_new_node->op() == Op::Get("_contrib_requantize")) { - auto const& param = nnvm::get(raw_new_node->attrs.parsed); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - matched_list.push_back(&new_node); - status = kRequantize; - return true; - } - } - case kRequantize: - if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) { - CHECK(raw_node->op() == Op::Get("_contrib_requantize")); - if (n.outputs.size() > 1) { - // check if requantize have other outputs than dequantize - // if it has we can't fuse dequantize into elemwise_mul - for (auto kv : n.outputs) { - const auto& node = kv.first; - if (node->op() != Op::Get("_contrib_dequantize")) { - status = kSuccess; - return false; - } - } - } - - matched_list.push_back(&new_node); - status = kSuccess; - return true; - } - default: - status = kSuccess; - return false; - } - } - - std::vector Filter(const std::vector& candidates) override { - if ((status != kSuccess) || (matched_list.size() <= 1)) { - return std::vector(0); - } else { - std::vector ret; - for (auto i : matched_list) { - auto non_const_i = const_cast(i); - if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { - ret.push_back(non_const_i); - } - } - return ret; - } - } - - void Reset() override { - CHECK_GE(matched_list.size(), 1); - auto new_selector = ElemwiseMulPostQuantizeSelector(disable_all, disable_float_output); - new_selector.Select(*matched_list[0]); - *this = new_selector; - } -}; - -class ElemwiseMulPostQuantizeProperty : public SubgraphProperty { - public: - ElemwiseMulPostQuantizeProperty() { - disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FUSE_ALL", false); - disable_float_output = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FLOAT_OUTPUT", false); - } - - static SubgraphPropertyPtr Create() { - static const std::string& name = "oneDNN EltwiseMul post-quantization optimization pass"; - auto property = std::make_shared(); - property->SetAttr("property_name", name); - property->SetAttr("inference_only", true); - return property; - } - - nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, - const int subgraph_id = 0) const override { - nnvm::ObjectPtr em_node = nullptr; - nnvm::ObjectPtr requantize_node = nullptr; - nnvm::ObjectPtr dequantize_node = nullptr; - - DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { - if (node->is_variable()) - return; - if (node->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) { - em_node = node; - } else if (node->op() == Op::Get("_contrib_requantize")) { - requantize_node = node; - } else if (node->op() == Op::Get("_contrib_dequantize")) { - dequantize_node = node; - } - }); - - CHECK_NOTNULL(em_node); - CHECK_NOTNULL(requantize_node); - auto const& requantize_param = nnvm::get(requantize_node->attrs.parsed); - CHECK(requantize_param.min_calib_range.has_value()); - CHECK(requantize_param.max_calib_range.has_value()); - - // When only fused quantized_elemwise_mul and requantize, set min/max_cablib_range, - // When fused quantized_elemwise_mul + requantize + dequantize, set dequantize flag to true. - if (dequantize_node != nullptr) { - em_node->attrs.dict["enable_float_output"] = "True"; - } else { - em_node->attrs.dict["min_calib_range"] = - std::to_string(requantize_param.min_calib_range.value()); - em_node->attrs.dict["max_calib_range"] = - std::to_string(requantize_param.max_calib_range.value()); - } - em_node->op()->attr_parser(&(em_node->attrs)); - return em_node; - } - - SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { - auto selector = - std::make_shared(disable_fuse_all, disable_float_output); - return selector; - } - - void ConnectSubgraphOutputs(const nnvm::ObjectPtr n, - std::vector* output_entries) const override { - for (size_t i = 0; i < output_entries->size(); ++i) { - auto entry_ptr = output_entries->at(i); - *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; - } - } - - private: - bool disable_fuse_all; - bool disable_float_output; -}; - -} // namespace op -} // namespace mxnet - -#endif // if MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h deleted file mode 100644 index b1ae5373ece9..000000000000 --- a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file dnnl_fc_post_quantize_property.cc - * \brief Partition gragph property for oneDNN Quantized FullyConnected operator - * \author Ciyong Chen - */ - -#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_ -#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_ -#if MXNET_USE_ONEDNN == 1 - -#include -#include -#include - -#include "../../nn/fully_connected-inl.h" -#include "../../quantization/requantize-inl.h" -#include "../common.h" -#include "dnnl_subgraph_base-inl.h" - -namespace mxnet { -namespace op { - -#define QUANTIZED_FC_NAME "_sg_onednn_fully_connected" - -class SgDNNLFCPostQuantizeSelector : public SubgraphSelectorV2 { - public: - /*! \brief pattern match status */ - enum SelectStatus { - kFail = 0, - kStart, - kRequantize, - kSuccess, - }; - - private: - bool disable_all; - bool disable_float_output; - SelectStatus status; - std::vector matched_list; - - public: - explicit SgDNNLFCPostQuantizeSelector(const bool dis_all, const bool dis_float_output) - : disable_all(dis_all), disable_float_output(dis_float_output) {} - - bool Select(const BiDirectedNode& n) override { - const auto rawnode = n.node; - if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) { - status = disable_all ? kSuccess : kStart; - matched_list.clear(); - matched_list.push_back(&n); - return true; - } - return false; - } - - bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { - return false; - } - - bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { - const auto raw_node = n.node; - const auto raw_new_node = new_node.node; - if (status == kFail || status == kSuccess || raw_new_node->is_variable()) - return false; - // If n isn't the last matched node, then we encoutered a internal - // branch, we should pop out the node behind n and stop fusion. - if (matched_list.back() != &n) { - if (std::find(matched_list.begin(), matched_list.end(), &n) != matched_list.end()) { - while (matched_list.back() != &n) { - matched_list.pop_back(); - } - } - - status = kSuccess; - return false; - } - - switch (status) { - case kStart: - if (raw_new_node->op() == Op::Get("_contrib_requantize")) { - auto const& param = nnvm::get(raw_new_node->attrs.parsed); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - matched_list.push_back(&new_node); - status = kRequantize; - return true; - } - } - case kRequantize: - if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) { - CHECK(raw_node->op() == Op::Get("_contrib_requantize")); - if (n.outputs.size() > 1) { - // check if requantize have other outputs than dequantize - // if it has we can't fuse dequantize into FC - for (auto kv : n.outputs) { - const auto& node = kv.first; - if (node->op() != Op::Get("_contrib_dequantize")) { - status = kSuccess; - return false; - } - } - } - matched_list.push_back(&new_node); - status = kSuccess; - return true; - } - default: - status = kSuccess; - return false; - } - } - - std::vector Filter(const std::vector& candidates) override { - if ((status != kSuccess) || (matched_list.size() <= 1)) { - return std::vector(0); - } else { - std::vector ret; - for (auto i : matched_list) { - auto non_const_i = const_cast(i); - if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { - ret.push_back(non_const_i); - } - } - return ret; - } - } - - void Reset() override { - CHECK_GE(matched_list.size(), 1); - auto new_selector = SgDNNLFCPostQuantizeSelector(disable_all, disable_float_output); - new_selector.Select(*matched_list[0]); - *this = new_selector; - } -}; - -class SgDNNLFCPostQuantizeProperty : public SubgraphProperty { - public: - SgDNNLFCPostQuantizeProperty() { - disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL", false); - disable_float_output = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT", false); - } - - static SubgraphPropertyPtr Create() { - static const std::string& name = "oneDNN FullyConected post-quantization optimization pass"; - auto property = std::make_shared(); - property->SetAttr("property_name", name); - property->SetAttr("inference_only", true); - return property; - } - - nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, - const int subgraph_id = 0) const override { - nnvm::ObjectPtr fc_node = nullptr; - nnvm::ObjectPtr requantize_node = nullptr; - nnvm::ObjectPtr dequantize_node = nullptr; - - DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { - if (node->is_variable()) - return; - if (node->op() == Op::Get(QUANTIZED_FC_NAME)) { - fc_node = node; - } else if (node->op() == Op::Get("_contrib_requantize")) { - requantize_node = node; - } else if (node->op() == Op::Get("_contrib_dequantize")) { - dequantize_node = node; - } - }); - - CHECK_NOTNULL(fc_node); - CHECK_NOTNULL(requantize_node); - auto const& requantize_param = nnvm::get(requantize_node->attrs.parsed); - CHECK(requantize_param.min_calib_range.has_value()); - CHECK(requantize_param.max_calib_range.has_value()); - - // When only fused quantized_fullyconnected and requantize, set min/max_cablib_range, - // When fused quantized_fullyconnected + requantize + dequantize, set dequantize flag to true. - if (dequantize_node != nullptr) { - fc_node->attrs.dict["enable_float_output"] = "True"; - } else { - fc_node->attrs.dict["min_calib_range"] = - std::to_string(requantize_param.min_calib_range.value()); - fc_node->attrs.dict["max_calib_range"] = - std::to_string(requantize_param.max_calib_range.value()); - } - fc_node->op()->attr_parser(&(fc_node->attrs)); - return fc_node; - } - - SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { - auto selector = - std::make_shared(disable_fuse_all, disable_float_output); - return selector; - } - - void ConnectSubgraphOutputs(const nnvm::ObjectPtr n, - std::vector* output_entries) const override { - for (size_t i = 0; i < output_entries->size(); ++i) { - auto entry_ptr = output_entries->at(i); - *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; - } - } - - private: - bool disable_fuse_all; - bool disable_float_output; -}; - -} // namespace op -} // namespace mxnet - -#endif // if MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h deleted file mode 100644 index 6c384a18f703..000000000000 --- a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ -#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ -#if MXNET_USE_ONEDNN == 1 - -#include -#include - -#include "../../quantization/requantize-inl.h" -#include "../common.h" -#include "dnnl_subgraph_base-inl.h" - -namespace mxnet { -namespace op { - -class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector { - public: - /*! \brief pattern match status */ - enum SelectStatus { - kFail = 0, - kStart, - kRequantize, - kSuccess, - }; - - private: - bool disable_all; - bool disable_float_output; - SelectStatus status; - std::vector matched_list; - - public: - explicit SgDNNLMatmulPostQuantizeSelector(const bool dis_all, const bool dis_float_output) - : disable_all(dis_all), disable_float_output(dis_float_output) {} - - bool Select(const nnvm::Node& n) override { - if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") || - n.op() == Op::Get("_sg_onednn_selfatt_valatt") || - n.op() == Op::Get("_sg_onednn_batch_dot"))) { - status = disable_all ? kSuccess : kStart; - matched_list.clear(); - matched_list.push_back(&n); - return true; - } - return false; - } - - bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override { - return false; - } - - bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) - return false; - // If n isn't the last matched node, then we encoutered a internal - // branch, we should pop out the node behind n and stop fusion. - if (matched_list.back() != &n) { - if (std::find(matched_list.begin(), matched_list.end(), &n) != matched_list.end()) { - while (matched_list.back() != &n) { - matched_list.pop_back(); - } - } - - status = kSuccess; - return false; - } - - switch (status) { - case kStart: - if (new_node.op() == Op::Get("_contrib_requantize")) { - auto const& param = nnvm::get(new_node.attrs.parsed); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - matched_list.push_back(&new_node); - status = kRequantize; - return true; - } - } - case kRequantize: - if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { - matched_list.push_back(&new_node); - status = kSuccess; - return true; - } - default: - status = kSuccess; - return false; - } - } - - std::vector Filter(const std::vector& candidates) override { - if ((status != kSuccess) || (matched_list.size() <= 1)) { - return std::vector(0); - } else { - std::vector ret; - for (auto i : matched_list) { - auto non_const_i = const_cast(i); - if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { - ret.push_back(non_const_i); - } - } - return ret; - } - } - - void Reset() override { - CHECK_GE(matched_list.size(), 1); - auto new_selector = SgDNNLMatmulPostQuantizeSelector(disable_all, disable_float_output); - new_selector.Select(*matched_list[0]); - *this = new_selector; - } -}; - -class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty { - public: - SgDNNLMatmulPostQuantizeProperty() { - disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FUSE_ALL", false); - disable_float_output = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FLOAT_OUTPUT", false); - } - - static SubgraphPropertyPtr Create() { - static const std::string& name = "oneDNN Matmul post-quantization optimization pass"; - auto property = std::make_shared(); - property->SetAttr("property_name", name); - property->SetAttr("inference_only", true); - return property; - } - - nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, - const int subgraph_id = 0) const override { - nnvm::ObjectPtr interleaved_node = nullptr; - nnvm::ObjectPtr requantize_node = nullptr; - nnvm::ObjectPtr dequantize_node = nullptr; - - DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { - if (node->is_variable()) - return; - if (node->op() == Op::Get("_sg_onednn_selfatt_qk") || - node->op() == Op::Get("_sg_onednn_selfatt_valatt") || - node->op() == Op::Get("_sg_onednn_batch_dot")) { - interleaved_node = node; - } else if (node->op() == Op::Get("_contrib_requantize")) { - requantize_node = node; - } else if (node->op() == Op::Get("_contrib_dequantize")) { - dequantize_node = node; - } - }); - - CHECK_NOTNULL(interleaved_node); - CHECK_NOTNULL(requantize_node); - auto const& requantize_param = nnvm::get(requantize_node->attrs.parsed); - CHECK(requantize_param.min_calib_range.has_value()); - CHECK(requantize_param.max_calib_range.has_value()); - - // When only fusing quantized_interleaved_matmul and requantize, set min/max_cablib_range, - // When fusing quantized_interleaved_matmul + requantize + dequantize, - // set dequantize flag to true. - if (dequantize_node != nullptr) { - interleaved_node->attrs.dict["enable_float_output"] = "True"; - } else { - interleaved_node->attrs.dict["min_calib_range"] = - std::to_string(requantize_param.min_calib_range.value()); - interleaved_node->attrs.dict["max_calib_range"] = - std::to_string(requantize_param.max_calib_range.value()); - } - interleaved_node->op()->attr_parser(&(interleaved_node->attrs)); - return interleaved_node; - } - - SubgraphSelectorPtr CreateSubgraphSelector() const override { - auto selector = - std::make_shared(disable_fuse_all, disable_float_output); - return selector; - } - - private: - bool disable_fuse_all; - bool disable_float_output; -}; - -} // namespace op -} // namespace mxnet - -#endif // if MXNET_USE_ONEDNN == 1 -#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h index 662b792d737d..cddf4b447810 100644 --- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h +++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h @@ -20,110 +20,161 @@ #define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_ONEDNN == 1 +#include #include #include #include -#include "../../nn/dnnl/dnnl_convolution-inl.h" -#include "../../quantization/requantize-inl.h" -#include "../common.h" +#include "operator/nn/dnnl/dnnl_convolution-inl.h" +#include "operator/nn/fully_connected-inl.h" +#include "operator/quantization/requantize-inl.h" +#include "operator/tensor/elemwise_binary_op-inl.h" +#include "operator/subgraph/common.h" #include "dnnl_conv-inl.h" #include "dnnl_subgraph_base-inl.h" namespace mxnet { namespace op { - -class SgDNNLPostQuantizeSelector : public SubgraphSelector { - public: +namespace { +const std::set support_req_fusion_op = {"_contrib_quantized_elemwise_add", + "_contrib_quantized_elemwise_mul", + "_contrib_quantized_npi_add", + "_sg_onednn_conv", + "_sg_onednn_fully_connected", + "_sg_onednn_selfatt_qk", + "_sg_onednn_selfatt_valatt", + "_sg_onednn_batch_dot"}; +} // namespace + +class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 { + private: /*! \brief pattern match status */ - enum SelectStatus { + enum class SelectStatus { kFail = 0, kStart, + kRequantize, kSuccess, }; - private: + bool fuse_all; + bool float_output; SelectStatus status; - std::vector matched_list; + std::vector matched_list; std::set support_requantize_fusion_op_name; public: - SgDNNLPostQuantizeSelector() { - support_requantize_fusion_op_name.insert("_sg_onednn_conv"); - support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add"); - support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add"); + explicit SgDNNLPostQuantizeSelector(const bool fuse_all, const bool float_output) + : fuse_all(fuse_all), float_output(float_output) { + support_requantize_fusion_op_name = support_req_fusion_op; } - bool Select(const nnvm::Node& n) override { - if (n.op() && support_requantize_fusion_op_name.count(n.op()->name)) { - if (n.op() == Op::Get("_sg_onednn_conv")) { - auto const& param = nnvm::get(n.attrs.parsed); - if (param.full_conv_param.dnnl_param.quantized) { - status = kStart; - matched_list.clear(); - matched_list.push_back(&n); - return true; - } - } else if (n.op()->name == "_contrib_quantized_elemwise_add" || - n.op()->name == "_contrib_quantized_npi_add") { - status = kStart; - matched_list.clear(); - matched_list.push_back(&n); - return true; - } + bool Select(const BiDirectedNode& n) override { + const nnvm::Node* raw_node = n.node; + if (fuse_all && raw_node->op() && + support_requantize_fusion_op_name.count(raw_node->op()->name)) { + status = SelectStatus::kStart; + matched_list.clear(); + matched_list.emplace_back(&n); + return true; } return false; } - bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override { + bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { return false; } - bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) + bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) override { + const nnvm::Node* raw_node = n.node; + const nnvm::Node* raw_new_node = new_node.node; + if (status == SelectStatus::kFail || status == SelectStatus::kSuccess || + raw_new_node->is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. if (matched_list.back() != &n) { - status = kFail; + if (std::find(matched_list.begin(), matched_list.end(), &n) != matched_list.end()) { + while (matched_list.back() != &n) { + matched_list.pop_back(); + } + } + status = SelectStatus::kSuccess; return false; } - if (new_node.op()->name == "_contrib_requantize") { - auto const& param = nnvm::get(new_node.attrs.parsed); - if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { - matched_list.push_back(&new_node); - status = kSuccess; - return true; - } else { - status = kFail; - } + + switch (status) { + case SelectStatus::kStart: + if (raw_new_node->op() == Op::Get("_contrib_requantize")) { + auto const& param = nnvm::get(raw_new_node->attrs.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + matched_list.emplace_back(&new_node); + status = SelectStatus::kRequantize; + if (raw_node->op() == Op::Get("_sg_onednn_conv")) { + status = SelectStatus::kSuccess; + } + return true; + } + } + case SelectStatus::kRequantize: + if (float_output && raw_new_node->op() == Op::Get("_contrib_dequantize")) { + CHECK(raw_node->op() == Op::Get("_contrib_requantize")); + if (n.outputs.size() > 1) { + // check if requantize have other outputs than dequantize + // if it has we can't fuse dequantize + for (const auto& kv : n.outputs) { + const auto& node = kv.first; + if (node->op() != Op::Get("_contrib_dequantize")) { + status = SelectStatus::kSuccess; + return false; + } + } + } + matched_list.emplace_back(&new_node); + status = SelectStatus::kSuccess; + return true; + } + default: + status = SelectStatus::kSuccess; + return false; } - return false; } - std::vector Filter(const std::vector& candidates) override { - if (status != kSuccess) { - return std::vector(0); + std::vector Filter(const std::vector& candidates) override { + if (status != SelectStatus::kSuccess || (matched_list.size() <= 1)) { + return std::vector(0); } else { - return candidates; + std::vector ret; + for (auto i : matched_list) { + auto non_const_i = const_cast(i); + if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { + ret.push_back(non_const_i); + } + } + return ret; } } void Reset() override { CHECK_GE(matched_list.size(), 1); - auto new_selector = SgDNNLPostQuantizeSelector(); + auto new_selector = SgDNNLPostQuantizeSelector(fuse_all, float_output); new_selector.Select(*matched_list[0]); *this = new_selector; } }; class SgDNNLPostQuantizeProperty : public SubgraphProperty { + private: + bool fuse_all; + bool float_output; + std::set support_requantize_fusion_op_name; + public: SgDNNLPostQuantizeProperty() { - support_requantize_fusion_op_name.insert("_sg_onednn_conv"); - support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add"); - support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add"); + fuse_all = dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true); + float_output = dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true); + support_requantize_fusion_op_name = support_req_fusion_op; } + static SubgraphPropertyPtr Create() { static const std::string& name = "oneDNN post-quantization optimization pass"; auto property = std::make_shared(); @@ -131,35 +182,47 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { property->SetAttr("inference_only", true); return property; } + nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym, const int subgraph_id = 0) const override { nnvm::ObjectPtr fuse_node = nullptr; nnvm::ObjectPtr requantize_node = nullptr; + nnvm::ObjectPtr dequantize_node = nullptr; + DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) { if (node->is_variable()) return; - auto& op_name = node->op()->name; - if (support_requantize_fusion_op_name.count(op_name)) { + if (node->op() && support_requantize_fusion_op_name.count(node->op()->name)) { fuse_node = node; - } else if (op_name == "_contrib_requantize") { + } else if (node->op() == Op::Get("_contrib_requantize")) { requantize_node = node; + } else if (node->op() == Op::Get("_contrib_dequantize")) { + dequantize_node = node; } }); + CHECK_NOTNULL(fuse_node); CHECK_NOTNULL(requantize_node); auto const& requantize_param = nnvm::get(requantize_node->attrs.parsed); CHECK(requantize_param.min_calib_range.has_value()); CHECK(requantize_param.max_calib_range.has_value()); - fuse_node->attrs.dict["min_calib_range"] = - std::to_string(requantize_param.min_calib_range.value()); - fuse_node->attrs.dict["max_calib_range"] = - std::to_string(requantize_param.max_calib_range.value()); + + // When only fused quantized operator and requantize, set min/max_cablib_range, + // When fused quantized operator + requantize + dequantize, set dequantize flag to true. + if (dequantize_node != nullptr) { + fuse_node->attrs.dict["enable_float_output"] = "True"; + } else { + fuse_node->attrs.dict["min_calib_range"] = + std::to_string(requantize_param.min_calib_range.value()); + fuse_node->attrs.dict["max_calib_range"] = + std::to_string(requantize_param.max_calib_range.value()); + } fuse_node->op()->attr_parser(&(fuse_node->attrs)); return fuse_node; } - SubgraphSelectorPtr CreateSubgraphSelector() const override { - auto selector = std::make_shared(); + SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { + auto selector = std::make_shared(fuse_all, float_output); return selector; } @@ -170,10 +233,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty { *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; } } - - private: - std::set support_requantize_fusion_op_name; }; + } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc index 4a5f6a6d129f..9727187ab9fd 100644 --- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc +++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc @@ -22,10 +22,7 @@ #include "dnnl_batch_dot_property.h" #include "dnnl_bn_relu_property.h" #include "dnnl_conv_property.h" -#include "dnnl_elemwisemul_post_quantize_property.h" -#include "dnnl_fc_post_quantize_property.h" #include "dnnl_fc_property.h" -#include "dnnl_matmul_post_quantize_property.h" #include "dnnl_post_quantize_align_scale_property.h" #include "dnnl_post_quantize_property.h" #include "dnnl_transformer_qk_property.h" @@ -54,11 +51,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLTransformerValAttPropert MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty) .set_attr("quantize", true); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCPostQuantizeProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, ElemwiseMulPostQuantizeProperty); MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeAlignScaleProperty); -MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLMatmulPostQuantizeProperty) - .set_attr("quantize", true); } // namespace op } // namespace mxnet