diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index 7b0fbae3bccb..0b8ce69062bd 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -23,6 +23,8 @@ */ #include "./regression_output-inl.h" +#include "./elemwise_op_common.h" + #define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$) \ NNVM_REGISTER_OP(__name$) \ @@ -33,6 +35,7 @@ return std::vector{"data", "label"}; \ }) \ .set_attr("FInferShape", RegressionOpShape) \ + .set_attr("FInferType", ElemwiseType<2, 1>) \ .set_attr("FGradient", RegressionOpGrad{__bwdop$}) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ @@ -48,6 +51,7 @@ .set_num_inputs(2) \ .set_num_outputs(2) \ .set_attr_parser(ParamParser) \ + .set_attr("FInferType", ElemwiseType<2, 2>) \ .set_attr("TIsBackward", true) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \