diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index 58f9be702396..8724d4e5a366 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -34,8 +34,10 @@ #include #include #include +#include #include "../operator_common.h" #include "../linalg.h" +#include "./im2col.h" namespace mxnet { @@ -118,7 +120,7 @@ struct DeconvolutionParam : public dmlc::Parameter { } template - void InferPad(mxnet::TShape input, index_t (&o_pad)[ndim], index_t (&o_adj)[ndim] ) const { + void InferPad(const TShape &input, index_t (&o_pad)[ndim], index_t (&o_adj)[ndim]) const { // Modified by Li.bs // Use tag to control the calculation of pad bool bCal = false; @@ -227,106 +229,122 @@ class DeconvolutionOp { size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); + LayerSetUp(in_data[deconv::kData].shape_, out_data[deconv::kData].shape_); Stream *s = ctx.get_stream(); +#if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init cuBLAS handle in stream"; +#endif auto in_data_shape = in_data[deconv::kData].shape_; - Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); - Tensor out = TBlobTo4DTensor(out_data[deconv::kOut], s); + // G: num of groups + // N: num of batches + // C: num of channels + // IH: input height + // IW: input width + // KH: kernel height + // KW: kernel width + // OH: output width + // OW: output height + // OC: num of output channels + + // input_4d: (N, C, IH, IW) + // output_4d: (N, OC, OH, OW) + Tensor input_4d = TBlobTo4DTensor(in_data[deconv::kData], s); + Tensor output_4d = TBlobTo4DTensor(out_data[deconv::kOut], s); index_t o_pad[2], o_adj[2]; if (param_.kernel.ndim() == 2) { param_.InferPad(mxnet::TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); } else { - index_t o_pad_1D[1], o_adj_1D[1]; - param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); - o_pad[0] = 0; - o_pad[1] = o_pad_1D[0]; - o_adj[0] = 0; - o_adj[1] = o_adj_1D[0]; + param_.InferPad({in_data_shape[2]}, o_pad, o_adj); } - auto stride = param_.kernel.ndim() == 2 ? param_.stride : mxnet::TShape({1, param_.stride[0]}); - auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : mxnet::TShape({1, param_.dilate[0]}); - auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : mxnet::TShape({1, param_.kernel[0]}); - auto kernel_size = kernel.Size(); - - Shape<3> wmat_shape = - Shape3(param_.num_group, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * kernel_size); - Tensor wmat = - in_data[deconv::kWeight].get_with_shape(wmat_shape, s); -#if defined(__CUDACC__) - CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) - << "Must init CuBLAS handle in stream"; -#endif - const index_t nbatch = data.size(0); - Tensor workspace = - ctx.requested[deconv::kTempSpace].get_space_typed( - Shape1(this->InitTemp(out.shape_, data.shape_)), s); - for (index_t i = 0; i < nbatch; i += nstep_) { - const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor( - workspace.dptr_, - Shape2(shape_colunit_[0], - shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor( - workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], - shape_dstunit_[1], - shape_dstunit_[2] * step), s); - temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (o_pad[0] == 0 && o_pad[1] == 0) { - temp_col = unpack_patch2col(out.Slice(i, i + step), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - temp_col = unpack_patch2col(pad(out.Slice(i, i + step), - o_pad[0], o_pad[1]), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } - const index_t gstride = temp_col.size(0) / param_.num_group; - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, - gstride * (gid + 1)); + + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); + auto padding = param_.kernel.ndim() == 2 ? TShape({o_pad[0], o_pad[1]}) : TShape({0, o_pad[0]}); + + // weight_3d: (G, C/G, OC/G * KH * KW) + Tensor weight_3d = in_data[deconv::kWeight].get_with_shape( + Shape3(group_, conv_in_channels_ / group_, kernel_dim_), s); + + Tensor workspace = ctx.requested[deconv::kTempSpace] + .get_space_typed(Shape1(col_buffer_size_ + in_data[deconv::kData].shape_.Size()), s); + + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); + col_buffer_shape[0] = conv_out_channels_ * param_.kernel.Size(); + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { + col_buffer_shape[i] = in_data[deconv::kData].shape_[i + 1]; + } + + // create a colum buffer to hold the matrix product between weight_3d(T) and input_data + TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); + + // col_buffer_3d : (G, OC/G * KH * KW, IH * IW) + Tensor col_buffer_3d = col_buffer.get_with_shape( + Shape3(group_, kernel_dim_, conv_in_spatial_dim_), s); + + for (index_t i = 0; i < num_; ++i) { + // Tensor data_3d = input_4d[i]; + Tensor data_3d = Tensor( + workspace.dptr_ + col_buffer_size_, + Shape3(group_, input_4d.shape_[1] / group_, conv_in_spatial_dim_), s); + + // data_3d : (G, IC/G, IH * IW) + data_3d = reshape(swapaxis<1, 0>(input_4d.Slice(i, i + 1)), data_3d.shape_); + /* + std::cout << "data_3d: " << std::endl; + DType *tmp_data = new DType[data_3d.shape_.Size()]; + if (ctx.run_ctx.get_ctx().dev_mask() == gpu::kDevMask) { + std::cout << "running on GPU " << std::endl; + NDArray data(data_3d, ctx.run_ctx.get_ctx().dev_id); + data.SyncCopyToCPU(tmp_data, data_3d.shape_.Size()); + std::cout << "complete " << std::endl; + } else { + tmp_data = static_cast(data_3d[0].dptr_); + } + + for (auto j = 0; j < data_3d.shape_[1]; ++j) { + for (auto k = 0; k < data_3d.shape_[2]; ++k) { + std::cout << *(tmp_data + j * data_3d.shape_[2] + k) << " "; + } + std::cout << std::endl; + } + */ + for (int g = 0; g < group_; ++g) { // Legacy approach shown here for comparison: - // tmpc = dot(wmat[gid].T(), temp_dst[gid]); - linalg_gemm(wmat[gid], temp_dst[gid], tmpc, true, false, s); - } - if (o_pad[0] == 0 && o_pad[1] == 0) { - out.Slice(i, i + step) = pack_col2patch(temp_col, - out.Slice(i, i + step).shape_, - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - Shape<4> pshape = out.Slice(i, i + step).shape_; - pshape[2] += 2 * o_pad[0]; - pshape[3] += 2 * o_pad[1]; - out.Slice(i, i + step) = crop(pack_col2patch(temp_col, - pshape, - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]), - out[i][0].shape_); + // col_buffer_3d[g] = dot(weight_3d[g].T(), data_3d[g]); + linalg_gemm(weight_3d[g], data_3d[g], col_buffer_3d[g], true, false, s); } + + + // TODO: (lnyuan) remove debugging code + /* + std::cout << "col buffer: " << std::endl; + DType *tmp_col = new DType[col_buffer_size_]; + if (ctx.run_ctx.get_ctx().dev_mask() == gpu::kDevMask) { + std::cout << "running on GPU " << std::endl; + NDArray col_data(col_buffer, ctx.run_ctx.get_ctx().dev_id); + col_data.SyncCopyToCPU(tmp_col, col_buffer_size_); + std::cout << "complete " << std::endl; + } else { + tmp_col = static_cast(col_buffer_3d[0].dptr_); + } + + for (auto j = 0; j < col_buffer_3d.shape_[1]; ++j) { + for (auto k = 0; k < col_buffer_3d.shape_[2]; ++k) { + std::cout << *(tmp_col + j * col_buffer_3d.shape_[2] + k) << " "; + } + std::cout << std::endl; + } + */ + col2im(s, col_buffer.dptr(), out_data[deconv::kOut].shape_, col_buffer.shape_, + param_.kernel, padding, stride, dilate, + out_data[deconv::kOut].dptr() + i * output_dim_, req[deconv::kOut]); } - if (!param_.no_bias) { + + if (bias_term_) { // add bias, broadcast bias to dim 1: channel Tensor bias = in_data[deconv::kBias].get(s); - out += mshadow::expr::broadcast<1>(bias, out.shape_); + output_4d += mshadow::expr::broadcast<1>(bias, output_4d.shape_); } } @@ -337,112 +355,93 @@ class DeconvolutionOp { const std::vector &in_grad) { using namespace mshadow; using namespace mshadow::expr; - // TODO(bing): check the BLAS Handle, be careful CHECK_EQ(out_grad.size(), 1U); size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK_EQ(in_data.size(), expected); CHECK_EQ(in_grad.size(), expected); CHECK_EQ(req.size(), expected); CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); - // get data + + LayerSetUp(in_grad[deconv::kData].shape_, out_grad[deconv::kOut].shape_); Stream *s = ctx.get_stream(); +#if defined(__CUDACC__) + CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) + << "Must init cuBLAS handle in stream"; +#endif + auto in_data_shape = in_data[deconv::kData].shape_; - Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); + Tensor data_4d = TBlobTo4DTensor(in_data[deconv::kData], s); Tensor grad = TBlobTo4DTensor(out_grad[deconv::kOut], s); Tensor gdata = TBlobTo4DTensor(in_grad[deconv::kData], s); - index_t o_pad[2], o_adj[2]; if (param_.kernel.ndim() == 2) { param_.InferPad(mxnet::TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); } else { - index_t o_pad_1D[1], o_adj_1D[1]; - param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); - o_pad[0] = 0; - o_pad[1] = o_pad_1D[0]; - o_adj[0] = 0; - o_adj[1] = o_adj_1D[0]; + param_.InferPad({in_data_shape[2]}, o_pad, o_adj); + } + auto stride = param_.kernel.ndim() == 2 ? param_.stride : TShape({1, param_.stride[0]}); + auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : TShape({1, param_.dilate[0]}); + auto padding = param_.kernel.ndim() == 2 ? TShape({o_pad[0], o_pad[1]}) : TShape({0, o_pad[0]}); + + // weight_3d: (G, C/G, OC * KH * KW) + Tensor weight_3d = in_data[deconv::kWeight] + .get_with_shape(Shape3(group_, conv_in_channels_ / group_, kernel_dim_), s); + + // dweight_3d: (G, C/G, OC * KH * KW) + Tensor dweight_3d = in_grad[deconv::kWeight] + .get_with_shape(Shape3(group_, conv_in_channels_ / group_, kernel_dim_), s); + + Tensor workspace = ctx.requested[deconv::kTempSpace] + .get_space_typed(Shape1(col_buffer_size_ + data_4d.shape_.Size()), s); + + // calculate shape of col_buffer + TShape col_buffer_shape(num_spatial_axes_ + 1, 1); + col_buffer_shape[0] = conv_out_channels_ * param_.kernel.Size(); + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { + col_buffer_shape[i] = in_data[deconv::kData].shape_[i+1]; } - auto stride = param_.kernel.ndim() == 2 ? param_.stride : mxnet::TShape({1, param_.stride[0]}); - auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : mxnet::TShape({1, param_.dilate[0]}); - auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : mxnet::TShape({1, param_.kernel[0]}); - auto kernel_size = kernel.Size(); - - Shape<3> wmat_shape = - Shape3(param_.num_group, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * kernel_size); - Tensor wmat = - in_data[deconv::kWeight].get_with_shape(wmat_shape, s); - Tensor gwmat = - in_grad[deconv::kWeight].get_with_shape(wmat_shape, s); -#if defined(__CUDACC__) - CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) - << "Must init CuBLAS handle in stream"; -#endif - const index_t nbatch = data.size(0); - Tensor workspace = - ctx.requested[deconv::kTempSpace].get_space_typed( - Shape1(this->InitTemp(grad.shape_, data.shape_)), s); - for (index_t i = 0; i < nbatch; i += nstep_) { - const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor( - workspace.dptr_, - Shape2(shape_colunit_[0], - shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor( - workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], - shape_dstunit_[1], - shape_dstunit_[2] * step), s); - temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (o_pad[0] == 0 && o_pad[1] == 0) { - temp_col = unpack_patch2col(grad.Slice(i, i + step), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } - const index_t gstride = temp_col.size(0) / param_.num_group; - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); - if (i == 0) { - Tensor tmp_gwmat = gwmat[gid]; - // Legacy approach shown here for comparison: - // Assign(tmp_gwmat, req[deconv::kWeight], dot(temp_dst[gid], tmpc.T())); - linalg_gemm(temp_dst[gid], tmpc, tmp_gwmat, false, true, s, req[deconv::kWeight]); - } else { - // Legacy approach shown here for comparison: - // gwmat[gid] += dot(temp_dst[gid], tmpc.T()); - linalg_gemm(temp_dst[gid], tmpc, gwmat[gid], false, true, s, kAddTo); - } + // create a column buffer to store ograd + TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); + + // col_buffer_3d: (G, OC/G * KH * KW, IH * IW) + Tensor col_buffer_3d = col_buffer.get_with_shape( + Shape3(group_, kernel_dim_, conv_in_spatial_dim_), s); + + for (index_t i = 0; i < num_; ++i) { + // Tensor data_3d = input_4d[i]; + Tensor data_3d = Tensor( + workspace.dptr_ + col_buffer_size_, + Shape3(group_, data_4d.shape_[1] / group_, conv_in_spatial_dim_), s); + + // data_3d : (G, C/G, IH * IW) + data_3d = reshape(swapaxis<1, 0>(data_4d.Slice(i, i + 1)), data_3d.shape_); + + // convert output gradient array to column buffer + im2col(s, out_grad[deconv::kOut].dptr() + i * output_dim_, out_grad[deconv::kOut].shape_, + col_buffer.shape_, param_.kernel, padding, stride, dilate, col_buffer.dptr()); + + for (int g = 0; g < group_; ++g) { + auto request = (i == 0) ? req[deconv::kWeight] : kAddTo; + // Legacy approach shown here for comparison: + // dweight_3d[gid] += dot(temp_dst[gid], tmpc.T()); + linalg_gemm(data_3d[g], col_buffer_3d[g], dweight_3d[g], false, true, s, request); } if (req[deconv::kData] == kWriteTo || req[deconv::kData] == kWriteInplace || req[deconv::kData] == kAddTo) { - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); + for (int g = 0; g < group_; ++g) { // Legacy approach shown here for comparison: - // temp_dst[gid] = dot(wmat[gid], tmpc); - linalg_gemm(wmat[gid], tmpc, temp_dst[gid], false, false, s); + // temp_dst[gid] = dot(weight_3d[gid], tmpc); + linalg_gemm(weight_3d[g], col_buffer_3d[g], data_3d[g], false, false, s); } - Assign(gdata.Slice(i, i + step), + Assign(gdata.Slice(i, i + 1), req[deconv::kData], - (swapaxis<1, 0>(reshape(temp_dst, - mshadow::Shape4(gdata.shape_[1], - step, - gdata.size(2), - gdata.size(3)))))); + (swapaxis<1, 0>(reshape(data_3d, + Shape4(gdata.shape_[1], + 1, + gdata.size(2), + gdata.size(3)))))); } } if (!param_.no_bias) { @@ -452,28 +451,6 @@ class DeconvolutionOp { } private: - inline index_t InitTemp(const mshadow::Shape<4> &ishape, - const mshadow::Shape<4> &oshape) { - const int ksize = param_.kernel.Size(); - shape_colunit_ = mshadow::Shape2(ishape[1] * ksize, - oshape[2] * oshape[3]); - shape_dstunit_ = mshadow::Shape3(param_.num_group, - oshape[1] / param_.num_group, - oshape[2] * oshape[3]); - // See convolution for workspace calculations. nstep_ will be the effective batch size - nstep_ = std::max( - std::min(param_.workspace / - (shape_colunit_.Size() + shape_dstunit_.Size()), ishape[0]), - 1); - - mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0], - shape_colunit_[1] * nstep_); - mshadow::Shape<3> sdst = mshadow::Shape3(shape_dstunit_[0], - shape_dstunit_[1], - shape_dstunit_[2] * nstep_); - index_t required_size = scol.Size() + sdst.Size(); - return required_size; - } inline Tensor TBlobTo4DTensor(const TBlob &tb, Stream *s) { using namespace mshadow; @@ -484,10 +461,48 @@ class DeconvolutionOp { Shape4(tb.shape_[0], tb.shape_[1], 1, tb.shape_[2]), s); } + void LayerSetUp(const mxnet::TShape& ishape, const mxnet::TShape& oshape) { + channel_axis_ = 1; // hard code channel axis + const index_t first_spatial_axis = channel_axis_ + 1; + const int num_axes = param_.kernel.ndim() + 2; + num_spatial_axes_ = num_axes - first_spatial_axis; + + // batch size + num_ = ishape[0]; + // number of input channels + channels_ = ishape[1]; + group_ = param_.num_group; + conv_out_channels_ = param_.num_filter; + conv_in_channels_ = channels_; + bias_term_ = !param_.no_bias; + kernel_dim_ = conv_out_channels_ / group_ * param_.kernel.Size(); + weight_offset_ = conv_out_channels_ * kernel_dim_ / group_; + conv_out_spatial_dim_ = oshape.ProdShape(2, oshape.ndim()); + conv_in_spatial_dim_ = ishape.ProdShape(2, ishape.ndim()); + // size of the column buffer used for storing im2col-ed pixels + col_buffer_size_ = kernel_dim_ * group_ * conv_in_spatial_dim_; + // input/output image size (#channels * height * width) + input_dim_ = ishape.ProdShape(1, ishape.ndim()); + output_dim_ = oshape.ProdShape(1, oshape.ndim()); + } + +private: DeconvolutionParam param_; - mshadow::Shape<2> shape_colunit_; - mshadow::Shape<3> shape_dstunit_; - index_t nstep_; + index_t channel_axis_; // channel axis of the input + index_t channels_; // number of channels of input image + index_t num_spatial_axes_; // number of spatial axes + index_t num_; // batch size + index_t group_; // number of groups + index_t conv_out_channels_; // number of output channels (num_filter) + index_t conv_out_spatial_dim_; // number of pixels of output images per channel + index_t conv_in_spatial_dim_; // number of pixels of input images per channel + index_t conv_in_channels_; // number of input channels + index_t kernel_dim_; // number of input channels per group * kernel size + index_t weight_offset_; // number of output channels per group * kernel_dim_ + index_t col_buffer_size_; + index_t input_dim_; + index_t output_dim_; + bool bias_term_; // has bias term? }; // class DeconvolutionOp template diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index efa04f4fa47a..d1bad8c0c71f 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -503,6 +503,40 @@ def test_deconv(): # layer = nn.Conv3DTranspose(16, (3, 3, 3), layout='NDHWC', in_channels=4) # # check_layer_forward(layer, (1, 10, 10, 10, 4)) +@with_seed() +def test_deconv_dilation(): + data = mx.nd.array((((0,0,0), + (0,1,0), + (0,0,0)), + ((0,0,0), + (0,2,0), + (0,0,0)))) + + kernel = mx.nd.array(((1,2,3), + (4,5,6), + (7,8,9))) + + data_batch = data.expand_dims(1) + weight = kernel.expand_dims(0).expand_dims(0) + layer = nn.Conv2DTranspose(in_channels=1, channels=1, + kernel_size=(3,3), padding=(1,1), + strides=(1,1), dilation=(2,2)) + layer.initialize() + layer.weight.set_data(weight) + out = layer(data_batch).asnumpy() + expected = np.array([[[[1.,0.,2.,0.,3.], + [0.,0.,0.,0.,0.], + [4.,0.,5.,0.,6.], + [0.,0.,0.,0.,0.], + [7.,0.,8.,0.,9.]]], + [[[2.,0.,4.,0.,6.], + [0.,0.,0.,0.,0.], + [8.,0.,10.,0.,12.], + [0.,0.,0.,0.,0.], + [14.,0.,16.,0.,18.]]] + ]) + assert_almost_equal(out, expected) + @with_seed() def test_pool(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7db07596d7f8..8c03d7c8c933 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1291,7 +1291,7 @@ def test_abs(): assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() npout_grad = npout_grad * np.sign(data_tmp) exe_test.backward(out_grad)