diff --git a/3rdparty/mshadow/mshadow/extension/spatial_upsampling_nearest.h b/3rdparty/mshadow/mshadow/extension/spatial_upsampling_nearest.h index 534fbdd9ebe0..c194a1266d63 100644 --- a/3rdparty/mshadow/mshadow/extension/spatial_upsampling_nearest.h +++ b/3rdparty/mshadow/mshadow/extension/spatial_upsampling_nearest.h @@ -11,7 +11,7 @@ namespace mshadow { namespace expr { -/*! \brief nearest neighboor upsampling +/*! \brief nearest neighbor upsampling * out(x, y) = in(int(x / scale_x), int(y / scale_y)) * \tparam SrcExp source expression * \tparam DType data type @@ -24,23 +24,25 @@ struct UpSamplingNearestExp : /*! \brief source oprand */ const SrcExp &src_; /*! \brief up sampling scale */ - index_t scale_; + index_t scale_h_; + index_t scale_w_; + /*! \brief constructor */ - UpSamplingNearestExp(const SrcExp &src, index_t scale) - : src_(src), scale_(scale) { + UpSamplingNearestExp(const SrcExp &src, index_t scale_h, index_t scale_w) + : src_(src), scale_h_(scale_h), scale_w_(scale_w) { this->shape_ = ShapeCheck::Check(src_); - this->shape_[srcdim - 2] *= scale_; - this->shape_[srcdim - 1] *= scale_; + this->shape_[srcdim - 2] *= scale_h; + this->shape_[srcdim - 1] *= scale_w; } }; template inline UpSamplingNearestExp::kDim> -upsampling_nearest(const Exp &src, index_t scale) { +upsampling_nearest(const Exp &src, index_t scale_h, index_t scale_w) { TypeCheckPass::kDim >= 2> ::Error_Expression_Does_Not_Meet_Dimension_Req(); - return UpSamplingNearestExp::kDim>(src.self(), scale); + return UpSamplingNearestExp::kDim>(src.self(), scale_h, scale_w); } template @@ -48,23 +50,29 @@ struct Plan, DType> { public: explicit Plan(const UpSamplingNearestExp &e) : src_(MakePlan(e.src_)), - scale_(e.scale_), + scale_h_(e.scale_h_), + scale_w_(e.scale_w_), new_height_(e.shape_[srcdim - 2]), - src_height_(static_cast(e.shape_[srcdim - 2] / e.scale_)) {} + new_width_(e.shape_[srcdim - 1]), + src_height_(static_cast(e.shape_[srcdim - 2] / e.scale_h_)), + src_width_(static_cast(e.shape_[srcdim - 1] / e.scale_w_)) {} MSHADOW_XINLINE DType Eval(index_t i, index_t j) const { const index_t x = j; const index_t y = i % new_height_; const index_t c = i / new_height_; - const index_t h = static_cast(y / scale_); - const index_t w = static_cast(x / scale_); + const index_t h = static_cast(y / scale_h_); + const index_t w = static_cast(x / scale_w_); return src_.Eval(c * src_height_ + h, w); } private: Plan src_; - const index_t scale_; + const index_t scale_h_; + const index_t scale_w_; const index_t new_height_; + const index_t new_width_; const index_t src_height_; + const index_t src_width_; }; } // namespace expr } // namespace mshadow diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj index 5e1b127d18bd..641f1236e414 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -590,7 +590,7 @@ scale (range 1 4) num-shape (range 1 4) base (range 1 4)] - (let [shape-vecs (mapv (fn [i] [1 3 (* base root-scale (int (Math/pow scale (- (dec num-shape) i)))) + (let [shape-vecs (mapv (fn [i] [1 3 (* base root-scale (int (Math/pow scale (- (dec num-shape) i)))) (* base root-scale (int (Math/pow scale (- (dec num-shape) i))))]) (range 0 num-shape))] (check-nearest-up-sampling-with-shape {:shape-vecs shape-vecs :scale scale :root-scale root-scale}))))) diff --git a/src/operator/nn/upsampling-inl.h b/src/operator/nn/upsampling-inl.h index 8219e3e9bd8d..63742b091a00 100644 --- a/src/operator/nn/upsampling-inl.h +++ b/src/operator/nn/upsampling-inl.h @@ -48,7 +48,7 @@ enum UpSamplingMultiInputMode {kConcat, kSum}; } // namespace up_enum struct UpSamplingParam : public dmlc::Parameter { - int scale; + mxnet::TShape scale; int num_filter; int sample_type; int num_args; @@ -56,8 +56,10 @@ struct UpSamplingParam : public dmlc::Parameter { uint64_t workspace; DMLC_DECLARE_PARAMETER(UpSamplingParam) { DMLC_DECLARE_FIELD(scale) - .set_range(1, 1000) - .describe("Up sampling scale"); + .set_default(mxnet::TShape()) + .describe("Up sampling scale. Integer or tuple of integers. " + "Different scale per dimension is allowed only for " + "nearest neighbor upsampling."); DMLC_DECLARE_FIELD(num_filter) .describe("Input filter. Only used by bilinear sample_type." "Since bilinear upsampling uses deconvolution, num_filters " @@ -84,6 +86,21 @@ struct UpSamplingParam : public dmlc::Parameter { } }; // struct UpSamplingParam +inline std::vector scaleComp(const UpSamplingParam ¶m) { + std::vector scaleArr{ 1, 1 }; + if (param.scale.ndim() == 1) { + scaleArr[0] = param.scale[0]; + scaleArr[1] = param.scale[0]; + } else if (param.scale.ndim() == 2) { + scaleArr[0] = param.scale[0]; + scaleArr[1] = param.scale[1]; + } else if (param.scale.ndim() == 4) { + scaleArr[0] = param.scale[2]; + scaleArr[1] = param.scale[3]; + } + return scaleArr; +} + template void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, const std::vector &in_data, @@ -103,21 +120,27 @@ void UpSamplingForward(const OpContext &ctx, const UpSamplingParam ¶m, for (int i = 0; i < param.num_args; ++i) { Tensor data = in_data[i].get(s); int end = begin + data.size(1); - int scale = out_data[up_enum::kOut].size(2)/in_data[i].size(2); + // 3rd dimension of TBlob + int scale_h = out_data[up_enum::kOut].size(2)/in_data[i].size(2); + // 4th dimension of TBlob + int scale_w = out_data[up_enum::kOut].size(3)/in_data[i].size(3); if (param.multi_input_mode == up_enum::kSum) { if (i == 0) { - Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale)); + Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale_h, scale_w)); } else { - out += upsampling_nearest(data, scale); + out += upsampling_nearest(data, scale_h, scale_w); } } else { - Assign(slice<1>(out, begin, end), req[up_enum::kOut], upsampling_nearest(data, scale)); + Assign(slice<1>(out, begin, end), + req[up_enum::kOut], + upsampling_nearest(data, scale_h, scale_w)); } begin = end; } } else { Tensor data = in_data[up_enum::kData].get(s); - Assign(out, req[up_enum::kOut], upsampling_nearest(data, param.scale)); + std::vector scale_hw = scaleComp(param); + Assign(out, req[up_enum::kOut], upsampling_nearest(data, scale_hw[0], scale_hw[1])); } } @@ -136,44 +159,49 @@ void UpSamplingBackward(const OpContext &ctx, const UpSamplingParam ¶m, Tensor input_grad = in_grad[i].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); int end = begin + input_grad.size(1); - int scale = grad.size(2)/in_shape[0]; + int scale_h = grad.size(2)/in_shape[0]; + int scale_w = grad.size(3)/in_shape[1]; if (param.multi_input_mode == up_enum::kSum) { Assign(input_grad, req[i], pool(grad, in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } else { Assign(input_grad, req[i], pool(slice<1>(grad, begin, end), in_shape, - scale, - scale, - scale, - scale)); + scale_h, + scale_w, + scale_h, + scale_w)); } begin = end; } } else { Tensor input_grad = in_grad[up_enum::kData].get(s); mshadow::Shape<2> in_shape = Shape2(input_grad.shape_[2], input_grad.shape_[3]); + std::vector scale_hw = scaleComp(param); Assign(input_grad, req[up_enum::kData], pool(grad, in_shape, - param.scale, - param.scale, - param.scale, - param.scale)); + scale_hw[0], + scale_hw[1], + scale_hw[0], + scale_hw[1])); } } static inline DeconvolutionParam GetDeconvolutionParam(const UpSamplingParam& param) { DeconvolutionParam p = DeconvolutionParam(); - int kernel = 2 * param.scale - param.scale % 2; - int stride = param.scale; - int pad = static_cast(ceil((param.scale - 1) / 2.)); + std::vector scale_hw = scaleComp(param); + CHECK_EQ(scale_hw[0], scale_hw[1]) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; + int kernel = static_cast(2.0 * scale_hw[0] - ::fmod(scale_hw[0], 2)); + int stride = scale_hw[0]; + int pad = static_cast(ceil((scale_hw[0] - 1) / 2.)); p.workspace = param.workspace; p.num_group = param.num_filter; p.num_filter = param.num_filter; diff --git a/src/operator/nn/upsampling.cc b/src/operator/nn/upsampling.cc index 971ff6ad560b..dd3dc43d4e25 100644 --- a/src/operator/nn/upsampling.cc +++ b/src/operator/nn/upsampling.cc @@ -37,13 +37,16 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, CHECK_GE(in_shape->size(), 1U); const mxnet::TShape &dshape = (*in_shape)[0]; mxnet::TShape oshape = dshape; + std::vector scale_hw = scaleComp(param_); + int scale_h = scale_hw[0]; + int scale_w = scale_hw[1]; if (param_.sample_type == up_enum::kNearest) { CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); oshape[1] = 0; for (auto& shape : *in_shape) { CHECK_EQ(shape.ndim(), 4U) << \ "UpSamplingNearest: Input data should be 4D in (batch, channel, y, x)"; - int oh = dshape[2]*param_.scale, ow = dshape[3]*param_.scale; + int oh = dshape[2] * scale_h, ow = dshape[3] * scale_w; CHECK_EQ(oh%shape[2], 0U) << "UpSamplingNearest: input height of " << shape[2] << \ "does not divide output height of " << oh; CHECK_EQ(ow%shape[3], 0U) << "UpSamplingNearest: input width of " << shape[3] << \ @@ -58,17 +61,19 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs, } } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; + CHECK_EQ(scale_h, scale_w) << + "UpSamplingBilinear: Scale should be the same along all dimensions for bilinear upsampling"; CHECK_EQ(dshape.ndim(), 4U) << \ "UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)"; if (!shape_is_known(dshape)) return false; - int kernel = 2 * param_.scale - param_.scale % 2; + int kernel = static_cast(2.0 * scale_h - (scale_h & 1)); SHAPE_ASSIGN_CHECK(*in_shape, up_enum::kWeight, mshadow::Shape4(dshape[1], 1, kernel, kernel)); oshape = dshape; } - oshape[2] = dshape[2] * param_.scale; - oshape[3] = dshape[3] * param_.scale; + oshape[2] = dshape[2] * scale_h; + oshape[3] = dshape[3] * scale_w; out_shape->clear(); out_shape->push_back(oshape); return true; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f0e16b2729a2..cd66fe48d317 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1559,15 +1559,23 @@ def check_deconvolution_forward_with_bias(shape=(1, 16, 5, 5), num_filter=32, nu def check_nearest_upsampling_with_shape(shapes, scale, root_scale): arr = {'arg_%d'%i: mx.random.uniform(-10.0, 10.0, shape, ctx=mx.cpu()).copyto(default_context()) for i, shape in zip(range(len(shapes)), shapes)} arr_grad = {'arg_%d'%i: mx.nd.zeros(shape) for i, shape in zip(range(len(shapes)), shapes)} - up = mx.sym.UpSampling(*[mx.sym.Variable('arg_%d'%i) for i in range(len(shapes))], sample_type='nearest', scale=root_scale) exe = up.bind(default_context(), args=arr, args_grad=arr_grad) exe.forward(is_train=True) exe.backward(exe.outputs) for k in range(len(shapes)): name = 'arg_%d'%k - assert_allclose(arr[name].asnumpy()*root_scale**2*scale**(2*k), arr_grad[name].asnumpy(), rtol=1e-4) - + out = arr_grad[name].asnumpy() + root_h = root_w = 1 + if type(root_scale) is int: + root_h = root_w = root_scale + elif len(root_scale) == 1: + root_h = root_w = root_scale[0] + elif len(root_scale) >= 2: + root_h = root_scale[0] + root_w = root_scale[1] + exp = arr[name].asnumpy() * root_h * root_w * scale ** (2 * k) + assert_allclose(exp, out, rtol=1.5e-4) def check_bilinear_upsampling_with_shape(data_shape, weight_shape, scale, root_scale, num_filter): def _init_bilinear(arr, f): @@ -1597,13 +1605,27 @@ def _init_bilinear(arr, f): assert out.shape == data_shape[:2] + target_shape +""" +The test cases include integer, tuple, +and empty tuple scales on up to 3 shapes +at once with the shapes having various sizes +for their heights and widths +""" @with_seed() def test_nearest_upsampling(): - for root_scale in [1,2,3]: - for scale in [1,2,3]: - for num_shape in [1,2,3]: - for base in [1,2,3]: - shapes = [(1,3,base*root_scale*scale**(num_shape-1-i),base*root_scale*scale**(num_shape-1-i)) for i in range(num_shape)] + for root_scale in [1, 2, (3), rand_shape_nd(1, 10), (5,1), (2,3), rand_shape_nd(2, 10), rand_shape_nd(2, 10), ()]: + for scale in [1, 2, 3]: + for num_shape in [1, 2, 3]: + for base in [1, 2, 3]: + root_h = root_w = 1 + if type(root_scale) is int: + root_h = root_w = root_scale + elif len(root_scale) == 1: + root_h = root_w = root_scale[0] + elif len(root_scale) >= 2: + root_h = root_scale[0] + root_w = root_scale[1] + shapes = [(1, 3, base*root_h*scale**(num_shape-1-i), base*root_w*scale**(num_shape-1-i)) for i in range(num_shape)] check_nearest_upsampling_with_shape(shapes, scale, root_scale)