Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,12 +286,14 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rdiv); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::plus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::mul); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::div); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::minus_sign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rminus); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rdiv); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div_grad); // NOLINT()
Expand Down
228 changes: 86 additions & 142 deletions src/operator/regression_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,28 @@
*/

/*!
* Copyright (c) 2015 by Contributors
* \file regression_ouput-inl.h
* \brief Regression output operator.
*/
*/
#ifndef MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_
#define MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_

#include <dmlc/logging.h>
#include <mxnet/operator.h>
#include <map>
#include <string>
#include <mxnet/operator_util.h>
#include <vector>
#include <utility>
#include "./mshadow_op.h"
#include "./mxnet_op.h"
#include "./operator_common.h"

namespace mxnet {
namespace op {

/*!
* \brief regression namespace
*/
namespace reg_enum {
enum RegressionOutputOpInputs {kData, kLabel};
enum RegressionOutputOutputs {kOut};
enum RegressionOutputType {kLinear, kLogistic, kMAE};
} // reg_enum

struct RegressionOutputParam : public dmlc::Parameter<RegressionOutputParam> {
Expand All @@ -50,146 +50,90 @@ struct RegressionOutputParam : public dmlc::Parameter<RegressionOutputParam> {
};
};

// Special Operator to output regression value in forward
// And get gradient in calculation.
template<typename xpu, typename ForwardOp, typename BackwardOp>
class RegressionOutputOp : public Operator {
public:
explicit RegressionOutputOp(RegressionOutputParam param) : param_(param) {}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 2U) << "RegressionOutputOp Input: [data, label]";
CHECK_EQ(out_data.size(), 1U) << "RegressionOutputOp Output: [output]";
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 2> data = in_data[reg_enum::kData].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s);
Assign(out, req[reg_enum::kOut], F<ForwardOp>(data));
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_grad.size(), 1U);
CHECK_GE(in_grad.size(), 1U);
CHECK_GE(req.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();
real_t num_output =
in_data[reg_enum::kLabel].Size()/in_data[reg_enum::kLabel].shape_[0];
Tensor<xpu, 2> out = out_data[reg_enum::kOut].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> grad = in_grad[reg_enum::kData].FlatTo2D<xpu, real_t>(s);
Tensor<xpu, 2> label = in_data[reg_enum::kLabel]
.get_with_shape<xpu, 2, real_t>(out.shape_, s);
Assign(grad, req[reg_enum::kData], param_.grad_scale/num_output*
F<BackwardOp>(out, reshape(label, grad.shape_)));
}

private:
RegressionOutputParam param_;
};

// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateRegressionOutputOp(reg_enum::RegressionOutputType type,
RegressionOutputParam param);

#if DMLC_USE_CXX11
template<reg_enum::RegressionOutputType type>
class RegressionOutputProp : public OperatorProperty {
public:
std::vector<std::string> ListArguments() const override {
return {"data", "label"};
}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}

bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]";
const TShape &dshape = in_shape->at(0);
if (dshape.ndim() == 0) return false;
auto &lshape = (*in_shape)[1];
if (lshape.ndim() == 0) {
// special treatment for 1D output, to allow 1D label by default.
// Think about change convention later
if (dshape.ndim() == 2 && dshape[1] == 1) {
lshape = Shape1(dshape[0]);
} else {
lshape = dshape;
}
} else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) {
std::ostringstream os;
os << "Shape inconsistent, Provided=" << lshape << ','
<< " inferred shape=" << dshape;
throw ::mxnet::op::InferShapeError(os.str(), 1);
}
out_shape->clear();
out_shape->push_back(dshape);
return true;
}

OperatorProperty* Copy() const override {
auto ptr = new RegressionOutputProp<type>();
ptr->param_ = param_;
return ptr;
}

std::string TypeString() const override {
switch (type) {
case reg_enum::kLinear: return "LinearRegressionOutput";
case reg_enum::kLogistic: return "LogisticRegressionOutput";
case reg_enum::kMAE: return "MAERegressionOutput";
default: LOG(FATAL) << "unknown type"; return "";
inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
using namespace mshadow;
CHECK_EQ(in_attrs->size(), 2U) << "Input:[data, label]";
const TShape &dshape = in_attrs->at(0);
if (dshape.ndim() == 0) return false;
auto &lshape = (*in_attrs)[1];
if (lshape.ndim() == 0) {
// special treatment for 1D output, to allow 1D label by default.
// Think about change convention later
if (dshape.ndim() == 2 && dshape[1] == 1) {
lshape = Shape1(dshape[0]);
} else {
lshape = dshape;
}
} else if (lshape[0] != dshape[0] || lshape.Size() != dshape.Size()) {
std::ostringstream os;
os << "Shape inconsistent, Provided=" << lshape << ','
<< " inferred shape=" << dshape;
throw ::mxnet::op::InferShapeError(os.str(), 1);
}

std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
return {in_data[reg_enum::kLabel], out_data[reg_enum::kOut]};
}

std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {{out_data[reg_enum::kOut], in_grad[reg_enum::kData]}};
}

std::vector<std::pair<int, void*> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<void*> &out_data) const override {
return {{in_data[reg_enum::kData], out_data[reg_enum::kOut]}};
out_attrs->clear();
out_attrs->push_back(dshape);
return true;
}

template<typename xpu, typename ForwardOp>
void RegressionForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[reg_enum::kData].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[reg_enum::kOut], Req, {
const DType* in_data = inputs[reg_enum::kData].dptr<DType>();
DType* out_data = outputs[reg_enum::kOut].dptr<DType>();
using namespace mxnet_op;
Kernel<op_with_req<ForwardOp, Req>, xpu>::Launch(
s, outputs[reg_enum::kOut].Size(), out_data, in_data);
});
});
}

template<typename xpu, typename BackwardOp>
void RegressionBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const RegressionOutputParam& param = nnvm::get<RegressionOutputParam>(attrs.parsed);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
// inputs are in_label, out_data
// outputs are data_grad, label_grad
MSHADOW_REAL_TYPE_SWITCH(inputs[1].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
const DType* in_label = inputs[0].dptr<DType>();
const DType* out_data = inputs[1].dptr<DType>();
DType* data_grad = outputs[0].dptr<DType>();
const real_t num_output = inputs[0].Size()/inputs[0].shape_[0];
using namespace mxnet_op;
Kernel<op_with_req<BackwardOp, Req>, xpu>::Launch(
s, outputs[0].Size(), data_grad, out_data, in_label);
Kernel<op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, outputs[0].Size(), data_grad, data_grad,
static_cast<DType>(param.grad_scale/num_output));
});
});
}

struct RegressionOpGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
std::vector<nnvm::NodeEntry> heads;
heads.push_back(n->inputs[reg_enum::kLabel]);
heads.emplace_back(nnvm::NodeEntry{n, reg_enum::kOut, 0});
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};

Operator* CreateOperator(Context ctx) const override;

protected:
RegressionOutputParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_REGRESSION_OUTPUT_INL_H_
Loading