diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h index 9a09d91ae5d0..438495870166 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc-inl.h @@ -36,7 +36,8 @@ static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) { op_name == "sqrt" || op_name == "exp" || op_name == "abs" || - op_name == "clip") { + op_name == "clip" || + op_name == "LeakyReLU") { return true; } else { return false; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index e2b1807b6559..dbaffc379ec3 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -286,8 +286,16 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx, if (fuse_requantize || mkldnn_param.enable_float_output) { float tmp_scale_ = 1.0f; if (fuse_requantize) { - tmp_scale_ = - GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_) / data_scale_; + if (mkldnn_param.with_eltwise) { + tmp_scale_ = 1.0 / data_scale_; + full_param_.eltwise_param.scale = + GetQuantizeScale(output.dtype(), cached_min_output_, cached_max_output_); + } else { + tmp_scale_ = + GetQuantizeScale(output.dtype(), + cached_min_output_, + cached_max_output_) / data_scale_; + } } else { tmp_scale_ = 1.0 / data_scale_; } @@ -405,6 +413,10 @@ static void SgMKLDNNFCParamParser(nnvm::NodeAttrs *attrs) { if (op_name == "Activation") { const ActivationParam act_param = nnvm::get(node->attrs.parsed); full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param); + } else if (op_name == "LeakyReLU") { + const auto act_param = nnvm::get(node->attrs.parsed); + full_param.eltwise_param.alpha = act_param.slope; + full_param.eltwise_param.alg = GetMKLDNNActAlgo(act_param); } else if (op_name == "clip") { const ClipParam clip_param = nnvm::get(node->attrs.parsed); full_param.eltwise_param.alg = mkldnn::algorithm::eltwise_bounded_relu; diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index aecb3a7a8477..432772d36298 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -102,6 +102,16 @@ class SgMKLDNNFCSelector : public SubgraphSelector { return true; } } + if (new_node.op() == Op::Get("LeakyReLU")) { + const LeakyReLUParam ¶m = nnvm::get(new_node.attrs.parsed); + if (param.act_type == leakyrelu::kLeakyReLU || + param.act_type == leakyrelu::kELU || + param.act_type == leakyrelu::kGELU) { + matched_list_.push_back(&new_node); + status_ = kSuccess; + return true; + } + } if (!quantized_ && (new_node.op() == Op::Get("square") || new_node.op() == Op::Get("sqrt") || new_node.op() == Op::Get("exp"))) { diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 65b73e438ea6..61f738f9c7d3 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -49,7 +49,7 @@ } DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] -fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', +fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'gelu', 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] def check_qsym_calibrated(qsym, out_type, name='conv'): @@ -664,6 +664,8 @@ def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): no_bias=no_bias, flatten=flatten) if alg in ['relu', 'sigmoid', 'tanh', 'softrelu']: sym = mx.symbol.Activation(data=fc, name='act', act_type=alg) + elif alg == "gelu": + sym = mx.symbol.LeakyReLU(data=fc, act_type='gelu') elif alg == 'square': sym = mx.symbol.square(data=fc, name='square') elif alg == 'square_root':