From 68ed790ab97adbb87e91710cba85b8364ea2ef9c Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Tue, 23 Jan 2018 16:14:49 +0000 Subject: [PATCH 1/7] refactor regression ops --- src/operator/regression_output-inl.h | 225 ++++++++++----------------- src/operator/regression_output.cc | 109 +++++++++---- src/operator/regression_output.cu | 41 ++--- 3 files changed, 185 insertions(+), 190 deletions(-) diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index 08b2f0a4a813..672e0aa269d7 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -18,28 +18,29 @@ */ /*! - * Copyright (c) 2015 by Contributors * \file regression_ouput-inl.h * \brief Regression output operator. - */ +*/ #ifndef MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ #define MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ -#include -#include -#include -#include +#include #include #include +#include "./tensor/init_op.h" +#include "./mshadow_op.h" +#include "./mxnet_op.h" #include "./operator_common.h" namespace mxnet { namespace op { +/*! + * \brief regression namespace + */ namespace reg_enum { enum RegressionOutputOpInputs {kData, kLabel}; enum RegressionOutputOutputs {kOut}; -enum RegressionOutputType {kLinear, kLogistic, kMAE}; } // reg_enum struct RegressionOutputParam : public dmlc::Parameter { @@ -50,146 +51,86 @@ struct RegressionOutputParam : public dmlc::Parameter { }; }; -// Special Operator to output regression value in forward -// And get gradient in calculation. -template -class RegressionOutputOp : public Operator { - public: - explicit RegressionOutputOp(RegressionOutputParam param) : param_(param) {} - - virtual void Forward(const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2U) << "RegressionOutputOp Input: [data, label]"; - CHECK_EQ(out_data.size(), 1U) << "RegressionOutputOp Output: [output]"; - Stream *s = ctx.get_stream(); - Tensor data = in_data[reg_enum::kData].FlatTo2D(s); - Tensor out = out_data[reg_enum::kOut].FlatTo2D(s); - Assign(out, req[reg_enum::kOut], F(data)); - } - - virtual void Backward(const OpContext &ctx, - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &req, - const std::vector &in_grad, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - CHECK_EQ(in_data.size(), 2U); - CHECK_EQ(out_grad.size(), 1U); - CHECK_GE(in_grad.size(), 1U); - CHECK_GE(req.size(), 1U); - Stream *s = ctx.get_stream(); - real_t num_output = - in_data[reg_enum::kLabel].Size()/in_data[reg_enum::kLabel].shape_[0]; - Tensor out = out_data[reg_enum::kOut].FlatTo2D(s); - Tensor grad = in_grad[reg_enum::kData].FlatTo2D(s); - Tensor label = in_data[reg_enum::kLabel] - .get_with_shape(out.shape_, s); - Assign(grad, req[reg_enum::kData], param_.grad_scale/num_output* - F(out, reshape(label, grad.shape_))); - } - - private: - RegressionOutputParam param_; -}; - -// Decalre Factory function, used for dispatch specialization -template -Operator* CreateRegressionOutputOp(reg_enum::RegressionOutputType type, - RegressionOutputParam param); - -#if DMLC_USE_CXX11 -template -class RegressionOutputProp : public OperatorProperty { - public: - std::vector ListArguments() const override { - return {"data", "label"}; - } - - void Init(const std::vector >& kwargs) override { - param_.Init(kwargs); - } - - std::map GetParams() const override { - return param_.__DICT__(); - } - - bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { - using namespace mshadow; - CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]"; - const TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; - auto &lshape = (*in_shape)[1]; - if (lshape.ndim() == 0) { - // special treatment for 1D output, to allow 1D label by default. - // Think about change convention later - if (dshape.ndim() == 2 && dshape[1] == 1) { - lshape = Shape1(dshape[0]); - } else { - lshape = dshape; - } - } else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) { - std::ostringstream os; - os << "Shape inconsistent, Provided=" << lshape << ',' - << " inferred shape=" << dshape; - throw ::mxnet::op::InferShapeError(os.str(), 1); - } - out_shape->clear(); - out_shape->push_back(dshape); - return true; - } - - OperatorProperty* Copy() const override { - auto ptr = new RegressionOutputProp(); - ptr->param_ = param_; - return ptr; - } - - std::string TypeString() const override { - switch (type) { - case reg_enum::kLinear: return "LinearRegressionOutput"; - case reg_enum::kLogistic: return "LogisticRegressionOutput"; - case reg_enum::kMAE: return "MAERegressionOutput"; - default: LOG(FATAL) << "unknown type"; return ""; +inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace mshadow; + CHECK_EQ(in_attrs->size(), 2U) << "Input:[data, label]"; + const TShape &dshape = in_attrs->at(0); + if (dshape.ndim() == 0) return false; + auto &lshape = (*in_attrs)[1]; + if (lshape.ndim() == 0) { + // special treatment for 1D output, to allow 1D label by default. + // Think about change convention later + if (dshape.ndim() == 2 && dshape[1] == 1) { + lshape = Shape1(dshape[0]); + } else { + lshape = dshape; } + } else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) { + std::ostringstream os; + os << "Shape inconsistent, Provided=" << lshape << ',' + << " inferred shape=" << dshape; + throw ::mxnet::op::InferShapeError(os.str(), 1); } - - std::vector DeclareBackwardDependency( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data) const override { - return {in_data[reg_enum::kLabel], out_data[reg_enum::kOut]}; - } - - std::vector > BackwardInplaceOption( - const std::vector &out_grad, - const std::vector &in_data, - const std::vector &out_data, - const std::vector &in_grad) const override { - return {{out_data[reg_enum::kOut], in_grad[reg_enum::kData]}}; - } - - std::vector > ForwardInplaceOption( - const std::vector &in_data, - const std::vector &out_data) const override { - return {{in_data[reg_enum::kData], out_data[reg_enum::kOut]}}; + out_attrs->clear(); + out_attrs->push_back(dshape); + return true; +} + +template +void RegressionForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, { + DType* in_data = inputs[reg_enum::kData].dptr(); + DType* out_data = outputs[reg_enum::kOut].dptr(); + using namespace mxnet_op; + Kernel, xpu>::Launch( + s, outputs[reg_enum::kOut].Size(), out_data, in_data); + }); + }); +} + +template +void RegressionBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const RegressionOutputParam& param = nnvm::get(attrs.parsed); + mshadow::Stream *s = ctx.get_stream(); + // inputs are in_label, out_data + // outputs are data_grad, label_grad + MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + DType* in_label = inputs[0].dptr(); + DType* out_data = inputs[1].dptr(); + DType* data_grad = outputs[0].dptr(); + using namespace mxnet_op; + Kernel, xpu>::Launch( + s, outputs[0].Size(), data_grad, out_data, in_label); + }); + }); +} + +struct RegressionOpGrad { + const char *op_name; + std::vector operator()(const nnvm::NodePtr& n, + const std::vector& ograds) const { + std::vector heads; + heads.push_back(n->inputs[reg_enum::kLabel]); + heads.emplace_back(nnvm::NodeEntry{n, reg_enum::kOut, 0}); + return MakeGradNode(op_name, n, heads, n->attrs.dict); } +}; - Operator* CreateOperator(Context ctx) const override; - protected: - RegressionOutputParam param_; -}; -#endif // DMLC_USE_CXX11 } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_ diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index 2f8042e9e831..ccd50996e5c8 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -18,41 +18,19 @@ */ /*! - * Copyright (c) 2015 by Contributors - * \file regression_output.cc - * \brief regression output operator + * \file regression_ouput.cc + * \brief Regression output operator. */ + #include "./regression_output-inl.h" -#include "./mshadow_op.h" namespace mxnet { namespace op { -template<> -Operator *CreateRegressionOutputOp(reg_enum::RegressionOutputType type, - RegressionOutputParam param) { - switch (type) { - case reg_enum::kLinear: - return new RegressionOutputOp(param); - case reg_enum::kLogistic: - return new RegressionOutputOp(param); - case reg_enum::kMAE: - return new RegressionOutputOp(param); - default: - LOG(FATAL) << "unknown activation type " << type; - } - return nullptr; -} - -// DO_BIND_DISPATCH comes from operator_common.h -template -Operator *RegressionOutputProp::CreateOperator(Context ctx) const { - DO_BIND_DISPATCH(CreateRegressionOutputOp, type, param_); -} DMLC_REGISTER_PARAMETER(RegressionOutputParam); -MXNET_REGISTER_OP_PROPERTY(LinearRegressionOutput, RegressionOutputProp) +NNVM_REGISTER_OP(LinearRegressionOutput) .describe(R"code(Computes and optimizes for squared loss during backward propagation. Just outputs ``data`` during forward propagation. @@ -68,11 +46,35 @@ By default, gradients of this loss function are scaled by factor `1/n`, where n The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. )code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; + }) +.set_attr("FInferShape", RegressionOpShape) +.set_attr("FGradient", RegressionOpGrad{"_backward_linear_reg_out"}) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr("FCompute", RegressionForward) .add_argument("data", "NDArray-or-Symbol", "Input data to the function.") .add_argument("label", "NDArray-or-Symbol", "Input label to the function.") .add_arguments(RegressionOutputParam::__FIELDS__()); -MXNET_REGISTER_OP_PROPERTY(MAERegressionOutput, RegressionOutputProp) +NNVM_REGISTER_OP(_backward_linear_reg_out) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad + return std::vector >{{1, 0}}; +}) +.set_attr("FCompute", RegressionBackward); + +NNVM_REGISTER_OP(MAERegressionOutput) .describe(R"code(Computes mean absolute error of the input. MAE is a risk metric corresponding to the expected value of the absolute error. @@ -89,11 +91,36 @@ By default, gradients of this loss function are scaled by factor `1/n`, where n The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. )code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; + }) +.set_attr("FInferShape", RegressionOpShape) +.set_attr("FGradient", RegressionOpGrad{"_backward_mae_reg_out"}) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr("FCompute", RegressionForward) .add_argument("data", "NDArray-or-Symbol", "Input data to the function.") .add_argument("label", "NDArray-or-Symbol", "Input label to the function.") .add_arguments(RegressionOutputParam::__FIELDS__()); -MXNET_REGISTER_OP_PROPERTY(LogisticRegressionOutput, RegressionOutputProp) +NNVM_REGISTER_OP(_backward_mae_reg_out) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad + return std::vector >{{1, 0}}; +}) +.set_attr("FCompute", RegressionBackward); + + +NNVM_REGISTER_OP(LogisticRegressionOutput) .describe(R"code(Applies a logistic function to the input. The logistic function, also known as the sigmoid function, is computed as @@ -110,9 +137,35 @@ By default, gradients of this loss function are scaled by factor `1/n`, where n The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. )code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "label"}; + }) +.set_attr("FInferShape", RegressionOpShape) +.set_attr("FGradient", RegressionOpGrad{"_backward_logistic_reg_out"}) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; +}) +.set_attr("FCompute", RegressionForward) .add_argument("data", "NDArray-or-Symbol", "Input data to the function.") .add_argument("label", "NDArray-or-Symbol", "Input label to the function.") .add_arguments(RegressionOutputParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_logistic_reg_out) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad + return std::vector >{{1, 0}}; +}) +.set_attr("FCompute", RegressionBackward); + + + } // namespace op } // namespace mxnet diff --git a/src/operator/regression_output.cu b/src/operator/regression_output.cu index cb951f1fd29f..e3a2e7ea2b2b 100644 --- a/src/operator/regression_output.cu +++ b/src/operator/regression_output.cu @@ -18,31 +18,32 @@ */ /*! - * Copyright (c) 2015 by Contributors - * \file regression_output.cu - * \brief regression output operator + * \file regression_ouput.cu + * \brief Regression output operator. */ #include "./regression_output-inl.h" -#include "./mshadow_op.h" + namespace mxnet { namespace op { -template<> -Operator *CreateRegressionOutputOp(reg_enum::RegressionOutputType type, - RegressionOutputParam param) { - switch (type) { - case reg_enum::kLinear: - return new RegressionOutputOp(param); - case reg_enum::kLogistic: - return new RegressionOutputOp(param); - case reg_enum::kMAE: - return new RegressionOutputOp(param); - default: - LOG(FATAL) << "unknown activation type " << type; - } - return NULL; -} +NNVM_REGISTER_OP(LinearRegressionOutput) +.set_attr("FCompute", RegressionForward); + +NNVM_REGISTER_OP(_backward_linear_reg_out) +.set_attr("FCompute", RegressionBackward); + +NNVM_REGISTER_OP(MAERegressionOutput) +.set_attr("FCompute", RegressionForward); + +NNVM_REGISTER_OP(_backward_mae_reg_out) +.set_attr("FCompute", RegressionBackward); + +NNVM_REGISTER_OP(LogisticRegressionOutput) +.set_attr("FCompute", RegressionForward); + +NNVM_REGISTER_OP(_backward_logistic_reg_out) +.set_attr("FCompute", RegressionBackward); + } // namespace op } // namespace mxnet - From 4815e85e2ee8547416e6a7743d4b325bd1b24bf3 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Wed, 24 Jan 2018 07:56:24 +0000 Subject: [PATCH 2/7] fix err for instantiation of minus_sign --- src/operator/operator_tune.cc | 2 ++ src/operator/regression_output-inl.h | 4 ++++ src/operator/regression_output.cc | 27 +++++++++++++++------------ 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 7cdf7a2078cc..e0f8306565d9 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -286,12 +286,14 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::plus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mul); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus_sign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rminus); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad); // NOLINT() diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index 672e0aa269d7..d886bb324c58 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -111,9 +111,13 @@ void RegressionBackward(const nnvm::NodeAttrs& attrs, DType* in_label = inputs[0].dptr(); DType* out_data = inputs[1].dptr(); DType* data_grad = outputs[0].dptr(); + real_t num_output = inputs[0].Size()/inputs[0].shape_[0]; using namespace mxnet_op; Kernel, xpu>::Launch( s, outputs[0].Size(), data_grad, out_data, in_label); + Kernel, xpu>::Launch( + s, outputs[0].Size(), data_grad, data_grad, + static_cast(param.grad_scale/num_output)); }); }); } diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index ccd50996e5c8..3c758ad6561c 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -37,13 +37,13 @@ Just outputs ``data`` during forward propagation. If :math:`\hat{y}_i` is the predicted value of the i-th sample, and :math:`y_i` is the corresponding target value, then the squared loss estimated over :math:`n` samples is defined as -:math:`\text{SquaredLoss}(y, \hat{y} ) = \frac{1}{n} \sum_{i=0}^{n-1} \left( y_i - \hat{y}_i \right)^2` +:math:`\text{SquaredLoss}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_2` .. note:: Use the LinearRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples. -The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE) .set_num_inputs(2) @@ -69,7 +69,8 @@ NNVM_REGISTER_OP(_backward_linear_reg_out) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad +[](const NodeAttrs& attrs){ + // inputs are in_label and out_data, outputs are data_grad and label_grad return std::vector >{{1, 0}}; }) .set_attr("FCompute", RegressionBackward); @@ -82,13 +83,13 @@ MAE is a risk metric corresponding to the expected value of the absolute error. If :math:`\hat{y}_i` is the predicted value of the i-th sample, and :math:`y_i` is the corresponding target value, then the mean absolute error (MAE) estimated over :math:`n` samples is defined as -:math:`\text{MAE}(y, \hat{y} ) = \frac{1}{n} \sum_{i=0}^{n-1} \left| y_i - \hat{y}_i \right|` +:math:`\text{MAE}(\textbf{Y}, \hat{\textbf{Y}} ) = \frac{1}{n} \sum_{i=0}^{n-1} \lVert \textbf{y}_i - \hat{\textbf{y}}_i \rVert_1` .. note:: Use the MAERegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples. -The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE) .set_num_inputs(2) @@ -114,7 +115,8 @@ NNVM_REGISTER_OP(_backward_mae_reg_out) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad +[](const NodeAttrs& attrs){ + // inputs are in_label and out_data, outputs are data_grad and label_grad return std::vector >{{1, 0}}; }) .set_attr("FCompute", RegressionBackward); @@ -124,7 +126,7 @@ NNVM_REGISTER_OP(LogisticRegressionOutput) .describe(R"code(Applies a logistic function to the input. The logistic function, also known as the sigmoid function, is computed as -:math:`\frac{1}{1+exp(-x)}`. +:math:`\frac{1}{1+exp(-\textbf{x})}`. Commonly, the sigmoid is used to squash the real-valued output of a linear model :math:wTx+b into the [0,1] range so that it can be interpreted as a probability. @@ -133,8 +135,8 @@ It is suitable for binary classification or probability prediction tasks. .. note:: Use the LogisticRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/n`, where n is the number of training examples. -The parameter `grad_scale` can be used to change this scale to `grad_scale/n`. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE) .set_num_inputs(2) @@ -160,7 +162,8 @@ NNVM_REGISTER_OP(_backward_logistic_reg_out) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) .set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ // input in_label, out_data, output data_grad, label_grad +[](const NodeAttrs& attrs){ + // inputs are in_label and out_data, outputs are data_grad and label_grad return std::vector >{{1, 0}}; }) .set_attr("FCompute", RegressionBackward); From 0312f87622ece115606fe4908c5bbf0488e6c556 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Wed, 24 Jan 2018 08:06:28 +0000 Subject: [PATCH 3/7] remove useless header file init_op.h --- src/operator/regression_output-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index d886bb324c58..b9aff707c776 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -27,7 +27,6 @@ #include #include #include -#include "./tensor/init_op.h" #include "./mshadow_op.h" #include "./mxnet_op.h" #include "./operator_common.h" From 3cd1e0d72ee93670e8368b8c1ba8250f990ad630 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Thu, 25 Jan 2018 07:16:02 +0000 Subject: [PATCH 4/7] replace with macro and address other comments --- src/operator/regression_output-inl.h | 8 +- src/operator/regression_output.cc | 147 +++++++++------------------ 2 files changed, 54 insertions(+), 101 deletions(-) diff --git a/src/operator/regression_output-inl.h b/src/operator/regression_output-inl.h index b9aff707c776..4642f8dc4679 100644 --- a/src/operator/regression_output-inl.h +++ b/src/operator/regression_output-inl.h @@ -86,7 +86,7 @@ void RegressionForward(const nnvm::NodeAttrs& attrs, mshadow::Stream *s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, { - DType* in_data = inputs[reg_enum::kData].dptr(); + const DType* in_data = inputs[reg_enum::kData].dptr(); DType* out_data = outputs[reg_enum::kOut].dptr(); using namespace mxnet_op; Kernel, xpu>::Launch( @@ -107,10 +107,10 @@ void RegressionBackward(const nnvm::NodeAttrs& attrs, // outputs are data_grad, label_grad MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - DType* in_label = inputs[0].dptr(); - DType* out_data = inputs[1].dptr(); + const DType* in_label = inputs[0].dptr(); + const DType* out_data = inputs[1].dptr(); DType* data_grad = outputs[0].dptr(); - real_t num_output = inputs[0].Size()/inputs[0].shape_[0]; + const real_t num_output = inputs[0].Size()/inputs[0].shape_[0]; using namespace mxnet_op; Kernel, xpu>::Launch( s, outputs[0].Size(), data_grad, out_data, in_label); diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index 3c758ad6561c..56aabad9daab 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -24,13 +24,45 @@ #include "./regression_output-inl.h" +#define MXNET_OPERATOR_REGISTER_REGRESSION_FWD(__name$, __kernel$, __bwdop$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"data", "label"}; \ + }) \ + .set_attr("FInferShape", RegressionOpShape) \ + .set_attr("FGradient", RegressionOpGrad{__bwdop$}) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .set_attr("FCompute", RegressionForward) \ + .add_argument("data", "NDArray-or-Symbol", "Input data to the function.") \ + .add_argument("label", "NDArray-or-Symbol", "Input label to the function.") \ + .add_arguments(RegressionOutputParam::__FIELDS__()) + +#define MXNET_OPERATOR_REGISTER_REGRESSION_BWD(__name$, __kernel$) \ + NNVM_REGISTER_OP(__name$) \ + .set_num_inputs(2) \ + .set_num_outputs(2) \ + .set_attr_parser(ParamParser) \ + .set_attr("TIsBackward", true) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{1, 0}}; \ + }) \ + .set_attr("FCompute", RegressionBackward) + namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(RegressionOutputParam); -NNVM_REGISTER_OP(LinearRegressionOutput) +MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutputi, + mshadow_op::identity, "_backward_linear_reg_out") .describe(R"code(Computes and optimizes for squared loss during backward propagation. Just outputs ``data`` during forward propagation. @@ -42,40 +74,15 @@ then the squared loss estimated over :math:`n` samples is defined as .. note:: Use the LinearRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. -)code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "label"}; - }) -.set_attr("FInferShape", RegressionOpShape) -.set_attr("FGradient", RegressionOpGrad{"_backward_linear_reg_out"}) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; -}) -.set_attr("FCompute", RegressionForward) -.add_argument("data", "NDArray-or-Symbol", "Input data to the function.") -.add_argument("label", "NDArray-or-Symbol", "Input label to the function.") -.add_arguments(RegressionOutputParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_linear_reg_out) -.set_num_inputs(2) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - // inputs are in_label and out_data, outputs are data_grad and label_grad - return std::vector >{{1, 0}}; -}) -.set_attr("FCompute", RegressionBackward); - -NNVM_REGISTER_OP(MAERegressionOutput) +)code" ADD_FILELINE); + +MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_linear_reg_out, mshadow_op::minus); + +MXNET_OPERATOR_REGISTER_REGRESSION_FWD(MAERegressionOutput, + mshadow_op::identity, "_backward_mae_reg_out") .describe(R"code(Computes mean absolute error of the input. MAE is a risk metric corresponding to the expected value of the absolute error. @@ -88,41 +95,15 @@ then the mean absolute error (MAE) estimated over :math:`n` samples is defined a .. note:: Use the MAERegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. -)code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "label"}; - }) -.set_attr("FInferShape", RegressionOpShape) -.set_attr("FGradient", RegressionOpGrad{"_backward_mae_reg_out"}) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; -}) -.set_attr("FCompute", RegressionForward) -.add_argument("data", "NDArray-or-Symbol", "Input data to the function.") -.add_argument("label", "NDArray-or-Symbol", "Input label to the function.") -.add_arguments(RegressionOutputParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_mae_reg_out) -.set_num_inputs(2) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - // inputs are in_label and out_data, outputs are data_grad and label_grad - return std::vector >{{1, 0}}; -}) -.set_attr("FCompute", RegressionBackward); - - -NNVM_REGISTER_OP(LogisticRegressionOutput) +)code" ADD_FILELINE); + +MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_mae_reg_out, mshadow_op::minus_sign); + +MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LogisticRegressionOutput, + mshadow_op::sigmoid, "_backward_logistic_reg_out") .describe(R"code(Applies a logistic function to the input. The logistic function, also known as the sigmoid function, is computed as @@ -135,40 +116,12 @@ It is suitable for binary classification or probability prediction tasks. .. note:: Use the LogisticRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of dimensions of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. -)code" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data", "label"}; - }) -.set_attr("FInferShape", RegressionOpShape) -.set_attr("FGradient", RegressionOpGrad{"_backward_logistic_reg_out"}) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; -}) -.set_attr("FCompute", RegressionForward) -.add_argument("data", "NDArray-or-Symbol", "Input data to the function.") -.add_argument("label", "NDArray-or-Symbol", "Input label to the function.") -.add_arguments(RegressionOutputParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_logistic_reg_out) -.set_num_inputs(2) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - // inputs are in_label and out_data, outputs are data_grad and label_grad - return std::vector >{{1, 0}}; -}) -.set_attr("FCompute", RegressionBackward); - +)code" ADD_FILELINE); +MXNET_OPERATOR_REGISTER_REGRESSION_BWD(_backward_logistic_reg_out, mshadow_op::minus); } // namespace op } // namespace mxnet From 08e4b0d0ce20822bd06e2a7a4a739ef4c4edc9c6 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Thu, 25 Jan 2018 07:20:22 +0000 Subject: [PATCH 5/7] update --- src/operator/regression_output.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index 56aabad9daab..f457bfcfc918 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -61,7 +61,7 @@ namespace op { DMLC_REGISTER_PARAMETER(RegressionOutputParam); -MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutputi, +MXNET_OPERATOR_REGISTER_REGRESSION_FWD(LinearRegressionOutput, mshadow_op::identity, "_backward_linear_reg_out") .describe(R"code(Computes and optimizes for squared loss during backward propagation. Just outputs ``data`` during forward propagation. From 81dd5ccfbd159d2d901e56dd072089498e283e74 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Fri, 26 Jan 2018 03:21:15 +0000 Subject: [PATCH 6/7] minor revise docs --- src/operator/regression_output.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/regression_output.cc b/src/operator/regression_output.cc index f457bfcfc918..7b0fbae3bccb 100644 --- a/src/operator/regression_output.cc +++ b/src/operator/regression_output.cc @@ -74,7 +74,7 @@ then the squared loss estimated over :math:`n` samples is defined as .. note:: Use the LinearRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE); @@ -95,7 +95,7 @@ then the mean absolute error (MAE) estimated over :math:`n` samples is defined a .. note:: Use the MAERegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE); @@ -116,7 +116,7 @@ It is suitable for binary classification or probability prediction tasks. .. note:: Use the LogisticRegressionOutput as the final output layer of a net. -By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of features of a training example. +By default, gradients of this loss function are scaled by factor `1/m`, where m is the number of regression outputs of a training example. The parameter `grad_scale` can be used to change this scale to `grad_scale/m`. )code" ADD_FILELINE); From 4cd096ea590432171e48ac65c78ae37d3b80b9d6 Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Sat, 27 Jan 2018 03:01:13 +0000 Subject: [PATCH 7/7] add mae test --- tests/python/unittest/test_operator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d169a5455bb8..aedd7c9ab590 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -244,7 +244,9 @@ def test_regression(): check_regression(mx.symbol.LinearRegressionOutput, lambda x: x, lambda x, y : x - y) - + check_regression(mx.symbol.MAERegressionOutput, + lambda x: x, + lambda x, y : np.where(x > y, np.ones(x.shape), -np.ones(x.shape))) def check_softmax_grad(xpu): x = mx.sym.Variable('x')