diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py index 63ff99ef3721..144d3ec28b35 100644 --- a/python/mxnet/ndarray/numpy_extension/_op.py +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -617,9 +617,9 @@ def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dil @set_module('mxnet.ndarray.numpy_extension') def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, - workspace=512, no_bias=False, cudnn_tune=None, + workspace=1024, no_bias=False, cudnn_tune=None, cudnn_off=False, layout=None): - r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of + r"""Computes 1D, 2D or 3D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation with respect to its input. Convolution usually reduces the size of the input. Transposed convolution works the other way, going from a smaller input diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py index 4124988c1536..61a1cce4d763 100644 --- a/python/mxnet/numpy_extension/_op.py +++ b/python/mxnet/numpy_extension/_op.py @@ -586,9 +586,9 @@ def convolution(data=None, weight=None, bias=None, kernel=None, stride=None, dil @set_module('mxnet.numpy_extension') def deconvolution(data=None, weight=None, bias=None, kernel=None, stride=None, dilate=None, pad=None, adj=None, target_shape=None, num_filter=1, num_group=1, - workspace=512, no_bias=False, cudnn_tune=None, + workspace=1024, no_bias=False, cudnn_tune=None, cudnn_off=False, layout=None): - r"""Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of + r"""Computes 1D, 2D or 3D transposed convolution (aka fractionally strided convolution) of the input tensor. This operation can be seen as the gradient of Convolution operation with respect to its input. Convolution usually reduces the size of the input. Transposed convolution works the other way, going from a smaller input diff --git a/src/operator/deformable_convolution-inl.h b/src/operator/deformable_convolution-inl.h index 7782a8bdd2bd..2ab1ad65ab3a 100644 --- a/src/operator/deformable_convolution-inl.h +++ b/src/operator/deformable_convolution-inl.h @@ -347,7 +347,7 @@ class DeformableConvolutionOp : public Operator { index_t num_kernels_col2im_; bool bias_term_; // has bias term? bool is_1x1_; -}; // class ConvolutionOp +}; // class DeformableConvolutionOp template Operator* CreateOp(DeformableConvolutionParam param, diff --git a/src/operator/modulated_deformable_convolution-inl.h b/src/operator/modulated_deformable_convolution-inl.h index 7ffa204fbacb..9f6adb9fdbb1 100644 --- a/src/operator/modulated_deformable_convolution-inl.h +++ b/src/operator/modulated_deformable_convolution-inl.h @@ -408,7 +408,7 @@ class ModulatedDeformableConvolutionOp : public Operator { index_t im2col_step_; bool bias_term_; // has bias term? bool is_1x1_; -}; // class ConvolutionOp +}; // class ModulatedDeformableConvolutionOp template Operator* CreateOp(ModulatedDeformableConvolutionParam param, diff --git a/src/operator/nn/convolution-inl.h b/src/operator/nn/convolution-inl.h index 46ff8ca50aea..77f7e4c6cc8a 100644 --- a/src/operator/nn/convolution-inl.h +++ b/src/operator/nn/convolution-inl.h @@ -228,9 +228,11 @@ class ConvolutionOp { this->param_ = p; // convert MBytes first to Bytes and then to elements. param_.workspace = (param_.workspace << 20) / sizeof(DType); - CHECK(param_.layout.value() == mshadow::kNCW || param_.layout.value() == mshadow::kNCHW || - param_.layout.value() == mshadow::kNCDHW) - << "Only support NCW, NCHW and NCDHW layout"; + if (param_.layout.has_value()) { + CHECK(param_.layout.value() == mshadow::kNCW || param_.layout.value() == mshadow::kNCHW || + param_.layout.value() == mshadow::kNCDHW) + << "Only support NCW, NCHW and NCDHW layout"; + } } void Forward(const OpContext& ctx, @@ -239,44 +241,88 @@ class ConvolutionOp { const std::vector& out_data) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(req[conv::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); - CHECK_EQ(req[conv::kOut], kWriteTo); - LayerSetUp(in_data[conv::kData].shape_, out_data[conv::kOut].shape_); + // CHECK_EQ(req[conv::kOut], kWriteTo); + _Forward(ctx, + in_data[conv::kData], + in_data[conv::kWeight], + param_.no_bias ? nullptr : &in_data[conv::kBias], + req[conv::kOut], + out_data[conv::kOut]); + } + + void Backward(const OpContext& ctx, + const std::vector& out_grad, + const std::vector& in_data, + const std::vector& req, + const std::vector& in_grad) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1U); + // We expect 2 inputs: in data and weight. We don't need bias for + // computing gradient. + size_t expected = param_.no_bias ? 2 : 3; + CHECK_EQ(in_data.size(), expected); + CHECK_EQ(in_grad.size(), expected); + CHECK_EQ(req.size(), expected); + CHECK_EQ(in_data[conv::kWeight].CheckContiguous(), true); + + auto workspace = _BackwardData( + ctx, out_grad[conv::kOut], in_data[conv::kWeight], req[conv::kData], in_grad[conv::kData]); + _BackwardWeightsBias(workspace, + ctx, + out_grad[conv::kOut], + in_data[conv::kData], + req[conv::kWeight], + in_grad[conv::kWeight], + param_.no_bias ? OpReqType() : req[conv::kBias], + param_.no_bias ? nullptr : &in_grad[conv::kBias]); + } + + private: + Tensor _Forward(const OpContext& ctx, + const TBlob& in_data, + const TBlob& in_weights, + const TBlob* in_bias, + const OpReqType req, + const TBlob& out_data) { + using namespace mshadow; + using namespace mshadow::expr; + LayerSetUp(in_data.shape_, out_data.shape_); Stream* s = ctx.get_stream(); + Tensor workspace; // initialize weight and col_buffer 3D tensors for using gemm index_t M = conv_out_channels_ / group_; index_t N = conv_out_spatial_dim_; index_t K = kernel_dim_; Tensor weight_3d = - in_data[conv::kWeight].get_with_shape(Shape3(group_, M, K), s); + in_weights.get_with_shape(Shape3(group_, M, K), s); Tensor output_4d = - out_data[conv::kOut].get_with_shape(Shape4(num_, group_, M, N), s); + out_data.get_with_shape(Shape4(num_, group_, M, N), s); // no need to allocating memory and reordering in memory if (is_1x1_) { Tensor input_4d = - in_data[conv::kData].get_with_shape(Shape4(num_, group_, K, N), s); + in_data.get_with_shape(Shape4(num_, group_, K, N), s); for (index_t n = 0; n < num_; ++n) { Tensor input_3d = input_4d[n]; Tensor output_3d = output_4d[n]; for (index_t g = 0; g < group_; ++g) { - linalg_gemm(weight_3d[g], input_3d[g], output_3d[g], false, false, s, req[conv::kOut]); + linalg_gemm(weight_3d[g], input_3d[g], output_3d[g], false, false, s, req); } } } else { // allocate workspace for col_buffer - Tensor workspace = - ctx.requested[conv::kTempSpace].get_space_typed(Shape1(col_buffer_size_), - s); + workspace = ctx.requested[conv::kTempSpace].get_space_typed( + Shape1(col_buffer_size_), s); // calculate the shape of col_buffer mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); for (int i = 1; i < col_buffer_shape.ndim(); ++i) { - col_buffer_shape[i] = out_data[0].shape_[i + 1]; + col_buffer_shape[i] = out_data.shape_[i + 1]; } // create a column buffer using workspace and col_buffer_shape TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); @@ -285,8 +331,8 @@ class ConvolutionOp { for (index_t n = 0; n < num_; ++n) { // transform image to col_buffer in order to use gemm im2col(s, - in_data[conv::kData].dptr() + n * input_dim_, - in_data[conv::kData].shape_, + in_data.dptr() + n * input_dim_, + in_data.shape_, col_buffer.shape_, param_.kernel, param_.pad, @@ -296,80 +342,65 @@ class ConvolutionOp { Tensor output_3d = output_4d[n]; for (index_t g = 0; g < group_; ++g) { // Legacy approach shown here for comparison: - // Assign(output_3d[g], req[conv::kOut], dot(weight_3d[g], col_buffer_3d[g])); - linalg_gemm( - weight_3d[g], col_buffer_3d[g], output_3d[g], false, false, s, req[conv::kOut]); + // Assign(output_3d[g], req, dot(weight_3d[g], col_buffer_3d[g])); + linalg_gemm(weight_3d[g], col_buffer_3d[g], output_3d[g], false, false, s, req); } } } if (bias_term_) { - Tensor bias = in_data[conv::kBias].get(s); - Tensor output_3d = out_data[conv::kOut].get_with_shape( + CHECK(in_bias != nullptr); + Tensor bias = in_bias->get(s); + Tensor output_3d = out_data.get_with_shape( Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s); // has bias term, broadcast it to the same shape of output_3d in channel dim output_3d += mshadow::expr::broadcast<1>(bias, output_3d.shape_); } + return workspace; } - void Backward(const OpContext& ctx, - const std::vector& out_grad, - const std::vector& in_data, - const std::vector& req, - const std::vector& in_grad) { + // Computes dLoss/dData + Tensor _BackwardData(const OpContext& ctx, + const TBlob& out_grad, + const TBlob& weights, + const OpReqType data_grad_req, + const TBlob& data_grad_dst) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), 1U); - // We expect 2 inputs: in data and weight. We don't need bias for - // computing gradient. - size_t expected = param_.no_bias == 0 ? 3 : 2; - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(in_grad.size(), expected); - CHECK_EQ(req.size(), expected); - CHECK_EQ(in_data[conv::kWeight].CheckContiguous(), true); - LayerSetUp(in_grad[conv::kData].shape_, out_grad[conv::kOut].shape_); + CHECK_EQ(weights.CheckContiguous(), true); + LayerSetUp(data_grad_dst.shape_, out_grad.shape_); Stream* s = ctx.get_stream(); + Tensor workspace; // initialize weight and col_buffer 3D tensors for using gemm - // For computing dLoss/d(in_data[kData]) index_t M = kernel_dim_; index_t N = conv_out_spatial_dim_; index_t K = conv_out_channels_ / group_; Tensor weight_3d = - in_data[conv::kWeight].get_with_shape(Shape3(group_, K, M), s); + weights.get_with_shape(Shape3(group_, K, M), s); Tensor out_grad_4d = - out_grad[conv::kOut].get_with_shape(Shape4(num_, group_, K, N), s); - // For computing dLoss/dWeight - Tensor dweight_3d = - in_grad[conv::kWeight].get_with_shape(Shape3(group_, K, M), s); + out_grad.get_with_shape(Shape4(num_, group_, K, N), s); // no need to allocating memory and reordering in memory if (is_1x1_) { - Tensor input_4d = - in_data[conv::kData].get_with_shape(Shape4(num_, group_, M, N), s); Tensor in_grad_4d = - in_grad[conv::kData].get_with_shape(Shape4(num_, group_, M, N), s); + data_grad_dst.get_with_shape(Shape4(num_, group_, M, N), s); for (index_t n = 0; n < num_; ++n) { - Tensor input_3d = input_4d[n]; Tensor in_grad_3d = in_grad_4d[n]; Tensor out_grad_3d = out_grad_4d[n]; - // gradient w.r.t. input data for (index_t g = 0; g < group_; ++g) { linalg_gemm(weight_3d[g], out_grad_3d[g], in_grad_3d[g], true, false, s); - auto request = (n == 0) ? req[conv::kWeight] : kAddTo; - linalg_gemm(out_grad_3d[g], input_3d[g], dweight_3d[g], false, true, s, request); } } } else { // allocate workspace for col_buffer - Tensor workspace = - ctx.requested[conv::kTempSpace].get_space_typed(Shape1(col_buffer_size_), - s); + workspace = ctx.requested[conv::kTempSpace].get_space_typed( + Shape1(col_buffer_size_), s); // calculate the shape of col_buffer mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); for (int i = 1; i < col_buffer_shape.ndim(); ++i) { - col_buffer_shape[i] = out_grad[conv::kData].shape_[i + 1]; + col_buffer_shape[i] = out_grad.shape_[i + 1]; } // create a column buffer using workspace and col_buffer_shape TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); @@ -377,27 +408,81 @@ class ConvolutionOp { col_buffer.get_with_shape(Shape3(group_, M, N), s); for (index_t n = 0; n < num_; ++n) { Tensor out_grad_3d = out_grad_4d[n]; - // gradient w.r.t. input data for (index_t g = 0; g < group_; ++g) { - // Legacy approach shown here for comparison: - // col_buffer_3d[g] = dot(weight_3d[g].T(), out_grad_3d[g]); linalg_gemm(weight_3d[g], out_grad_3d[g], col_buffer_3d[g], true, false, s); } col2im(s, col_buffer.dptr(), - in_grad[conv::kData].shape_, + data_grad_dst.shape_, col_buffer.shape_, param_.kernel, param_.pad, param_.stride, param_.dilate, - in_grad[conv::kData].dptr() + n * input_dim_, - req[conv::kData]); + data_grad_dst.dptr() + n * input_dim_, + data_grad_req); + } + } + return workspace; + } - // gradient w.r.t. weight, dWeight should accumulate across the batch and group + // Computes dLoss/dWeights and dLoss/dBias + void _BackwardWeightsBias(Tensor workspace, + const OpContext& ctx, + const TBlob& out_grad, + const TBlob& data, + const OpReqType weights_grad_req, + const TBlob& weights_grad_dst, + const OpReqType bias_grad_req, + const TBlob* const bias_grad_dst) { + using namespace mshadow; + using namespace mshadow::expr; + LayerSetUp(data.shape_, out_grad.shape_); + Stream* s = ctx.get_stream(); + + // initialize weight and col_buffer 3D tensors for using gemm + index_t M = kernel_dim_; + index_t N = conv_out_spatial_dim_; + index_t K = conv_out_channels_ / group_; + Tensor out_grad_4d = + out_grad.get_with_shape(Shape4(num_, group_, K, N), s); + Tensor dweight_3d = + weights_grad_dst.get_with_shape(Shape3(group_, K, M), s); + + // no need to allocating memory and reordering in memory + if (is_1x1_) { + Tensor input_4d = + data.get_with_shape(Shape4(num_, group_, M, N), s); + for (index_t n = 0; n < num_; ++n) { + Tensor input_3d = input_4d[n]; + Tensor out_grad_3d = out_grad_4d[n]; + for (index_t g = 0; g < group_; ++g) { + auto request = (n == 0) ? weights_grad_req : kAddTo; + linalg_gemm(out_grad_3d[g], input_3d[g], dweight_3d[g], false, true, s, request); + } + } + } else { + // allocate workspace for col_buffer + if (workspace.dptr_ == nullptr) { + workspace = ctx.requested[conv::kTempSpace].get_space_typed( + Shape1(col_buffer_size_), s); + } + // calculate the shape of col_buffer + mxnet::TShape col_buffer_shape(num_spatial_axes_ + 1, 1); + col_buffer_shape[0] = conv_in_channels_ * param_.kernel.Size(); + for (int i = 1; i < col_buffer_shape.ndim(); ++i) { + col_buffer_shape[i] = out_grad.shape_[i + 1]; + } + // create a column buffer using workspace and col_buffer_shape + TBlob col_buffer(workspace.dptr_, col_buffer_shape, xpu::kDevMask, DataType::kFlag); + Tensor col_buffer_3d = + col_buffer.get_with_shape(Shape3(group_, M, N), s); + for (index_t n = 0; n < num_; ++n) { + Tensor out_grad_3d = out_grad_4d[n]; + // dWeight should accumulate across the batch and group im2col(s, - in_data[conv::kData].dptr() + n * input_dim_, - in_data[conv::kData].shape_, + data.dptr() + n * input_dim_, + data.shape_, col_buffer.shape_, param_.kernel, param_.pad, @@ -405,24 +490,22 @@ class ConvolutionOp { param_.dilate, col_buffer.dptr()); for (index_t g = 0; g < group_; ++g) { - auto request = (n == 0) ? req[conv::kWeight] : kAddTo; - // Legacy approach shown here for comparison: - // Assign(dweight_3d[g], request, dot(out_grad_3d[g], col_buffer_3d[g].T())); + auto request = (n == 0) ? weights_grad_req : kAddTo; linalg_gemm(out_grad_3d[g], col_buffer_3d[g], dweight_3d[g], false, true, s, request); } } } - // gradient w.r.t bias + // bias gradient if (bias_term_) { - Tensor dbias = in_grad[conv::kBias].get(s); - Tensor dout = out_grad[conv::kOut].get_with_shape( + CHECK(bias_grad_dst != nullptr); + Tensor dbias = bias_grad_dst->get(s); + Tensor dout = out_grad.get_with_shape( Shape3(num_, conv_out_channels_, conv_out_spatial_dim_), s); - ASSIGN_DISPATCH(dbias, req[conv::kBias], sumall_except_dim<1>(dout)); + ASSIGN_DISPATCH(dbias, bias_grad_req, sumall_except_dim<1>(dout)); } } - private: void LayerSetUp(const mxnet::TShape& ishape, const mxnet::TShape& oshape) { channel_axis_ = 1; // hard code channel axis const index_t first_spatial_axis = channel_axis_ + 1; @@ -478,6 +561,9 @@ class ConvolutionOp { index_t num_kernels_col2im_; bool bias_term_; // has bias term? bool is_1x1_; + + template + friend class DeconvolutionOp; }; // class ConvolutionOp template diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index 8b2ae8e0c8a8..51a7a8038dee 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -36,6 +36,7 @@ #include #include "../operator_common.h" #include "../linalg.h" +#include "convolution-inl.h" namespace mxnet { namespace op { @@ -96,7 +97,7 @@ struct DeconvolutionParam : public dmlc::Parameter { .describe("Shape of the output tensor: (w,), (h, w) or (d, h, w)."); DMLC_DECLARE_FIELD(num_filter).set_lower_bound(1).describe("Number of output filters."); DMLC_DECLARE_FIELD(num_group).set_default(1).describe("Number of groups partition."); - DMLC_DECLARE_FIELD(workspace).set_default(512).set_lower_bound(0).describe( + DMLC_DECLARE_FIELD(workspace).set_default(1024).set_lower_bound(0).describe( "Maximum temporary workspace allowed (MB) in deconvolution." "This parameter has two usages. When CUDNN is not used, it determines the " "effective batch size of the deconvolution kernel. When CUDNN is used, " @@ -278,9 +279,8 @@ namespace op { template class DeconvolutionOp { public: - void Init(DeconvolutionParam p) { - this->param_ = p; - // convert MBytes first to Bytes and then to elements. + void Init(DeconvolutionParam dp) { + param_ = dp; param_.workspace = (param_.workspace << 20) / sizeof(DType); } @@ -288,106 +288,28 @@ class DeconvolutionOp { const std::vector& in_data, const std::vector& req, const std::vector& out_data) { - using namespace mshadow; - using namespace mshadow::expr; - - if (param_.kernel.ndim() > 2) { - LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported"; - } - - CHECK_EQ(req[deconv::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; + CHECK_EQ(req[deconv::kOut], kWriteTo); CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); - Stream* s = ctx.get_stream(); - auto in_data_shape = in_data[deconv::kData].shape_; - Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); - Tensor out = TBlobTo4DTensor(out_data[deconv::kOut], s); - index_t o_pad[2], o_adj[2]; - if (param_.kernel.ndim() == 2) { - param_.InferPad(mxnet::TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); - } else { - index_t o_pad_1D[1], o_adj_1D[1]; - param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); - o_pad[0] = 0; - o_pad[1] = o_pad_1D[0]; - o_adj[0] = 0; - o_adj[1] = o_adj_1D[0]; - } - auto stride = param_.kernel.ndim() == 2 ? param_.stride : mxnet::TShape({1, param_.stride[0]}); - auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : mxnet::TShape({1, param_.dilate[0]}); - auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : mxnet::TShape({1, param_.kernel[0]}); - auto kernel_size = kernel.Size(); - - Shape<3> wmat_shape = Shape3(param_.num_group, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * kernel_size); - Tensor wmat = - in_data[deconv::kWeight].get_with_shape(wmat_shape, s); -#if defined(__CUDACC__) - CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) - << "Must init CuBLAS handle in stream"; -#endif - const index_t nbatch = data.size(0); - Tensor workspace = - ctx.requested[deconv::kTempSpace].get_space_typed( - Shape1(this->InitTemp(out.shape_, data.shape_)), s); - for (index_t i = 0; i < nbatch; i += nstep_) { - const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor( - workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor( - workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), - s); - temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (o_pad[0] == 0 && o_pad[1] == 0) { - temp_col = unpack_patch2col(out.Slice(i, i + step), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - temp_col = unpack_patch2col(pad(out.Slice(i, i + step), o_pad[0], o_pad[1]), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } - const index_t gstride = temp_col.size(0) / param_.num_group; - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - mshadow::Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); - // Legacy approach shown here for comparison: - // tmpc = dot(wmat[gid].T(), temp_dst[gid]); - linalg_gemm(wmat[gid], temp_dst[gid], tmpc, true, false, s); - } - if (o_pad[0] == 0 && o_pad[1] == 0) { - out.Slice(i, i + step) = pack_col2patch(temp_col, - out.Slice(i, i + step).shape_, - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - Shape<4> pshape = out.Slice(i, i + step).shape_; - pshape[2] += 2 * o_pad[0]; - pshape[3] += 2 * o_pad[1]; - out.Slice(i, i + step) = crop( - pack_col2patch( - temp_col, pshape, kernel[0], kernel[1], stride[0], stride[1], dilate[0], dilate[1]), - out[i][0].shape_); - } - } + + if (need_init_conv) + InitConv(in_data[deconv::kData]); + + conv_op._BackwardData(ctx, + in_data[deconv::kData], + in_data[deconv::kWeight], + req[deconv::kOut], + out_data[deconv::kOut]); + if (!param_.no_bias) { - // add bias, broadcast bias to dim 1: channel - Tensor bias = in_data[deconv::kBias].get(s); - out += mshadow::expr::broadcast<1>(bias, out.shape_); + Stream* s = ctx.get_stream(); + const TShape& out_shape = out_data[deconv::kOut].shape_; + Tensor bias = in_data[deconv::kBias].get(s); + Tensor output_3d = out_data[deconv::kOut].get_with_shape( + Shape3(out_shape[0], out_shape[1], out_shape.ProdShape(2, out_shape.ndim())), s); + // broadcast bias to the same shape of output_3d in channel dim + output_3d += mshadow::expr::broadcast<1>(bias, output_3d.shape_); } } @@ -398,145 +320,64 @@ class DeconvolutionOp { const std::vector& in_grad) { using namespace mshadow; using namespace mshadow::expr; - // TODO(bing): check the BLAS Handle, be careful + + const size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK_EQ(out_grad.size(), 1U); - size_t expected = param_.no_bias == 0 ? 3 : 2; CHECK_EQ(in_data.size(), expected); CHECK_EQ(in_grad.size(), expected); CHECK_EQ(req.size(), expected); - CHECK_EQ(in_data[deconv::kWeight].CheckContiguous(), true); - // get data - Stream* s = ctx.get_stream(); - auto in_data_shape = in_data[deconv::kData].shape_; - Tensor data = TBlobTo4DTensor(in_data[deconv::kData], s); - Tensor grad = TBlobTo4DTensor(out_grad[deconv::kOut], s); - Tensor gdata = TBlobTo4DTensor(in_grad[deconv::kData], s); - - index_t o_pad[2], o_adj[2]; - if (param_.kernel.ndim() == 2) { - param_.InferPad(mxnet::TShape({in_data_shape[2], in_data_shape[3]}), o_pad, o_adj); - } else { - index_t o_pad_1D[1], o_adj_1D[1]; - param_.InferPad({in_data_shape[2]}, o_pad_1D, o_adj_1D); - o_pad[0] = 0; - o_pad[1] = o_pad_1D[0]; - o_adj[0] = 0; - o_adj[1] = o_adj_1D[0]; - } - auto stride = param_.kernel.ndim() == 2 ? param_.stride : mxnet::TShape({1, param_.stride[0]}); - auto dilate = param_.kernel.ndim() == 2 ? param_.dilate : mxnet::TShape({1, param_.dilate[0]}); - auto kernel = param_.kernel.ndim() == 2 ? param_.kernel : mxnet::TShape({1, param_.kernel[0]}); - auto kernel_size = kernel.Size(); - - Shape<3> wmat_shape = Shape3(param_.num_group, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group * kernel_size); - Tensor wmat = - in_data[deconv::kWeight].get_with_shape(wmat_shape, s); - Tensor gwmat = - in_grad[deconv::kWeight].get_with_shape(wmat_shape, s); -#if defined(__CUDACC__) - CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) - << "Must init CuBLAS handle in stream"; -#endif - - const index_t nbatch = data.size(0); - Tensor workspace = - ctx.requested[deconv::kTempSpace].get_space_typed( - Shape1(this->InitTemp(grad.shape_, data.shape_)), s); - for (index_t i = 0; i < nbatch; i += nstep_) { - const index_t step = std::min(nstep_, nbatch - i); - Tensor temp_col = Tensor( - workspace.dptr_, Shape2(shape_colunit_[0], shape_colunit_[1] * step), s); - Tensor temp_dst = Tensor( - workspace.dptr_ + temp_col.shape_.Size(), - Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * step), - s); - temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (o_pad[0] == 0 && o_pad[1] == 0) { - temp_col = unpack_patch2col(grad.Slice(i, i + step), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } else { - temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]), - kernel[0], - kernel[1], - stride[0], - stride[1], - dilate[0], - dilate[1]); - } - const index_t gstride = temp_col.size(0) / param_.num_group; - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); - if (i == 0) { - Tensor tmp_gwmat = gwmat[gid]; - // Legacy approach shown here for comparison: - // Assign(tmp_gwmat, req[deconv::kWeight], dot(temp_dst[gid], tmpc.T())); - linalg_gemm(temp_dst[gid], tmpc, tmp_gwmat, false, true, s, req[deconv::kWeight]); - } else { - // Legacy approach shown here for comparison: - // gwmat[gid] += dot(temp_dst[gid], tmpc.T()); - linalg_gemm(temp_dst[gid], tmpc, gwmat[gid], false, true, s, kAddTo); - } - } - if (req[deconv::kData] == kWriteTo || req[deconv::kData] == kWriteInplace || - req[deconv::kData] == kAddTo) { - for (uint32_t gid = 0; gid < param_.num_group; ++gid) { - Tensor tmpc = temp_col.Slice(gstride * gid, gstride * (gid + 1)); - // Legacy approach shown here for comparison: - // temp_dst[gid] = dot(wmat[gid], tmpc); - linalg_gemm(wmat[gid], tmpc, temp_dst[gid], false, false, s); - } - Assign( - gdata.Slice(i, i + step), - req[deconv::kData], - (swapaxis<1, 0>(reshape( - temp_dst, mshadow::Shape4(gdata.shape_[1], step, gdata.size(2), gdata.size(3)))))); - } - } + + if (need_init_conv) + InitConv(in_data[deconv::kData]); + + // data gradient + auto workspace = conv_op._Forward(ctx, + out_grad[deconv::kOut], + in_data[deconv::kWeight], + nullptr, + req[deconv::kData], + in_grad[deconv::kData]); + // weights gradient + conv_op._BackwardWeightsBias(workspace, + ctx, + in_data[deconv::kData], + out_grad[deconv::kOut], + req[deconv::kWeight], + in_grad[deconv::kWeight], + OpReqType(), + nullptr); + // bias gradient if (!param_.no_bias) { - Tensor gbias = in_grad[deconv::kBias].get(s); - Assign(gbias, req[deconv::kBias], sumall_except_dim<1>(grad)); + Stream* s = ctx.get_stream(); + const TShape& out_shape = out_grad[deconv::kOut].shape_; + Tensor dbias = in_grad[deconv::kBias].get(s); + Tensor dout = out_grad[deconv::kOut].get_with_shape( + Shape3(out_shape[0], out_shape[1], out_shape.ProdShape(2, out_shape.ndim())), s); + ASSIGN_DISPATCH(dbias, req[deconv::kBias], sumall_except_dim<1>(dout)); } } private: - inline index_t InitTemp(const mshadow::Shape<4>& ishape, const mshadow::Shape<4>& oshape) { - const index_t ksize = param_.kernel.Size(); - shape_colunit_ = mshadow::Shape2(ishape[1] * ksize, oshape[2] * oshape[3]); - shape_dstunit_ = - mshadow::Shape3(param_.num_group, oshape[1] / param_.num_group, oshape[2] * oshape[3]); - // See convolution for workspace calculations. nstep_ will be the effective batch size - nstep_ = std::max( - std::min(param_.workspace / (shape_colunit_.Size() + shape_dstunit_.Size()), - ishape[0]), - 1); - - mshadow::Shape<2> scol = mshadow::Shape2(shape_colunit_[0], shape_colunit_[1] * nstep_); - mshadow::Shape<3> sdst = - mshadow::Shape3(shape_dstunit_[0], shape_dstunit_[1], shape_dstunit_[2] * nstep_); - index_t required_size = scol.Size() + sdst.Size(); - return required_size; - } - - inline Tensor TBlobTo4DTensor(const TBlob& tb, Stream* s) { - using namespace mshadow; - if (param_.kernel.ndim() == 2) - return tb.get(s); - else - return tb.get_with_shape(Shape4(tb.shape_[0], tb.shape_[1], 1, tb.shape_[2]), - s); + void InitConv(const TBlob& in_data) { + ConvolutionParam cp; + cp.kernel = param_.kernel; + cp.stride = param_.stride; + cp.dilate = param_.dilate; + cp.pad = param_.pad; + cp.num_filter = in_data.shape_[1]; + cp.num_group = param_.num_group; + cp.workspace = (param_.workspace * sizeof(DType)) >> 20; + cp.no_bias = true; + cp.cudnn_tune = param_.cudnn_tune; + cp.cudnn_off = param_.cudnn_off; + cp.layout = param_.layout; + conv_op.Init(cp); + need_init_conv = false; } + bool need_init_conv = true; DeconvolutionParam param_; - mshadow::Shape<2> shape_colunit_; - mshadow::Shape<3> shape_dstunit_; - index_t nstep_; + ConvolutionOp conv_op; }; // class DeconvolutionOp template diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index aa1f1daa36aa..947ba7f2a640 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -98,12 +98,6 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector* in_shape, mxnet::ShapeVector* out_shape) { const DeconvolutionParam& param_ = nnvm::get(attrs.parsed); -#if MXNET_USE_CUDNN == 0 - if (param_.kernel.ndim() > 2) { - LOG(FATAL) << "If not using CUDNN, only 1D or 2D Deconvolution is supported"; - return false; - } -#endif // CUDNN using namespace mshadow; if (!param_.no_bias) { @@ -413,9 +407,9 @@ DMLC_REGISTER_PARAMETER(DeconvolutionParam); NNVM_REGISTER_OP(Deconvolution) .add_alias("_npx_deconvolution") .describe( - "Computes 1D or 2D transposed convolution (aka fractionally strided convolution) of the " - "input tensor. This operation can be seen as the gradient of Convolution operation with " - "respect to its input. Convolution usually reduces the size of the input. Transposed " + "Computes 1D, 2D or 3D transposed convolution (aka fractionally strided convolution) of " + "the input tensor. This operation can be seen as the gradient of Convolution operation " + "with respect to its input. Convolution usually reduces the size of the input. Transposed " "convolution works the other way, going from a smaller input to a larger output while " "preserving the connectivity pattern.") .set_num_inputs([](const NodeAttrs& attrs) { diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h new file mode 100644 index 000000000000..a66d3a887326 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution-inl.h @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_deconvolution-inl.h + * Naming convention: + * ________ + * (src) data --->|Deconv| + * weights --->| FWD |---> out (dst) + * bias --->|______| + * ________ + * (diff_src) data_grad <---|Deconv|<--- out_grad (diff_dst) + * (diff_weight) weights_grad <---| BWD |<--- data (src) + * (diff_bias) bias_grad <---| |<--- weight + * |______|<--- bias + * + * "out" in this (and .cc) file will always refer to the output of Deconv FWD and + * "out_grad" to its gradient. The corresponding MKLDNN names are in parentheses. + */ +#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H_ +#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include +#include +#include + +#include "../deconvolution-inl.h" +#include "./mkldnn_base-inl.h" +#include "./mkldnn_ops-inl.h" + +namespace mxnet { +namespace op { + +using deconv_fwd_t = mkldnn::deconvolution_forward; +using deconv_fwd_pd_t = mkldnn::deconvolution_forward::primitive_desc; + +using deconv_bwd_data_t = mkldnn::deconvolution_backward_data; +using deconv_bwd_data_pd_t = mkldnn::deconvolution_backward_data::primitive_desc; + +using deconv_bwd_weights_t = mkldnn::deconvolution_backward_weights; +using deconv_bwd_weights_pd_t = mkldnn::deconvolution_backward_weights::primitive_desc; + +// Swaps the logical order of dimensions that in plain format would correspond to input and output +// channels (for example: oihw => iohw, iohw => oihw, goihw => giohw). +inline mkldnn::memory::desc IOLogicalSwapDesc(const mkldnn::memory::desc& desc, + const uint32_t num_group) { + std::vector order(desc.data.ndims); + std::iota(std::begin(order), std::end(order), 0); + const int offset = static_cast(num_group > 1); + std::swap(order[offset + 0], order[offset + 1]); + return desc.permute_axes(order); +} + +// Applies IOLogicalSwapDesc to MKLDNN memory of arr +inline void IOLogicalSwapMKLDNNMem(const NDArray& arr, const uint32_t num_group) { + mkldnn::memory::desc desc; + if (arr.IsMKLDNNData()) { + desc = arr.GetMKLDNNData()->get_desc(); + } else { + // GetMKLDNNData won't take groups into account when creating mkldnn::memory, we need to use + // descriptor from GetWeightDesc but with default format + const auto& temp = GetWeightDesc(arr, num_group); + desc = mkldnn::memory::desc( + temp.dims(), + temp.data_type(), + static_cast(GetDefaultFormat(temp.data.ndims))); + } + const_cast(arr).UpdateMKLDNNMemDesc(IOLogicalSwapDesc(desc, num_group)); +} + +// Version of GetWeightsDesc for deconvolution (with swap) +inline mkldnn::memory::desc GetDeconvWeightsDesc(const NDArray& weights, const uint32_t num_group) { + return IOLogicalSwapDesc(GetWeightDesc(weights, num_group), num_group); +} + +class MKLDNNDeconvFwd { + public: + struct Tensors { + Tensors(const NDArray& data, + const NDArray& weights, + const NDArray* const bias, + const NDArray& out); + Tensors(const bool no_bias, + const std::vector& inputs, + const std::vector& outputs); + + const NDArray& data; + const NDArray& weights; + const NDArray* const bias; + const NDArray& out; + }; + + static MKLDNNDeconvFwd& GetCached(const DeconvolutionParam& param, const Tensors& tensors); + static std::shared_ptr CreatePrimitiveDesc(const DeconvolutionParam& param, + const Tensors& tensors); + + MKLDNNDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors); + void ControlWeightsFormat(const uint32_t num_group, + const bool is_train, + const NDArray& weights) const; + void Execute(const uint32_t num_group, const OpReqType req, const Tensors& tensors) const; + + private: + const mkldnn::memory* DataMem(const NDArray& data) const; + const mkldnn::memory* WeightsMem(const uint32_t num_group, const NDArray& weights) const; + const mkldnn::memory* BiasMem(const NDArray& bias) const; + + mkldnn_output_t OutMem(const OpReqType req, const NDArray& out) const; + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + +MKLDNNDeconvFwd::Tensors::Tensors(const bool no_bias, + const std::vector& inputs, + const std::vector& outputs) + : data(inputs[deconv::kData]), + weights(inputs[deconv::kWeight]), + bias(no_bias ? nullptr : &inputs[deconv::kBias]), + out(outputs[deconv::kOut]) {} + +MKLDNNDeconvFwd::Tensors::Tensors(const NDArray& data, + const NDArray& weights, + const NDArray* const bias, + const NDArray& out) + : data(data), weights(weights), bias(bias), out(out) {} + +MKLDNNDeconvFwd::MKLDNNDeconvFwd(const DeconvolutionParam& param, const Tensors& tensors) + : fwd_pd(CreatePrimitiveDesc(param, tensors)) { + fwd = std::make_shared(*fwd_pd); +} + +inline const mkldnn::memory* MKLDNNDeconvFwd::DataMem(const NDArray& data) const { + return data.GetMKLDNNDataReorder(fwd_pd->src_desc()); +} + +inline const mkldnn::memory* MKLDNNDeconvFwd::WeightsMem(const uint32_t num_group, + const NDArray& weights) const { + return GetWeights(weights, fwd_pd->weights_desc(), num_group); +} + +inline const mkldnn::memory* MKLDNNDeconvFwd::BiasMem(const NDArray& bias) const { + return bias.GetMKLDNNData(); +} + +inline mkldnn_output_t MKLDNNDeconvFwd::OutMem(const OpReqType req, const NDArray& out) const { + return CreateMKLDNNMem(out, fwd_pd->dst_desc(), req); +} + +class MKLDNNDeconvBwd { + public: + struct ReadTensors { + ReadTensors(const bool no_bias, const std::vector& inputs); + const NDArray& data; + const NDArray& weights; + const NDArray* const bias; + const NDArray& out_grad; + }; + struct WriteTensors { + WriteTensors(const bool no_bias, const std::vector& outputs); + const NDArray& data_grad; + const NDArray& weights_grad; + const NDArray* const bias_grad; + }; + + static MKLDNNDeconvBwd& GetCached(const DeconvolutionParam& param, + const ReadTensors& read_tensors); + + static std::shared_ptr CreateDataPrimitiveDesc( + const DeconvolutionParam& param, + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd); + + static std::shared_ptr CreateWeightsPrimitiveDesc( + const DeconvolutionParam& param, + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd); + + MKLDNNDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors); + + void Execute(const uint32_t num_group, + const std::vector& req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors) const; + + private: + void IOSwapWeightsTensors(const uint32_t num_group, + const std::vector& req, + const NDArray& weights, + const NDArray& weights_grad) const; + + // returns the output gradient memory used to calculate the data (input) gradient, + // which might be reused when calculating the gradient of weights + const mkldnn::memory* ScheduleBwdData(const uint32_t num_group, + const OpReqType req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors) const; + + void ScheduleBwdWeights(const uint32_t num_group, + const std::vector& req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors, + const mkldnn::memory* const out_grad_mem) const; + + const mkldnn::memory* DataMem(const NDArray& data) const; + const mkldnn::memory* WeightsMem(const uint32_t num_group, const NDArray& weights) const; + + // for calculating the gradient of data (input) + const mkldnn::memory* OutGradMem(const NDArray& out_grad) const; + // for calculating the gradient of weights + const mkldnn::memory* OutGradMem(const NDArray& out_grad, + const mkldnn::memory* const out_grad_mem) const; + + mkldnn_output_t DataGradMem(const OpReqType req, const NDArray& data_grad) const; + mkldnn_output_t WeightsGradMem(const uint32_t num_group, + const OpReqType req, + const NDArray& weights_grad) const; + mkldnn_output_t BiasGradMem(const OpReqType req, const NDArray* const bias) const; + + std::shared_ptr bwd_data_pd; + std::shared_ptr bwd_weights_pd; + std::shared_ptr bwd_data; + std::shared_ptr bwd_weights; +}; + +MKLDNNDeconvBwd::ReadTensors::ReadTensors(const bool no_bias, const std::vector& inputs) + : data(inputs[deconv::kData + 1]), + weights(inputs[deconv::kWeight + 1]), + bias(no_bias ? nullptr : &inputs[deconv::kBias + 1]), + out_grad(inputs[deconv::kOut]) {} + +MKLDNNDeconvBwd::WriteTensors::WriteTensors(const bool no_bias, const std::vector& outputs) + : data_grad(outputs[deconv::kData]), + weights_grad(outputs[deconv::kWeight]), + bias_grad(no_bias ? nullptr : &outputs[deconv::kBias]) {} + +MKLDNNDeconvBwd::MKLDNNDeconvBwd(const DeconvolutionParam& param, const ReadTensors& read_tensors) { + const auto& fwd_pd = MKLDNNDeconvFwd::CreatePrimitiveDesc( + param, + MKLDNNDeconvFwd::Tensors( + read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad)); + bwd_data_pd = CreateDataPrimitiveDesc(param, read_tensors, *fwd_pd); + bwd_weights_pd = CreateWeightsPrimitiveDesc(param, read_tensors, *fwd_pd); + bwd_data = std::make_shared(*bwd_data_pd); + bwd_weights = std::make_shared(*bwd_weights_pd); +} + +inline void MKLDNNDeconvBwd::IOSwapWeightsTensors(const uint32_t num_group, + const std::vector& req, + const NDArray& weights, + const NDArray& weights_grad) const { + if (req[deconv::kData]) { + IOLogicalSwapMKLDNNMem(weights, num_group); + } + if (req[deconv::kWeight] || (req.size() < deconv::kBias && req[deconv::kBias])) { + IOLogicalSwapMKLDNNMem(weights_grad, num_group); + } +} + +inline const mkldnn::memory* MKLDNNDeconvBwd::DataMem(const NDArray& data) const { + return data.GetMKLDNNDataReorder(bwd_weights_pd->src_desc()); +} + +inline const mkldnn::memory* MKLDNNDeconvBwd::WeightsMem(const uint32_t num_group, + const NDArray& weights) const { + return GetWeights(weights, bwd_data_pd->weights_desc(), num_group); +} + +inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem(const NDArray& out_grad) const { + return out_grad.GetMKLDNNDataReorder(bwd_data_pd->diff_dst_desc()); +} + +inline const mkldnn::memory* MKLDNNDeconvBwd::OutGradMem( + const NDArray& out_grad, + const mkldnn::memory* const out_grad_mem) const { + return (out_grad_mem && out_grad_mem->get_desc() == bwd_weights_pd->diff_dst_desc()) + ? out_grad_mem + : out_grad.GetMKLDNNDataReorder(bwd_weights_pd->diff_dst_desc()); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::DataGradMem(const OpReqType req, + const NDArray& data_grad) const { + return CreateMKLDNNMem(data_grad, bwd_data_pd->diff_src_desc(), req); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::WeightsGradMem(const uint32_t num_group, + const OpReqType req, + const NDArray& weights_grad) const { + // CreateMKLDNNWeightGrad always creates a new tensor as IsDefaultFormat always fails (because + // of the logical swap - explained in MKLDNNDeconvFwd::Execute). We try to reuse weights_grad + // memory (which, when not swapped, is always in default format), so here we check if after a + // swap, weights_md will have a default format + const auto& weights_md = bwd_weights_pd->diff_weights_desc(); + if (req == OpReqType::kWriteTo && IsDefaultFormat(IOLogicalSwapDesc(weights_md, num_group))) { + return {OutDataOp::Noop, const_cast(weights_grad).CreateMKLDNNData(weights_md)}; + } + return CreateMKLDNNWeightGrad(weights_grad, weights_md, req); +} + +inline mkldnn_output_t MKLDNNDeconvBwd::BiasGradMem(const OpReqType req, + const NDArray* const bias) const { + return bias ? CreateMKLDNNMem(*bias, bwd_weights_pd->diff_bias_desc(), req) + : mkldnn_output_t(OutDataOp::Noop, nullptr); +} + +// Utility class for creating operation descriptors of deconvolution primitives +class DeconvDescCreator { + public: + DeconvDescCreator(const DeconvolutionParam& param, + const NDArray& data, + const NDArray& weights, + const NDArray* const bias, + const NDArray& out); + + // Imposes plain formats on memory descriptors with padding (so the next selected implementation + // will pass CheckImplSizeReq). After calling this method, new primitive descriptor (with new + // operator descriptor) should be created, which should select an implementation with matching + // size requirements. + // data_size, weights_size, out_size - size requirements of current implementation + // Returns whether successfully imposed a plain format on any of the data, weights, and output + // memory descriptors. + bool ImposePlainWherePadding(const size_t data_size, + const size_t weights_size, + const size_t out_size); + bool CheckImplSizeReq(const size_t data_size, + const size_t weights_size, + const size_t out_size) const; + + deconv_fwd_t::desc CreateFwdDesc() const; + deconv_bwd_data_t::desc CreateBwdDataDesc() const; + deconv_bwd_weights_t::desc CreateBwdWeightsDesc() const; + + private: + mkldnn::memory::desc data_md; + mkldnn::memory::desc weights_md; + mkldnn::memory::desc bias_md; + mkldnn::memory::desc out_md; + + mkldnn::memory::dims strides; + mkldnn::memory::dims padding; + mkldnn::memory::dims dilates; +}; + +inline bool DeconvDescCreator::CheckImplSizeReq(const size_t data_size, + const size_t weights_size, + const size_t out_size) const { + // MKLDNN introduced padded formats since 0.15 which require more memory + // compared to the actual size of the tensor. Currently, MKLDNN operators + // still reuse memory from memory planning, so here we need to accept only a + // kernel that has the expected memory size requirements (which is suboptimal) + return (data_size == GetMemDescSize(data_md) && weights_size == GetMemDescSize(weights_md) && + out_size == GetMemDescSize(out_md)); +} + +inline deconv_fwd_t::desc DeconvDescCreator::CreateFwdDesc() const { + return deconv_fwd_t::desc(mkldnn::prop_kind::forward_training, + mkldnn::algorithm::deconvolution_direct, + data_md, + weights_md, + bias_md, + out_md, + strides, + dilates, + padding, + padding); +} + +inline deconv_bwd_data_t::desc DeconvDescCreator::CreateBwdDataDesc() const { + return deconv_bwd_data_t::desc(mkldnn::algorithm::deconvolution_direct, + data_md, + weights_md, + out_md, + strides, + dilates, + padding, + padding); +} + +inline deconv_bwd_weights_t::desc DeconvDescCreator::CreateBwdWeightsDesc() const { + return deconv_bwd_weights_t::desc(mkldnn::algorithm::deconvolution_direct, + data_md, + weights_md, + bias_md, + out_md, + strides, + dilates, + padding, + padding); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_DECONVOLUTION_INL_H__ diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index f188f9fd041b..7621a510a0fa 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -19,529 +19,340 @@ /*! * \file mkldnn_deconvolution.cc - * \brief */ #if MXNET_USE_ONEDNN == 1 -#include "./mkldnn_base-inl.h" -#include "./mkldnn_ops-inl.h" - #include "../deconvolution-inl.h" +#include "./mkldnn_deconvolution-inl.h" namespace mxnet { namespace op { bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray& input) { - if (params.kernel.ndim() != 2) - return false; - return (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16) && - input.shape().ndim() == 4; -} - -static inline mkldnn::memory::desc GetBiasDesc(mkldnn::memory::desc md) { - mkldnn::memory::dims dims(1); - // This is deconvolution on 4D data. The second dimension is the channel. - dims[0] = md.data.dims[1]; - return mkldnn::memory::desc(dims, - static_cast(md.data.data_type), - mkldnn::memory::format_tag::any); -} - -std::shared_ptr GetDeconvBwd_( - const mkldnn::memory::desc& data_md, - const mkldnn::memory::desc& weights_md, - bool has_bias, - const mkldnn::memory::desc& out_md, - const mkldnn::engine& engine, - const mkldnn::memory::dims& strides, - const mkldnn::memory::dims& padding, - const mkldnn::memory::dims& dilates) { - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - if (!has_bias) { - mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, - mkldnn::algorithm::convolution_direct, - out_md, - weights_md, - data_md, - strides, - dilates, - padding, - padding); - auto deconv_pd = std::make_shared(desc, engine); - while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } else { - auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_forward::desc desc(mkldnn::prop_kind::forward_training, - mkldnn::algorithm::convolution_direct, - out_md, - weights_md, - bias_md, - data_md, - strides, - dilates, - padding, - padding); - auto deconv_pd = std::make_shared(desc, engine); - while (deconv_pd->dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weights_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } -} - -std::shared_ptr GetDeconvFwdImpl( - const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - bool has_bias, - const NDArray& output) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - auto bwd_pd = - GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, padding, dilate); - mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct, - out_md, - weight_md, - data_md, - strides, - dilate, - padding, - padding); - auto deconv_pd = - std::make_shared(desc, engine, *bwd_pd); - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->diff_src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->weights_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; -} - -std::shared_ptr GetDeconvBwdDataImpl( - const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - bool has_bias, - const NDArray& output) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - return GetDeconvBwd_(data_md, weight_md, has_bias, out_md, engine, strides, padding, dilate); + return params.kernel.ndim() >= 1 && params.kernel.ndim() <= 3 && + input.shape().ndim() == (params.kernel.ndim() + 2) && + (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16); } -std::shared_ptr GetDeconvBwdWeightsImpl( - const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - bool has_bias, - const NDArray& output, - const mkldnn::convolution_forward::primitive_desc& fwd_pd) { - auto data_md = GetMemDesc(data); - auto weight_md = GetWeightDesc(weights, param.num_group); - auto out_md = GetMemDesc(output); - auto engine = CpuEngine::Get()->get_engine(); - CHECK_GE(param.stride.ndim(), 2); - CHECK_GE(param.pad.ndim(), 2); - CHECK_GE(param.dilate.ndim(), 2); - mkldnn::memory::dims strides{0, 0}; - strides[0] = param.stride[0]; - strides[1] = param.stride[1]; - mkldnn::memory::dims padding{0, 0}; - padding[0] = param.pad[0]; - padding[1] = param.pad[1]; - mkldnn::memory::dims dilate{0, 0}; - dilate[0] = param.dilate[0] - 1; - dilate[1] = param.dilate[1] - 1; - - // MKL-DNN introduced padded formats since 0.15 which require more memory - // compared to the actual size of the tensor. Currently, MKL-DNN operators - // still reuse memory from memory planning, so here we need to select a - // suboptimal kernel for computation that has the expected memory size requirements - if (!has_bias) { - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - out_md, - weight_md, - data_md, - strides, - dilate, - padding, - padding); - auto deconv_pd = std::make_shared( - desc, engine, fwd_pd); - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->diff_weights_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } else { - auto bias_md = GetBiasDesc(data_md); - mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct, - out_md, - weight_md, - bias_md, - data_md, - strides, - dilate, - padding, - padding); - auto deconv_pd = std::make_shared( - desc, engine, fwd_pd); - while (deconv_pd->diff_dst_desc().get_size() != GetMemDescSize(data_md) || - deconv_pd->src_desc().get_size() != GetMemDescSize(out_md) || - deconv_pd->diff_weights_desc().get_size() != GetMemDescSize(weight_md)) { - CHECK(deconv_pd->next_impl()) << "No implementation"; - } - return deconv_pd; - } -} - -class MKLDNNDeconvForward { - public: - MKLDNNDeconvForward(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - bool has_bias, - const NDArray& output); - const mkldnn::convolution_backward_data& GetFwd() const { - return *fwd; - } - - const mkldnn::convolution_backward_data::primitive_desc& GetPd() const { - return *fwd_pd; - } - - private: - std::shared_ptr fwd; - std::shared_ptr fwd_pd; -}; // class MKLDNNDeconvForward - -MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - bool has_bias, - const NDArray& output) - : fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { - fwd = std::make_shared(GetPd()); -} +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const auto& param = nnvm::get(attrs.parsed); + const auto tensors = MKLDNNDeconvFwd::Tensors(param.no_bias, inputs, outputs); + const auto& fwd = MKLDNNDeconvFwd::GetCached(param, tensors); -static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, - const OpContext& ctx, - const NDArray& bias, - const std::vector& out_data) { - // add bias, broadcast bias to dim 1: channel - if (!param.no_bias) { - // MKLDNN only supports float right now. - typedef float DType; - Stream* s = ctx.get_stream(); - Tensor b = bias.data().get(s); - // The output data is stored in a special MKLDNN format, - // converts its format to the default format. - // Unfortunately, MKLDNN doesn't support broadcast. - auto out_data_def = out_data[deconv::kOut].Reorder2Default(); - Tensor out_cpu = out_data_def.data().get(s); - out_cpu += mshadow::expr::broadcast<1>(b, out_cpu.shape_); - } + fwd.ControlWeightsFormat(param.num_group, ctx.is_train, tensors.weights); + fwd.Execute(param.num_group, req[deconv::kOut], tensors); } -MKLDNNDeconvForward& GetDeconvFwd(const nnvm::NodeAttrs& attrs, - const NDArray& data, - const NDArray& weights, - const NDArray* bias, - const NDArray& output) { +MKLDNNDeconvFwd& MKLDNNDeconvFwd::GetCached(const DeconvolutionParam& param, + const Tensors& tensors) { + using deconv_fwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map fwds; + static thread_local deconv_fwd_map fwds; #else - static MX_THREAD_LOCAL std::unordered_map fwds; + static MX_THREAD_LOCAL deconv_fwd_map fwds; #endif - const DeconvolutionParam& param = nnvm::get(attrs.parsed); DeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); - if (bias) - key.AddSign(*bias); + key.AddSign(tensors.data); + key.AddSign(tensors.weights); + key.AddSign(tensors.out); + if (tensors.bias) { + key.AddSign(*tensors.bias); + } auto it = fwds.find(key); if (it == fwds.end()) { - bool has_bias = (bias != nullptr); - auto fwd = MKLDNNDeconvForward(param, data, weights, has_bias, output); - it = AddToCache(&fwds, key, fwd); + const MKLDNNDeconvFwd fwd(param, tensors); + it = AddToCache(&fwds, key, fwd); } return it->second; } -void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& in_data, - const std::vector& req, - const std::vector& out_data) { - TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - - auto& data = in_data[deconv::kData]; - auto& weight = in_data[deconv::kWeight]; - const NDArray* bias = param.no_bias ? nullptr : &in_data[deconv::kBias]; - - MKLDNNDeconvForward& fwd = GetDeconvFwd(attrs, data, weight, bias, out_data[deconv::kOut]); +std::shared_ptr MKLDNNDeconvFwd::CreatePrimitiveDesc( + const DeconvolutionParam& param, + const Tensors& tensors) { + DeconvDescCreator ddc(param, tensors.data, tensors.weights, tensors.bias, tensors.out); + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(ddc.CreateFwdDesc(), engine); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution forward propagation"; + *pd = deconv_fwd_pd_t(ddc.CreateFwdDesc(), engine); + } + return pd; +} - auto data_mem = data.GetMKLDNNDataReorder(fwd.GetPd().diff_dst_desc()); - const mkldnn::memory* weight_mem; - if (ctx.is_train) { +void MKLDNNDeconvFwd::ControlWeightsFormat(const uint32_t num_group, + const bool is_train, + const NDArray& weights) const { + if (is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it // to the default format for now. - if (weight.IsMKLDNNData()) - // This asks the engine to change the layout of the weight array after - // it's used. - weight.Reorder2DefaultAsync(); - weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); + if (weights.IsMKLDNNData()) { + // This asks the engine to change the layout of the weights array after it's used. + weights.Reorder2DefaultAsync(); + } } else { - // For inference, we want to reorder the weight array so we don't need to + // For inference, we want to reorder the weights array so we don't need to // reorder data every time. - if (weight.IsDefaultData()) { - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. - weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc()); - weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group); - + if (weights.IsDefaultData()) { + // We also need to modify the layout on the original weights array. + // The data conversion happens after the weights array is used. + weights.MKLDNNDataReorderAsync(IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group)); } else { - weight_mem = weight.GetMKLDNNData(); - CHECK(weight_mem->get_desc() == fwd.GetPd().weights_desc()); + CHECK(weights.GetMKLDNNData()->get_desc() == + IOLogicalSwapDesc(fwd_pd->weights_desc(), num_group)); } } - mkldnn_output_t out_mem; - out_mem = CreateMKLDNNMem(out_data[deconv::kOut], fwd.GetPd().diff_src_desc(), req[deconv::kOut]); - - mkldnn_args_map_t net_args; +} - net_args.insert({MKLDNN_ARG_DIFF_DST, *data_mem}); - net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem}); - net_args.insert({MKLDNN_ARG_DIFF_SRC, *out_mem.second}); - MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); - CommitOutput(out_data[deconv::kOut], out_mem); - MKLDNNStream::Get()->Submit(); +void MKLDNNDeconvFwd::Execute(const uint32_t num_group, + const OpReqType req, + const Tensors& tensors) const { + // MXNet (correctly) assumes that deconvolution is implemented using convolution primitives. + // For that, we would pass input tensor in place of output and output tensor in place of input + // (for appropriate convolution primitives: deconvolution forward = convolution backward data, + // deconvolution backward data = convolution forward). + // The convolution primitive expects weights tensor with the shape of + // (primitive_out_channels, primitive_in_channels, h, w), but with swapped input and output: + // primitive_out_channels = deconv_in_channels, primitive_in_channels = deconv_out_channels, + // so it becomes (deconv_in_channels, deconv_out_channels, h, w) and MXNet provides such tensor. + // + // MKLDNN deconvolution primitive also (as convolution) expects weights tensor with the shape of + // (primitive_out_channels, primitive_in_channels, h, w), but this time we don't swap input and + // output tensors, so: + // primitive_out_channels = deconv_out_channels, primitive_in_channels = deconv_in_channels, + // thus the current weights tensor won't fit (when deconv_out_channels != deconv_in_channels). + // However, underneath deconvolution MKLDNN also uses convolution, so even though it expects the + // weights tensor with the logical order of oihw, it wants its physical representation to + // match the order of iohw, which is the same as current weights tensor. + // + // So here we swap logical order of input and output dimensions for weights tensor just for + // MKLDNN operations. + IOLogicalSwapMKLDNNMem(tensors.weights, num_group); + { + mkldnn_args_map_t net_args; + const auto& out_mem = OutMem(req, tensors.out); + + net_args.insert({MKLDNN_ARG_SRC, *DataMem(tensors.data)}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *WeightsMem(num_group, tensors.weights)}); + net_args.insert({MKLDNN_ARG_DST, *out_mem.second}); + if (tensors.bias) { + net_args.insert({MKLDNN_ARG_BIAS, *BiasMem(*tensors.bias)}); + } - MKLDNNDeconvFwdBiasPostProcess(param, ctx, *bias, out_data); + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*fwd, net_args); + CommitOutput(tensors.out, out_mem); + MKLDNNStream::Get()->Submit(); + } + IOLogicalSwapMKLDNNMem(tensors.weights, num_group); // swap back from oihw to iohw } -class MKLDNNDeconvBackwardData { - std::shared_ptr bwd; +void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_NE(req[deconv::kWeight], kWriteInplace) << "Cannot write weights inplace"; - public: - std::shared_ptr bwd_pd; - MKLDNNDeconvBackwardData(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output); + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const auto& param = nnvm::get(attrs.parsed); + const auto read_tensors = MKLDNNDeconvBwd::ReadTensors(param.no_bias, inputs); + const auto write_tensors = MKLDNNDeconvBwd::WriteTensors(param.no_bias, outputs); + MKLDNNDeconvBwd& bwd = MKLDNNDeconvBwd::GetCached(param, read_tensors); - const mkldnn::convolution_forward& GetBwd() const { - return *bwd; - } - const mkldnn::convolution_forward::primitive_desc& GetDataPd() const { - return *bwd_pd; - } -}; - -MKLDNNDeconvBackwardData::MKLDNNDeconvBackwardData(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output) - : bwd_pd(GetDeconvBwdDataImpl(param, data, weights, false, output)) { - bwd = std::make_shared(GetDataPd()); + bwd.Execute(param.num_group, req, read_tensors, write_tensors); } -typedef ParamOpSign MKLDNNDeconvSignature; - -static inline MKLDNNDeconvBackwardData& GetDeconvBwdData(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output) { +MKLDNNDeconvBwd& MKLDNNDeconvBwd::GetCached(const DeconvolutionParam& param, + const ReadTensors& read_tensors) { + using deconv_bwd_map = std::unordered_map; #if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - bwds; + static thread_local deconv_bwd_map bwds; #else - static MX_THREAD_LOCAL std::unordered_map - bwds; + static MX_THREAD_LOCAL deconv_bwd_map bwds; #endif - MKLDNNDeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); + DeconvSignature key(param); + key.AddSign(read_tensors.data); + key.AddSign(read_tensors.weights); + key.AddSign(read_tensors.out_grad); + if (read_tensors.bias) { + key.AddSign(*read_tensors.bias); + } auto it = bwds.find(key); if (it == bwds.end()) { - auto bwd = MKLDNNDeconvBackwardData(param, data, weights, output); - it = AddToCache(&bwds, key, bwd); + const MKLDNNDeconvBwd bwd(param, read_tensors); + it = AddToCache(&bwds, key, bwd); } return it->second; } -class MKLDNNDeconvBackwardWeights { - std::shared_ptr bwd; - - public: - std::shared_ptr bwd_data_pd; - MKLDNNDeconvBackwardWeights(const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output, - const mkldnn::convolution_forward::primitive_desc& bwd_data_pd); - const mkldnn::convolution_backward_weights& GetBwd() const { - return *bwd; - } - const mkldnn::convolution_backward_weights::primitive_desc& GetWeightsPd() const { - return *bwd_data_pd; - } -}; - -MKLDNNDeconvBackwardWeights::MKLDNNDeconvBackwardWeights( +std::shared_ptr MKLDNNDeconvBwd::CreateDataPrimitiveDesc( const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output, - const mkldnn::convolution_forward::primitive_desc& bwd_data_pd) - : bwd_data_pd(GetDeconvBwdWeightsImpl(param, data, weights, false, output, bwd_data_pd)) { - bwd = std::make_shared(GetWeightsPd()); + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd) { + DeconvDescCreator ddc( + param, read_tensors.data, read_tensors.weights, nullptr, read_tensors.out_grad); + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = std::make_shared(ddc.CreateBwdDataDesc(), engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->diff_src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of deconvolution backward propagation"; + *pd = deconv_bwd_data_pd_t(ddc.CreateBwdDataDesc(), engine, fwd_pd); + } + return pd; } -static inline MKLDNNDeconvBackwardWeights& GetDeconvBwdWeights( +std::shared_ptr MKLDNNDeconvBwd::CreateWeightsPrimitiveDesc( const DeconvolutionParam& param, - const NDArray& data, - const NDArray& weights, - const NDArray& output, - const mkldnn::convolution_forward::primitive_desc& bwd_data_pd) { -#if DMLC_CXX11_THREAD_LOCAL - static thread_local std::unordered_map - bwds; -#else - static MX_THREAD_LOCAL - std::unordered_map - bwds; -#endif - MKLDNNDeconvSignature key(param); - // Here we can sign the conv op with NDArray because conv primitive will - // decide the right layout for the, so we only need to get the shape and the - // data type of the arrays. - key.AddSign(data); - key.AddSign(weights); - key.AddSign(output); - - auto it = bwds.find(key); - if (it == bwds.end()) { - auto bwd = MKLDNNDeconvBackwardWeights(param, data, weights, output, bwd_data_pd); - auto ins_ret = - bwds.insert(std::pair(key, bwd)); - CHECK(ins_ret.second); - it = ins_ret.first; + const ReadTensors& read_tensors, + const deconv_fwd_pd_t& fwd_pd) { + DeconvDescCreator ddc( + param, read_tensors.data, read_tensors.weights, read_tensors.bias, read_tensors.out_grad); + const auto& engine = CpuEngine::Get()->get_engine(); + const auto pd = + std::make_shared(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); + const auto get_data_size = [&pd]() { return pd->src_desc().get_size(); }; + const auto get_weights_size = [&pd]() { return pd->diff_weights_desc().get_size(); }; + const auto get_out_size = [&pd]() { return pd->diff_dst_desc().get_size(); }; + + while (!ddc.CheckImplSizeReq(get_data_size(), get_weights_size(), get_out_size())) { + // ImposePlainWherePadding fails when all memory descriptors already have plain formats + // imposed, meaning there is no implementation with plain formats + CHECK(ddc.ImposePlainWherePadding(get_data_size(), get_weights_size(), get_out_size())) + << "No implementation of calculating deconvolution weights gradient"; + *pd = deconv_bwd_weights_pd_t(ddc.CreateBwdWeightsDesc(), engine, fwd_pd); } - return it->second; + return pd; } -void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const std::vector& in_grad = outputs; - const DeconvolutionParam& param = nnvm::get(attrs.parsed); - - auto& data = inputs[deconv::kData + 1]; - auto& weight = inputs[deconv::kWeight + 1]; - auto& out_grad = inputs[deconv::kOut]; - - CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; - MKLDNNDeconvBackwardData& bwd_data = GetDeconvBwdData(param, data, weight, inputs[deconv::kOut]); - auto out_grad_mem = out_grad.GetMKLDNNDataReorder(bwd_data.GetDataPd().src_desc()); - if (req[deconv::kData]) { - auto weight_mem = GetWeights(weight, bwd_data.GetDataPd().weights_desc(), param.num_group); - auto in_grad_mem = CreateMKLDNNMem( - in_grad[deconv::kData], bwd_data.GetDataPd().dst_desc(), req[deconv::kData]); - mkldnn_args_map_t net_args = {{MKLDNN_ARG_SRC, *out_grad_mem}, - {MKLDNN_ARG_WEIGHTS, *weight_mem}, - {MKLDNN_ARG_DST, *in_grad_mem.second}}; - MKLDNNStream::Get()->RegisterPrimArgs(bwd_data.GetBwd(), net_args); - CommitOutput(in_grad[deconv::kData], in_grad_mem); +void MKLDNNDeconvBwd::Execute(const uint32_t num_group, + const std::vector& req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors) const { + // swaps are explained in MKLDNNDeconvFwd::Execute + IOSwapWeightsTensors(num_group, req, read_tensors.weights, write_tensors.weights_grad); + { + auto* const out_grad_mem = + ScheduleBwdData(num_group, req[deconv::kData], read_tensors, write_tensors); + ScheduleBwdWeights(num_group, req, read_tensors, write_tensors, out_grad_mem); + MKLDNNStream::Get()->Submit(); } - if (req[deconv::kWeight]) { - MKLDNNDeconvBackwardWeights& bwd_weights = - GetDeconvBwdWeights(param, data, weight, inputs[deconv::kOut], bwd_data.GetDataPd()); - if (bwd_data.GetDataPd().src_desc() != bwd_weights.GetWeightsPd().src_desc()) - out_grad_mem = out_grad.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().src_desc()); - auto data_mem = data.GetMKLDNNDataReorder(bwd_weights.GetWeightsPd().diff_dst_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[deconv::kWeight], - bwd_weights.GetWeightsPd().diff_weights_desc(), - req[deconv::kWeight]); - - mkldnn_args_map_t net_args = {{MKLDNN_ARG_SRC, *out_grad_mem}, - {MKLDNN_ARG_DIFF_DST, *data_mem}, - {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second}}; - MKLDNNStream::Get()->RegisterPrimArgs(bwd_weights.GetBwd(), net_args); - CommitOutput(in_grad[deconv::kWeight], in_grad_weight); + IOSwapWeightsTensors(num_group, req, read_tensors.weights, write_tensors.weights_grad); +} + +const mkldnn::memory* MKLDNNDeconvBwd::ScheduleBwdData(const uint32_t num_group, + const OpReqType req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors) const { + if (req) { + mkldnn_args_map_t net_args; + auto* const out_grad_mem = OutGradMem(read_tensors.out_grad); + const auto& data_grad_mem = DataGradMem(req, write_tensors.data_grad); + + net_args.insert({MKLDNN_ARG_DIFF_DST, *out_grad_mem}); + net_args.insert({MKLDNN_ARG_WEIGHTS, *WeightsMem(num_group, read_tensors.weights)}); + net_args.insert({MKLDNN_ARG_DIFF_SRC, *data_grad_mem.second}); + + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*bwd_data, net_args); + CommitOutput(write_tensors.data_grad, data_grad_mem); + return out_grad_mem; } - MKLDNNStream::Get()->Submit(); + return nullptr; +} - if (!param.no_bias) { - typedef float DType; - Stream* s = ctx.get_stream(); - Tensor gbias = in_grad[deconv::kBias].data().get(s); +void MKLDNNDeconvBwd::ScheduleBwdWeights(const uint32_t num_group, + const std::vector& req, + const ReadTensors& read_tensors, + const WriteTensors& write_tensors, + const mkldnn::memory* const out_grad_mem) const { + OpReqType weight_req = req[deconv::kWeight]; + OpReqType bias_req = req.size() > deconv::kBias ? req[deconv::kBias] : OpReqType::kNullOp; + if (weight_req || bias_req) { + mkldnn_args_map_t net_args; + const auto& weights_grad_mem = + WeightsGradMem(num_group, weight_req, write_tensors.weights_grad); + const auto& bias_grad_mem = BiasGradMem(bias_req, write_tensors.bias_grad); + + net_args.insert({MKLDNN_ARG_DIFF_DST, *OutGradMem(read_tensors.out_grad, out_grad_mem)}); + net_args.insert({MKLDNN_ARG_SRC, *DataMem(read_tensors.data)}); + net_args.insert({MKLDNN_ARG_DIFF_WEIGHTS, *weights_grad_mem.second}); + if (bias_grad_mem.second) { + net_args.insert({MKLDNN_ARG_DIFF_BIAS, *bias_grad_mem.second}); + } - NDArray temp = inputs[deconv::kOut]; - if (temp.IsMKLDNNData()) { - temp = temp.Reorder2Default(); + // CommitOutput should run after RegisterPrimArgs for memory dependency + MKLDNNStream::Get()->RegisterPrimArgs(*bwd_weights, net_args); + CommitOutput(write_tensors.weights_grad, weights_grad_mem); + if (bias_grad_mem.second) { + CommitOutput(*write_tensors.bias_grad, bias_grad_mem); } + } +} + +DeconvDescCreator::DeconvDescCreator(const DeconvolutionParam& param, + const NDArray& data, + const NDArray& weights, + const NDArray* const bias, + const NDArray& out) + : data_md(GetMemDesc(data)), + weights_md(GetDeconvWeightsDesc(weights, param.num_group)), + bias_md(bias ? GetMemDesc(*bias) : mkldnn::memory::desc()), + out_md(GetMemDesc(out)), + strides(param.stride.ndim()), + padding(param.pad.ndim()), + dilates(param.dilate.ndim()) { + CHECK_EQ(param.stride.ndim(), param.pad.ndim()); + CHECK_EQ(param.stride.ndim(), param.dilate.ndim()); + CHECK_GE(param.stride.ndim(), 1); + CHECK_LE(param.stride.ndim(), 3); + for (int i = 0; i < param.stride.ndim(); ++i) { + strides[i] = param.stride[i]; + padding[i] = param.pad[i]; + dilates[i] = param.dilate[i] - 1; + } +} - Tensor grad = temp.data().get(s); - Assign(gbias, req[deconv::kBias], mshadow::expr::sumall_except_dim<1>(grad)); +bool DeconvDescCreator::ImposePlainWherePadding(const size_t data_size, + const size_t weights_size, + const size_t out_size) { + // Changing only one at a time, so maybe better implementations will be selected (than entirely + // plain one) + if (data_md.data.format_kind == dnnl_format_kind_any && data_size != GetMemDescSize(data_md)) { + data_md = GetDesc(data_md, GetDefaultFormat(data_md)); + return true; + } else if (out_md.data.format_kind == dnnl_format_kind_any && + out_size != GetMemDescSize(out_md)) { + out_md = GetDesc(out_md, GetDefaultFormat(out_md)); + return true; + } else if (weights_md.data.format_kind == dnnl_format_kind_any && + weights_size != GetMemDescSize(weights_md)) { + const int num_gr = (weights_md.data.ndims > data_md.data.ndims) ? weights_md.data.dims[0] : 1; + weights_md = IOLogicalSwapDesc(weights_md, num_gr); + weights_md = IOLogicalSwapDesc(GetDesc(weights_md, GetDefaultFormat(weights_md)), num_gr); + return true; } + return false; } } // namespace op diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index 52de29c59aa3..8d855e65cfb0 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -427,11 +427,10 @@ def check_convolution_training(stype): check_convolution_training(stype) -@pytest.mark.skip(reason="Flaky test https://github.com/apache/incubator-mxnet/issues/12579") def test_Deconvolution(): def check_Deconvolution_training(stype): - for shape in [(3, 3, 10), (3, 3, 10, 10)]: - data_tmp = np.random.randint(256, size=shape) + for shape in [(3, 3, 10), (3, 3, 10, 10), (3, 3, 3, 10, 10)]: + data_tmp = np.random.normal(-0.1, 1, size=shape) data = mx.symbol.Variable('data', stype=stype) if np.array(shape).shape[0] == 3: @@ -440,6 +439,9 @@ def check_Deconvolution_training(stype): elif np.array(shape).shape[0] == 4: test = mx.symbol.Deconvolution(data=data, kernel=(3, 3), stride=(2, 2), num_filter=4) weight_tmp = np.random.normal(-0.1, 0.1, size=(3, 4, 3, 3)) + elif np.array(shape).shape[0] == 5: + test = mx.symbol.Deconvolution(data=data, kernel=(3,3,3), stride=(2,2,2), num_filter=4) + weight_tmp = np.random.normal(-0.1, 0.1, size=(3, 4, 3, 3, 3)) else: return 0 bias_tmp = np.random.normal(0.1, 0.1, size=(4,)) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 791ecefa6c58..d34519c332cc 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -432,45 +432,60 @@ def test_conv_nhwc(layer, shape): check_layer_forward(layer, shape) -def test_deconv(): - # layers1d = [ - # nn.Conv1DTranspose(16, 3, in_channels=4), - # nn.Conv1DTranspose(16, 3, groups=2, in_channels=4), - # nn.Conv1DTranspose(16, 3, strides=3, groups=2, in_channels=4), - # ] - # for layer in layers1d: - # check_layer_forward(layer, (1, 4, 10)) - - - layers2d = [ - nn.Conv2DTranspose(16, (3, 4), in_channels=4), - nn.Conv2DTranspose(16, (5, 4), in_channels=4), - nn.Conv2DTranspose(16, (3, 4), groups=2, in_channels=4), - nn.Conv2DTranspose(16, (3, 4), strides=4, in_channels=4), - nn.Conv2DTranspose(16, (3, 4), dilation=4, in_channels=4), - # nn.Conv2DTranspose(16, (3, 4), padding=4, in_channels=4), - nn.Conv2DTranspose(16, (3, 4), strides=4, output_padding=3, in_channels=4), - ] - for layer in layers2d: - check_layer_forward(layer, (1, 4, 20, 20)) - - - # layers3d = [ - # nn.Conv3DTranspose(16, (1, 8, 4), in_channels=4), - # nn.Conv3DTranspose(16, (5, 4, 3), in_channels=4), - # nn.Conv3DTranspose(16, (3, 3, 3), groups=2, in_channels=4), - # nn.Conv3DTranspose(16, 4, strides=4, in_channels=4), - # nn.Conv3DTranspose(16, (3, 3, 3), padding=4, in_channels=4), - # ] - # for layer in layers3d: - # check_layer_forward(layer, (1, 4, 10, 10, 10)) - # - # - # layer = nn.Conv2DTranspose(16, (3, 3), layout='NHWC', in_channels=4) - # # check_layer_forward(layer, (1, 10, 10, 4)) - # - # layer = nn.Conv3DTranspose(16, (3, 3, 3), layout='NDHWC', in_channels=4) - # # check_layer_forward(layer, (1, 10, 10, 10, 4)) +@pytest.mark.parametrize('layer,shape', [ + (nn.Conv1DTranspose(16, 3, in_channels=4), (1, 4, 10)), + (nn.Conv1DTranspose(16, 3, groups=2, in_channels=4), (1, 4, 10)), + (nn.Conv1DTranspose(16, 3, strides=3, groups=2, in_channels=4, output_padding=2), (1, 4, 10)), + (nn.Conv2DTranspose(16, (3, 4), in_channels=4), (1, 4, 20, 20)), + (nn.Conv2DTranspose(16, (5, 4), in_channels=4), (1, 4, 20, 20)), + (nn.Conv2DTranspose(16, (3, 4), groups=2, in_channels=4), (1, 4, 20, 20)), + (nn.Conv2DTranspose(16, (3, 4), strides=4, in_channels=4, output_padding=3), (1, 4, 20, 20)), + (nn.Conv2DTranspose(16, (3, 4), dilation=4, in_channels=4), (1, 4, 20, 20)), + (nn.Conv2DTranspose(16, (3, 4), padding=4, in_channels=4), (1, 4, 20, 20)), + (nn.Conv3DTranspose(16, (1, 8, 4), in_channels=4, activation='relu'), (1, 4, 10, 10, 10)), + (nn.Conv3DTranspose(16, (5, 4, 3), in_channels=4), (1, 4, 10, 10, 10)), + (nn.Conv3DTranspose(16, (3, 3, 3), groups=2, in_channels=4), (1, 4, 10, 10, 10)), + (nn.Conv3DTranspose(16, 4, strides=4, in_channels=4, output_padding=3), (1, 4, 10, 10, 10)), + (nn.Conv3DTranspose(16, (3, 3, 3), padding=4, in_channels=4), (1, 4, 10, 10, 10)), +]) +def test_deconv(layer, shape): + if len(shape) == 5 and mx.current_context().device_type == 'gpu': + pytest.skip('Skipping Conv3DTranspose tests for GPU') + check_layer_forward(layer, shape) + + +@use_np +def test_deconv_dilation(): + data = mx.np.array([[[[0, 0, 0], + [0, 1, 0], + [0, 0, 0]]], + [[[0, 0, 0], + [0, 2, 0], + [0, 0, 0]]]]) + + weight = mx.np.array([[[[1, 2, 3], + [4, 5, 6], + [7, 8, 9]]]]) + + layer = nn.Conv2DTranspose(in_channels=1, channels=1, + kernel_size=(3, 3), padding=(1, 1), + strides=(1, 1), dilation=(2, 2)) + layer.initialize() + layer.weight.set_data(weight) + out = layer(data) + expected = mx.np.array( + [[[[1., 0., 2., 0., 3.], + [0., 0., 0., 0., 0.], + [4., 0., 5., 0., 6.], + [0., 0., 0., 0., 0.], + [7., 0., 8., 0., 9.]]], + [[[2., 0., 4., 0., 6.], + [0., 0., 0., 0., 0.], + [8., 0., 10., 0., 12.], + [0., 0., 0., 0., 0.], + [14., 0., 16., 0., 18.]]] + ]) + assert_almost_equal(out, expected) def test_pool(): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index e21e8fdc49b8..488f1a80285d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -10832,3 +10832,76 @@ def test_slice_like(): xx[:] = 0.0 xx[idx] = x.asnumpy()[idx] assert_allclose(x1.grad.asnumpy(), np.zeros_like(x1.grad).asnumpy()) + + +@use_np +@pytest.mark.parametrize('shape,num_filter,num_group,kernel,pad', [ + ((1, 4, 15), 16, 2, (2,), (0,)), + ((8, 4, 16), 16, 1, (3,), (1,)), + + ((1, 4, 15, 16), 16, 2, (2, 2), (0, 0)), + ((8, 4, 16, 16), 16, 1, (3, 3), (1, 1)), + + ((1, 4, 3, 15, 16), 16, 2, (2, 2, 2), (0, 0, 0)), + ((8, 4, 3, 16, 16), 16, 1, (3, 3, 3), (1, 1, 1))]) +def test_npx_deconvolution(shape, num_filter, num_group, kernel, pad): + if len(kernel) == 3 and mx.current_context().device_type == 'gpu': + pytest.skip('Skipping deconvoluition 3D tests for GPU') + + class TestConv(mx.gluon.HybridBlock): + def __init__(self, w): + super().__init__() + self.weight = w + + def forward(self, x, *args): + return npx.convolution(x, self.weight.data(x.ctx), no_bias=True, kernel=kernel, + pad=pad, num_filter=self.weight.shape[0], num_group=num_group) + + class TestDeconv(mx.gluon.HybridBlock): + def __init__(self): + super().__init__() + self.weight = mx.gluon.Parameter('weight', shape=(shape[1], int(num_filter/num_group), + *kernel)) + self.bias = mx.gluon.Parameter('bias', shape=num_filter) + + def forward(self, x, *args): + return npx.deconvolution(x, self.weight.data(x.ctx), self.bias.data(x.ctx), kernel, + pad=pad, num_filter=num_filter, num_group=num_group) + + deconvNet = TestDeconv() + deconvNet.initialize() + + # test imperative + deconvData = np.random.uniform(0, 1, size=shape) + npx_out_imp = deconvNet(deconvData) + + # test symbolic + deconvNet.hybridize() + deconvNet(deconvData) + npx_out_sym = deconvNet(deconvData) + assert_almost_equal(npx_out_imp, npx_out_sym) + + # compare outputs with reference tensors generated using convolution + convNet = TestConv(deconvNet.weight) + convNet.initialize() + convData = np.random.uniform(0, 1, size=npx_out_imp.shape) + convData.attach_grad() + with mx.autograd.record(): + convOut = convNet(convData) + y = np.reshape(convOut, -1) + y = np.sum(y) + y.backward() + + deconvData = np.ones_like(convOut) # gradient of convOut + deconvBias = np.repeat(deconvNet.bias.data(), int(np.prod(np.array(convData.grad.shape[2:])).item())) + deconvRefOut = np.copy(convData.grad) + deconvBias.reshape((convData.grad.shape[1:])) + deconvData.attach_grad() + with mx.autograd.record(): + deconvOut = deconvNet(deconvData) + deconvOut.backward() + + convData = np.ones_like(deconvOut) + deconvRefGrad = convNet(convData) + + assert_almost_equal(deconvOut, deconvRefOut) + assert_almost_equal(deconvData.grad, deconvRefGrad) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 610b70032853..b4a6d6d40f24 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1297,26 +1297,46 @@ def test_deconvolution(): pad = (3,) ) -def test_deconvolution_forward_with_bias(): +@pytest.mark.parametrize('shape,num_filter,num_group,kernel,pad', [ + ((1, 4, 15), 16, 2, (2,), (0,)), + ((8, 4, 16), 16, 1, (3,), (1,)), + + ((1, 4, 15, 16), 16, 2, (2, 2), (0, 0)), + ((8, 4, 16, 16), 16, 1, (3, 3), (1, 1)), + + ((1, 4, 3, 15, 16), 16, 2, (2, 2, 2), (0, 0, 0)), + ((8, 4, 3, 16, 16), 16, 1, (3, 3, 3), (1, 1, 1))]) +def test_deconvolution_forward_with_bias(shape, num_filter, num_group, kernel, pad): """Check if deconvolution forward can work well with bias=True """ - def check_deconvolution_forward_with_bias(shape=(1, 16, 5, 5), num_filter=32, num_group=1, kernel=(3, 3), pad=(1, 1)): - x = mx.sym.Variable('x') - w = mx.sym.Variable('w') - input_data = mx.random.uniform(-5, 5, shape, ctx=mx.cpu()) - y = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad) - exe = y._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') - - exe.arg_arrays[0][:] = np.random.normal(size=exe.arg_arrays[0].shape) - exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape) - - exe.forward(is_train=False) - o = exe.outputs[0] - t = o.asnumpy() - check_deconvolution_forward_with_bias((1, 16, 5), 32, 1, (3,), (1,)) - check_deconvolution_forward_with_bias((32, 16, 5), 32, 1, (3,), (1,)) - check_deconvolution_forward_with_bias((1, 16, 5, 5), 32, 1, (3, 3), (1, 1)) - check_deconvolution_forward_with_bias((32, 16, 5, 5), 32, 1, (3, 3), (1, 1)) + if len(kernel) == 3 and mx.current_context().device_type == 'gpu': + pytest.skip('Skipping Conv3DTranspose tests for GPU') + + x = mx.sym.Variable('x') + w = mx.sym.Variable('w') + b = mx.sym.Variable('b') + y_nb = mx.sym.Deconvolution(data=x, weight=w, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=True, pad=pad) + y_b = mx.sym.Deconvolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel, no_bias=False, pad=pad) + + exe_nb = y_nb._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') + exe_b = y_b._simple_bind(ctx=mx.cpu(), x=shape, grad_req='null') + + data = np.random.uniform(-5, 5, size=exe_b.arg_arrays[0].shape) + weights = np.random.normal(size=exe_b.arg_arrays[1].shape) + bias = np.random.normal(size=exe_b.arg_arrays[2].shape) + + def exe_forward(exe): + exe.arg_arrays[0][:] = data + exe.arg_arrays[1][:] = weights + if len(exe.arg_arrays) == 3: + exe.arg_arrays[2][:] = bias + return exe.forward(is_train=False)[0].asnumpy() + + out_nb = exe_forward(exe_nb) + out_b = exe_forward(exe_b) + bias = np.broadcast_to(bias, [np.prod(out_nb.shape[2:])] + [num_filter]).T + bias = np.broadcast_to(bias.reshape((num_filter, *out_nb.shape[2:])), out_b.shape) + assert_almost_equal(out_nb + bias, out_b) def check_nearest_upsampling_with_shape(shapes, scale, root_scale):