Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def __call__(self, name, arr):
raise TypeError('name must be string')
if not isinstance(arr, NDArray):
raise TypeError('arr must be NDArray')
if name.endswith('bias'):
if name.startswith('upsampling'):
self._init_bilinear(name, arr)
elif name.endswith('bias'):
self._init_bias(name, arr)
elif name.endswith('gamma'):
self._init_gamma(name, arr)
Expand All @@ -39,7 +41,18 @@ def __call__(self, name, arr):
self._init_zero(name, arr)
else:
self._init_default(name, arr)
# pylint: disable=no-self-use, missing-docstring
# pylint: disable=no-self-use, missing-docstring, invalid-name
def _init_bilinear(self, _, arr):
weight = np.zeros(np.prod(arr.shape), dtype='float32')
shape = arr.shape
f = shape[3] / 2.
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % shape[3]
y = (i / shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)

def _init_zero(self, _, arr):
arr[:] = 0.0

Expand All @@ -58,7 +71,7 @@ def _init_weight(self, name, arr):

def _init_default(self, name, _):
raise ValueError('Unknown initialization pattern for %s' % name)
# pylint: enable=no-self-use, missing-docstring
# pylint: enable=no-self-use, missing-docstring, invalid-name

class Uniform(Initializer):
"""Initialize the weight with uniform [-scale, scale]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_
#define MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_
#ifndef MXNET_OPERATOR_UPSAMPLING_INL_H_
#define MXNET_OPERATOR_UPSAMPLING_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
Expand All @@ -21,23 +21,32 @@ namespace mxnet {
namespace op {

namespace up_enum {
enum UpSamplingNearestOpInputs {kData};
enum UpSamplingNearestOpOutputs {kOut};
enum UpSamplingOpInputs {kData, kWeight};
enum UpSamplingOpOutputs {kOut};
enum UpSamplingType {kNearest, kBilinear};
} // namespace up_enum

struct UpSamplingNearestParam : public dmlc::Parameter<UpSamplingNearestParam> {
struct UpSamplingParam : public dmlc::Parameter<UpSamplingParam> {
index_t scale;
DMLC_DECLARE_PARAMETER(UpSamplingNearestParam) {
index_t num_filter;
int sample_type;
DMLC_DECLARE_PARAMETER(UpSamplingParam) {
DMLC_DECLARE_FIELD(scale)
.set_range(1, 1000)
.describe("Up sampling scale");
DMLC_DECLARE_FIELD(num_filter)
.describe("input filter");
DMLC_DECLARE_FIELD(sample_type)
.add_enum("nearest", up_enum::kNearest)
.add_enum("bilinear", up_enum::kBilinear)
.describe("upsampling method");
}
}; // struct UpSamplingNearestParam
}; // struct UpSamplingParam

template<typename xpu>
class UpSamplingNearestOp : public Operator {
public:
explicit UpSamplingNearestOp(UpSamplingNearestParam p) {
explicit UpSamplingNearestOp(UpSamplingParam p) {
this->param_ = p;
}

Expand Down Expand Up @@ -81,15 +90,15 @@ class UpSamplingNearestOp : public Operator {
}

private:
UpSamplingNearestParam param_;
UpSamplingParam param_;
}; // class UpSamplingNearestOp

template<typename xpu>
Operator *CreateOp(UpSamplingNearestParam param);
Operator *CreateOp(UpSamplingParam param);


#if DMLC_USE_CXX11
class UpSamplingNearestProp : public OperatorProperty {
class UpSamplingProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
Expand All @@ -99,14 +108,35 @@ class UpSamplingNearestProp : public OperatorProperty {
return param_.__DICT__();
}

std::vector<std::string> ListArguments() const override {
if (param_.sample_type == up_enum::kNearest) {
return {"data"};
} else {
return {"data", "weight"};
}
}

bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
CHECK_EQ(in_shape->size(), 1);
CHECK_GE(in_shape->size(), 1);
const TShape &dshape = (*in_shape)[0];
CHECK_EQ(dshape.ndim(), 4) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
if (param_.sample_type == up_enum::kNearest) {
CHECK_EQ(in_shape->size(), 1) << "Input:[data]";
CHECK_EQ(dshape.ndim(), 4) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
} else {
CHECK_EQ(in_shape->size(), 2) << "Input:[data, weight]";
CHECK_EQ(dshape.ndim(), 4) << \
"UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)";
if (dshape.ndim() == 0) return false;
// param_.num_filter = dshape[1];
int kernel = 2 * param_.scale - param_.scale % 2;
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
mshadow::Shape4(dshape[1], 1, kernel, kernel));
}
TShape oshape = dshape;
oshape[2] = dshape[2] * param_.scale;
oshape[3] = dshape[3] * param_.scale;
Expand All @@ -116,38 +146,64 @@ class UpSamplingNearestProp : public OperatorProperty {
}

OperatorProperty* Copy() const override {
auto ptr = new UpSamplingNearestProp();
auto ptr = new UpSamplingProp();
ptr->param_ = this->param_;
return ptr;
}

std::string TypeString() const override {
return "UpSamplingNearest";
return "UpSampling";
}

std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
return {out_grad[up_enum::kOut]};
if (param_.sample_type == up_enum::kNearest) {
return {out_grad[up_enum::kOut]};
} else {
return {out_grad[up_enum::kOut], in_data[up_enum::kData], in_data[up_enum::kWeight]};
}
}

std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {{in_data[up_enum::kData], in_grad[up_enum::kData]}};
if (param_.sample_type == up_enum::kNearest) {
return {{in_data[up_enum::kData], in_grad[up_enum::kData]}};
} else {
return {};
}
}

std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const override {
if (param_.sample_type == up_enum::kNearest) {
return {};
} else {
return {ResourceRequest::kTempSpace};
}
}

std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const override {
if (param_.sample_type == up_enum::kNearest) {
return {};
} else {
return {ResourceRequest::kTempSpace};
}
}

Operator* CreateOperator(Context ctx) const override;

private:
UpSamplingNearestParam param_;
}; // class UpSamplingNearestProp
UpSamplingParam param_;
}; // class UpSamplingProp
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_UPSAMPLING_NEAREST_INL_H_
#endif // MXNET_OPERATOR_UPSAMPLING_INL_H_

49 changes: 49 additions & 0 deletions src/operator/upsampling.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*!
* Copyright (c) 2015 by Contributors
* \file upsampling_nearest.cc
* \brief
* \author Bing Xu
*/

#include "./deconvolution-inl.h"
#include "./upsampling-inl.h"

namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(UpSamplingParam param) {
if (param.sample_type == up_enum::kNearest) {
return new UpSamplingNearestOp<cpu>(param);
} else if (param.sample_type == up_enum::kBilinear) {
DeconvolutionParam p;
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
p.no_bias = true;
int shape[] = {1, 1};
shape[0] = shape[1] = kernel;
p.kernel = TShape(shape, shape + 2);
shape[0] = shape[1] = stride;
p.stride = TShape(shape, shape + 2);
shape[0] = shape[1] = pad;
p.pad = TShape(shape, shape + 2);
return new DeconvolutionOp<cpu>(p);
} else {
LOG(FATAL) << "Unknown sample type";
return NULL;
}
}

Operator* UpSamplingProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
}

DMLC_REGISTER_PARAMETER(UpSamplingParam);

MXNET_REGISTER_OP_PROPERTY(UpSampling, UpSamplingProp)
.describe("Perform simple nearest neighboor up sampling to inputs")
.add_arguments(UpSamplingParam::__FIELDS__());
} // namespace op
} // namespace mxnet
40 changes: 40 additions & 0 deletions src/operator/upsampling.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*!
* Copyright (c) 2015 by Contributors
* \file upsampling_nearest.cc
* \brief
* \author Bing Xu
*/

#include "./deconvolution-inl.h"
#include "./upsampling-inl.h"

namespace mxnet {
namespace op {
template<>
Operator *CreateOp<gpu>(UpSamplingParam param) {
if (param.sample_type == up_enum::kNearest) {
return new UpSamplingNearestOp<gpu>(param);
} else if (param.sample_type == up_enum::kBilinear) {
DeconvolutionParam p;
int kernel = 2 * param.scale - param.scale % 2;
int stride = param.scale;
int pad = static_cast<int>(ceil((param.scale - 1) / 2.));
p.num_group = param.num_filter;
p.num_filter = param.num_filter;
p.no_bias = true;
int shape[] = {1, 1};
shape[0] = shape[1] = kernel;
p.kernel = TShape(shape, shape + 2);
shape[0] = shape[1] = stride;
p.stride = TShape(shape, shape + 2);
shape[0] = shape[1] = pad;
p.pad = TShape(shape, shape + 2);
return new DeconvolutionOp<gpu>(p);
} else {
LOG(FATAL) << "Unknown sample type";
return NULL;
}
}

} // namespace op
} // namespace mxnet
29 changes: 0 additions & 29 deletions src/operator/upsampling_nearest.cc

This file was deleted.

19 changes: 0 additions & 19 deletions src/operator/upsampling_nearest.cu

This file was deleted.

2 changes: 1 addition & 1 deletion tests/travis/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then
fi

if [ ${TASK} == "lint" ]; then
pip install cpplint pylint --user `whoami`
pip install cpplint 'pylint==1.4.4' 'astroid==1.3.6' --user `whoami`
fi