diff --git a/mshadow b/mshadow index cce0b32a892f..da390521662f 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit cce0b32a892fbf5de28fa1feabea50b8cf85441b +Subproject commit da390521662f99adcc7963e97141738b57974573 diff --git a/ps-lite b/ps-lite index e99b0f288096..55ee9bdbf7e5 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit e99b0f288096c21ab943dff55dc3ff854c7904b4 +Subproject commit 55ee9bdbf7e5bf1cbde423ba118041e1a7dcca1b diff --git a/src/operator/upsampling_nearest-inl.h b/src/operator/upsampling_nearest-inl.h new file mode 100644 index 000000000000..305e740f903e --- /dev/null +++ b/src/operator/upsampling_nearest-inl.h @@ -0,0 +1,153 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file upsampling-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_ +#define MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +namespace up_enum { +enum UpSamplingNearestOpInputs {kData}; +enum UpSamplingNearestOpOutputs {kOut}; +} // namespace up_enum + +struct UpSamplingNearestParam : public dmlc::Parameter { + index_t scale; + DMLC_DECLARE_PARAMETER(UpSamplingNearestParam) { + DMLC_DECLARE_FIELD(scale) + .set_range(1, 1000) + .describe("Up sampling scale"); + } +}; // struct UpSamplingNearestParam + +template +class UpSamplingNearestOp : public Operator { + public: + explicit UpSamplingNearestOp(UpSamplingNearestParam p) { + this->param_ = p; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor data = in_data[up_enum::kData].get(s); + Tensor out = out_data[up_enum::kOut].get(s); + Assign(out, req[up_enum::kOut], upsampling_nearest(data, param_.scale)); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_grad.size(), 1); + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[up_enum::kOut].get(s); + Tensor input_grad = in_grad[up_enum::kData].get(s); + mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); + Assign(input_grad, req[up_enum::kData], + static_cast(1.0f / param_.scale / param_.scale) * \ + pool(grad, + in_shape, + param_.scale, + param_.scale, + param_.scale)); + } + + private: + UpSamplingNearestParam param_; +}; // class UpSamplingNearestOp + +template +Operator *CreateOp(UpSamplingNearestParam param); + + +#if DMLC_USE_CXX11 +class UpSamplingNearestProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + CHECK_EQ(in_shape->size(), 1); + const TShape &dshape = (*in_shape)[0]; + CHECK_EQ(dshape.ndim(), 4) << \ + "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; + if (dshape.ndim() == 0) return false; + TShape oshape = dshape; + oshape[2] = dshape[2] * param_.scale; + oshape[3] = dshape[3] * param_.scale; + out_shape->clear(); + out_shape->push_back(oshape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new UpSamplingNearestProp(); + ptr->param_ = this->param_; + return ptr; + } + + std::string TypeString() const override { + return "UpSamplingNearest"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[up_enum::kOut]}; + } + + std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + return {{in_data[up_enum::kData], in_grad[up_enum::kData]}}; + } + + Operator* CreateOperator(Context ctx) const override; + + private: + UpSamplingNearestParam param_; +}; // class UpSamplingNearestProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_ + diff --git a/src/operator/upsampling_nearest.cc b/src/operator/upsampling_nearest.cc new file mode 100644 index 000000000000..dc377974466d --- /dev/null +++ b/src/operator/upsampling_nearest.cc @@ -0,0 +1,29 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file upsampling_nearest.cc + * \brief + * \author Bing Xu +*/ + + +#include "./upsampling_nearest-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(UpSamplingNearestParam param) { + return new UpSamplingNearestOp(param); +} + +Operator* UpSamplingNearestProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(UpSamplingNearestParam); + +MXNET_REGISTER_OP_PROPERTY(UpSamplingNearest, UpSamplingNearestProp) +.describe("Perform simple nearest neighboor up sampling to inputs") +.add_argument("data", "Symbol", "Input data to the up sampling operator.") +.add_arguments(UpSamplingNearestParam::__FIELDS__()); +} // namespace op +} // namespace mxnet diff --git a/src/operator/upsampling_nearest.cu b/src/operator/upsampling_nearest.cu new file mode 100644 index 000000000000..b6f64e42ead9 --- /dev/null +++ b/src/operator/upsampling_nearest.cu @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file upsampling.cc + * \brief + * \author Bing Xu +*/ + + +#include "./upsampling_nearest-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(UpSamplingNearestParam param) { + return new UpSamplingNearestOp(param); +} + +} // namespace op +} // namespace mxnet