From 00ae900530244a3d14b1e4af2e4d4af659ef34d1 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Wed, 21 Feb 2018 13:25:19 +0000 Subject: [PATCH 1/2] add infer_type for regression ops --- src/operator/regression_output.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index 7b0fbae3bccb..c212f05f2572 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -23,6 +23,7 @@ */ #include "./regression_output-inl.h" +#include "./elemwise_op_common.h" #define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$) \ NNVM_REGISTER_OP(__name$) \ @@ -33,6 +34,7 @@ return std::vector{"data", "label"}; \ }) \ .set_attr("FInferShape", RegressionOpShape) \ + .set_attr("FInferType", ElemwiseType<2, 1>) \ .set_attr("FGradient", RegressionOpGrad{__bwdop$}) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ @@ -48,6 +50,7 @@ .set_num_inputs(2) \ .set_num_outputs(2) \ .set_attr_parser(ParamParser) \ + .set_attr("FInferType", ElemwiseType<2, 2>) \ .set_attr("TIsBackward", true) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ From 02a365c0126a745e1b97a8f7f02da2ab3c8672a4 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Fri, 23 Feb 2018 11:29:31 +0000 Subject: [PATCH 2/2] trigger CI --- src/operator/regression_output.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index c212f05f2572..0b8ce69062bd 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -25,6 +25,7 @@ #include "./regression_output-inl.h" #include "./elemwise_op_common.h" + #define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$) \ NNVM_REGISTER_OP(__name$) \ .set_num_inputs(2) \