diff --git a/src/operator/cudnn_algoreg-inl.h b/src/operator/cudnn_algoreg-inl.h new file mode 100644 index 000000000000..3778125d7e2a --- /dev/null +++ b/src/operator/cudnn_algoreg-inl.h @@ -0,0 +1,89 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file cudnn_algoreg-inl.h + * \brief + * \author Bing Xu + */ +#ifndef MXNET_OPERATOR_CUDNN_ALGOREG_INL_H_ +#define MXNET_OPERATOR_CUDNN_ALGOREG_INL_H_ + +#include +#include +#include +#include +#include "../common/cuda_utils.h" +#include "./convolution-inl.h" +#include "./deconvolution-inl.h" + +namespace mxnet { +namespace op { +#if MXNET_USE_CUDNN == 1 + +class CuDNNAlgoReg { + public: + template + std::string GetKey(const Param ¶m, const std::vector &in_shape, + const std::vector &out_shape) { + std::ostringstream oss; + for (auto &i : in_shape) + oss << i << ";"; + for (auto &i : out_shape) + oss << i << ";"; + auto dict = param.__DICT__(); + for (auto &k : dict) + oss << k.first << "=" << k.second << ";"; + return oss.str(); + } + + bool Find(std::string key, cudnnConvolutionFwdAlgo_t *fwd, + cudnnConvolutionBwdDataAlgo_t *bwd, + cudnnConvolutionBwdFilterAlgo_t *flt) { + std::lock_guard guard(lock_); + auto i = reg_.find(key); + if (i != reg_.end()) { + *fwd = i->second.fwd; + *bwd = i->second.bwd; + *flt = i->second.flt; + return true; + } + return false; + } + + void Register(std::string key, cudnnConvolutionFwdAlgo_t fwd, + cudnnConvolutionBwdDataAlgo_t bwd, + cudnnConvolutionBwdFilterAlgo_t flt) { + std::lock_guard guard(lock_); + if (reg_.size() % 50 == 0) { + LOG(INFO) << "Running performance tests to find the best convolution " + "algorithm, " + "this can take a while... (setting env variable " + "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)"; + if (reg_.size() >= 1000) { + LOG(INFO) + << "If you see this message in the middle of training, you are " + "probably using bucketing. Consider setting env variable " + "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning."; + } + } + reg_[key].fwd = fwd; + reg_[key].bwd = bwd; + reg_[key].flt = flt; + } + + static CuDNNAlgoReg *Get(); + + private: + struct CudnnAlgorithms { + cudnnConvolutionFwdAlgo_t fwd; + cudnnConvolutionBwdDataAlgo_t bwd; + cudnnConvolutionBwdFilterAlgo_t flt; + }; + + std::mutex lock_; + std::unordered_map reg_; +}; +#endif // __CUDACC__ && CUDNN +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CUDNN_ALGOREG_INL_H_ diff --git a/src/operator/cudnn_convolution.cc b/src/operator/cudnn_algoreg.cc similarity index 66% rename from src/operator/cudnn_convolution.cc rename to src/operator/cudnn_algoreg.cc index 6e65a7262d73..103c4819d951 100644 --- a/src/operator/cudnn_convolution.cc +++ b/src/operator/cudnn_algoreg.cc @@ -1,10 +1,10 @@ /*! * Copyright (c) 2015 by Contributors - * \file cudnn_convolution.cc + * \file cudnn_algoreg.cc * \brief * \author Junyuan Xie */ -#include "./cudnn_convolution-inl.h" +#include "./cudnn_algoreg-inl.h" #include #include @@ -14,8 +14,8 @@ namespace mxnet { namespace op { #if MXNET_USE_CUDNN == 1 -CuDNNAlgoReg* CuDNNAlgoReg::Get() { - static CuDNNAlgoReg* ptr = new CuDNNAlgoReg(); +CuDNNAlgoReg *CuDNNAlgoReg::Get() { + static CuDNNAlgoReg *ptr = new CuDNNAlgoReg(); return ptr; } #endif // CUDNN diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index c9b8de6e3430..73fbeac6371b 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -12,75 +12,13 @@ #include #include #include "./convolution-inl.h" +#include "./cudnn_algoreg-inl.h" #include "../common/cuda_utils.h" namespace mxnet { namespace op { #if MXNET_USE_CUDNN == 1 -class CuDNNAlgoReg { - public: - std::string GetKey(const ConvolutionParam& param, - const std::vector& in_shape, - const std::vector& out_shape) { - std::ostringstream oss; - for (auto& i : in_shape) oss << i << ";"; - for (auto& i : out_shape) oss << i << ";"; - auto dict = param.__DICT__(); - for (auto& k : dict) oss << k.first << "=" << k.second << ";"; - return oss.str(); - } - - bool Find(std::string key, - cudnnConvolutionFwdAlgo_t *fwd, - cudnnConvolutionBwdDataAlgo_t *bwd, - cudnnConvolutionBwdFilterAlgo_t *flt) { - std::lock_guard guard(lock_); - auto i = reg_.find(key); - if (i != reg_.end()) { - *fwd = i->second.fwd; - *bwd = i->second.bwd; - *flt = i->second.flt; - return true; - } - return false; - } - - void Register(std::string key, - cudnnConvolutionFwdAlgo_t fwd, - cudnnConvolutionBwdDataAlgo_t bwd, - cudnnConvolutionBwdFilterAlgo_t flt) { - std::lock_guard guard(lock_); - if (reg_.size() % 50 == 0) { - LOG(INFO) - << "Running performance tests to find the best convolution algorithm, " - "this can take a while... (setting env variable " - "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)"; - if (reg_.size() >= 1000) { - LOG(INFO) - << "If you see this message in the middle of training, you are " - "probably using bucketing. Consider setting env variable " - "MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable cudnn tuning."; - } - } - reg_[key].fwd = fwd; - reg_[key].bwd = bwd; - reg_[key].flt = flt; - } - - static CuDNNAlgoReg* Get(); - - private: - struct CudnnAlgorithms { - cudnnConvolutionFwdAlgo_t fwd; - cudnnConvolutionBwdDataAlgo_t bwd; - cudnnConvolutionBwdFilterAlgo_t flt; - }; - - std::mutex lock_; - std::unordered_map reg_; -}; - template class CuDNNConvolutionOp : public Operator { public: diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h index 9805813d1605..fd820fe6ce92 100644 --- a/src/operator/cudnn_deconvolution-inl.h +++ b/src/operator/cudnn_deconvolution-inl.h @@ -1,29 +1,55 @@ /*! - * Copyright (c) 2015 by Contributors + * Copyright (c) 2017 by Contributors * \file cudnn_deconvolution-inl.h * \brief - * \author Wei Wu + * \author Wei Wu, Leonard Lausen */ #ifndef MXNET_OPERATOR_CUDNN_DECONVOLUTION_INL_H_ #define MXNET_OPERATOR_CUDNN_DECONVOLUTION_INL_H_ #include #include +#include +#include #include "./deconvolution-inl.h" +#include "./cudnn_algoreg-inl.h" +#include "../common/cuda_utils.h" namespace mxnet { namespace op { -#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN == 1 + template class CuDNNDeconvolutionOp : public Operator { public: - explicit CuDNNDeconvolutionOp(DeconvolutionParam param) { + explicit CuDNNDeconvolutionOp(DeconvolutionParam param, + const std::vector& in_shape, + const std::vector& out_shape, + const Context& ctx) { + using namespace mshadow; this->param_ = param; + // convert MB to words param_.workspace = (param_.workspace << 20) / sizeof(DType); init_cudnn_ = false; - // TODO(xxx): fp16 + init_temp_size_ = false; dtype_ = mshadow::DataType::kCudnnFlag; + +#if CUDNN_MAJOR >= 5 + MSHADOW_LAYOUT_SWITCH(param_.layout.value(), Layout, { + format_ = LayoutType::kCudnnFlag; + }); +#else + CHECK(param_.layout.value() == kNCHW || param_.layout.value() == kNCDHW) + << "Need CuDNN > 5.0 for layout support"; +#endif + + InitDescriptors(ctx, in_shape, out_shape); + + if (!param_.cudnn_tune) { + param_.cudnn_tune = dmlc::GetEnv("MXNET_CUDNN_AUTOTUNE_DEFAULT", 1); + } + SelectAlgo(ctx, in_shape, out_shape); } ~CuDNNDeconvolutionOp() { @@ -43,22 +69,39 @@ class CuDNNDeconvolutionOp : public Operator { const std::vector &aux_args) { using namespace mshadow; size_t expected = param_.no_bias ? 2 : 3; + DType *data_ptr = NULL; + DType *wmat_ptr = NULL; + DType *out_ptr = NULL; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); - Tensor data = in_data[deconv::kData].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor out = out_data[deconv::kOut].get(s); - - CHECK_EQ(data.CheckContiguous(), true); - CHECK_EQ(wmat.CheckContiguous(), true); - CHECK_EQ(out.CheckContiguous(), true); - if (!init_cudnn_) { - Init(s, in_data, out_data); - } + GetTempSize(ctx); Tensor workspace = - ctx.requested[deconv::kTempSpace].get_space_typed( - mshadow::Shape1(forward_workspace_), s); + ctx.requested[deconv::kTempSpace].get_space_typed( + mshadow::Shape1(forward_workspace_), s); + + if (param_.kernel.ndim() == 2) { + Tensor data = in_data[deconv::kData].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor out = out_data[deconv::kOut].get(s); + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(wmat.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + data_ptr = data.dptr_; + wmat_ptr = wmat.dptr_; + out_ptr = out.dptr_; + } else { + Tensor data = in_data[deconv::kData].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor out = out_data[deconv::kOut].get(s); + CHECK_EQ(data.CheckContiguous(), true); + CHECK_EQ(wmat.CheckContiguous(), true); + CHECK_EQ(out.CheckContiguous(), true); + data_ptr = data.dptr_; + wmat_ptr = wmat.dptr_; + out_ptr = out.dptr_; + } + for (uint32_t g = 0; g < param_.num_group; ++g) { typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; @@ -66,30 +109,30 @@ class CuDNNDeconvolutionOp : public Operator { CHECK_EQ(cudnnConvolutionBackwardData_v3(s->dnn_handle_, &alpha, filter_desc_, - wmat.dptr_ + weight_offset_ * g, + wmat_ptr + weight_offset_ * g, in_desc_, - data.dptr_ + data_offset_ * g, + data_ptr + data_offset_ * g, conv_desc_, back_algo_, workspace.dptr_, backward_workspace_byte_, &beta, out_desc_, - out.dptr_ + out_offset_ * g), CUDNN_STATUS_SUCCESS); + out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); #elif CUDNN_MAJOR == 5 CHECK_EQ(cudnnConvolutionBackwardData(s->dnn_handle_, &alpha, filter_desc_, - wmat.dptr_ + weight_offset_ * g, + wmat_ptr + weight_offset_ * g, in_desc_, - data.dptr_ + data_offset_ * g, + data_ptr + data_offset_ * g, conv_desc_, back_algo_, workspace.dptr_, backward_workspace_byte_, &beta, out_desc_, - out.dptr_ + out_offset_ * g), CUDNN_STATUS_SUCCESS); + out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); #endif if (!param_.no_bias) { beta = 1.0f; @@ -101,7 +144,7 @@ class CuDNNDeconvolutionOp : public Operator { bias.dptr_ + bias_offset_ * g, &beta, out_desc_, - out.dptr_ + out_offset_ * g), CUDNN_STATUS_SUCCESS); + out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); #endif #if CUDNN_MAJOR == 3 CHECK_EQ(cudnnAddTensor(s->dnn_handle_, @@ -111,7 +154,7 @@ class CuDNNDeconvolutionOp : public Operator { bias.dptr_ + bias_offset_ * g, &beta, out_desc_, - out.dptr_ + out_offset_ * g), CUDNN_STATUS_SUCCESS); + out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); #endif } } @@ -127,17 +170,40 @@ class CuDNNDeconvolutionOp : public Operator { using namespace mshadow; using namespace mshadow::expr; size_t expected = param_.no_bias == 0 ? 3 : 2; + DType *grad_ptr = NULL; + DType *wmat_ptr = NULL; + DType *gwmat_ptr = NULL; + DType *data_ptr = NULL; + DType *gdata_ptr = NULL; CHECK_EQ(out_grad.size(), 1U); CHECK(in_data.size() == expected && in_grad.size() == expected); + Stream *s = ctx.get_stream(); + if (param_.kernel.ndim() == 2) { + Tensor grad = out_grad[deconv::kOut].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor gwmat = in_grad[deconv::kWeight].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor gdata = in_grad[deconv::kData].get(s); + grad_ptr = grad.dptr_; + wmat_ptr = wmat.dptr_; + gwmat_ptr = gwmat.dptr_; + data_ptr = data.dptr_; + gdata_ptr = gdata.dptr_; + } else { + Tensor grad = out_grad[deconv::kOut].get(s); + Tensor wmat = in_data[deconv::kWeight].get(s); + Tensor gwmat = in_grad[deconv::kWeight].get(s); + Tensor data = in_data[deconv::kData].get(s); + Tensor gdata = in_grad[deconv::kData].get(s); + grad_ptr = grad.dptr_; + wmat_ptr = wmat.dptr_; + gwmat_ptr = gwmat.dptr_; + data_ptr = data.dptr_; + gdata_ptr = gdata.dptr_; + } CHECK_NE(req[deconv::kWeight], kWriteInplace); CHECK_NE(req[deconv::kBias], kWriteInplace); CHECK_NE(req[deconv::kData], kWriteInplace); - Stream *s = ctx.get_stream(); - Tensor grad = out_grad[deconv::kOut].get(s); - Tensor wmat = in_data[deconv::kWeight].get(s); - Tensor gwmat = in_grad[deconv::kWeight].get(s); - Tensor data = in_data[deconv::kData].get(s); - Tensor gdata = in_grad[deconv::kData].get(s); Tensor workspace = ctx.requested[deconv::kTempSpace].get_space_typed( mshadow::Shape1(backward_workspace_), s); @@ -155,7 +221,7 @@ class CuDNNDeconvolutionOp : public Operator { CHECK_EQ(cudnnConvolutionBackwardBias(s->dnn_handle_, &alpha, out_desc_, - grad.dptr_ + out_offset_ * g, + grad_ptr + out_offset_ * g, &bias_beta, bias_desc_, gbias.dptr_ + bias_offset_ * g), @@ -166,187 +232,341 @@ class CuDNNDeconvolutionOp : public Operator { CHECK_EQ(cudnnConvolutionBackwardFilter_v3(s->dnn_handle_, &alpha, out_desc_, - grad.dptr_ + out_offset_ * g, + grad_ptr + out_offset_ * g, in_desc_, - data.dptr_ + data_offset_ * g, + data_ptr + data_offset_ * g, conv_desc_, back_algo_w_, workspace.dptr_, backward_workspace_byte_, &weight_beta, filter_desc_, - gwmat.dptr_ + weight_offset_ * g), CUDNN_STATUS_SUCCESS); + gwmat_ptr + weight_offset_ * g), CUDNN_STATUS_SUCCESS); #elif CUDNN_MAJOR == 5 CHECK_EQ(cudnnConvolutionBackwardFilter(s->dnn_handle_, &alpha, out_desc_, - grad.dptr_ + out_offset_ * g, + grad_ptr + out_offset_ * g, in_desc_, - data.dptr_ + data_offset_ * g, + data_ptr + data_offset_ * g, conv_desc_, back_algo_w_, workspace.dptr_, backward_workspace_byte_, &weight_beta, filter_desc_, - gwmat.dptr_ + weight_offset_ * g), CUDNN_STATUS_SUCCESS); + gwmat_ptr + weight_offset_ * g), CUDNN_STATUS_SUCCESS); #endif } if (req[deconv::kData] != kNullOp) { CHECK_EQ(cudnnConvolutionForward(s->dnn_handle_, &alpha, out_desc_, - grad.dptr_ + out_offset_ * g, + grad_ptr + out_offset_ * g, filter_desc_, - wmat.dptr_ + weight_offset_ * g, + wmat_ptr + weight_offset_ * g, conv_desc_, algo_, workspace.dptr_, forward_workspace_byte_, &data_beta, in_desc_, - gdata.dptr_ + data_offset_ * g), CUDNN_STATUS_SUCCESS); + gdata_ptr + data_offset_ * g), CUDNN_STATUS_SUCCESS); } } } private: - inline void Init(mshadow::Stream *s, - const std::vector &in_data, - const std::vector &out_data) { + inline void InitDescriptors(const Context& ctx, + const std::vector &in_shape, + const std::vector &out_shape) { using namespace mshadow; - #if CUDNN_MAJOR == 5 - format_ = CUDNN_TENSOR_NCHW; - #endif size_t expected = param_.no_bias ? 2 : 3; - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(out_data.size(), 1U); - if (!init_cudnn_) { - init_cudnn_ = true; - size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); - size_t back_size = 0; - size_t back_size_w = 0; - Tensor data = in_data[deconv::kData].get(s); - Tensor out = out_data[deconv::kOut].get(s); - index_t pad_y, pad_x, adj_y, adj_x; - param_.InferPad(data.size(2), data.size(3), &pad_y, &pad_x, &adj_y, &adj_x); - data_offset_ = data.shape_[1] / param_.num_group * data.shape_[2] * data.shape_[3]; - out_offset_ = out.shape_[1] /param_.num_group * out.shape_[2] * out.shape_[3]; - weight_offset_ = data.shape_[1] / param_.num_group * param_.num_filter / param_.num_group - * param_.kernel[0] * param_.kernel[1]; - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&bias_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateFilterDescriptor(&filter_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateConvolutionDescriptor(&conv_desc_), CUDNN_STATUS_SUCCESS); - #if CUDNN_MAJOR <=4 - CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group, - param_.kernel[0], - param_.kernel[1]), CUDNN_STATUS_SUCCESS); - #elif CUDNN_MAJOR ==5 - CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - format_, - data.shape_[1] / param_.num_group, - param_.num_filter / param_.num_group, - param_.kernel[0], - param_.kernel[1]), CUDNN_STATUS_SUCCESS); - #endif + CHECK_EQ(in_shape.size(), expected); + CHECK_EQ(out_shape.size(), 1U); + CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateTensorDescriptor(&bias_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateFilterDescriptor(&filter_desc_), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnCreateConvolutionDescriptor(&conv_desc_), CUDNN_STATUS_SUCCESS); + + TShape dshape = in_shape[deconv::kData]; + TShape wshape = in_shape[deconv::kWeight]; + TShape oshape = out_shape[deconv::kOut]; + TShape dstride, ostride; + wshape[0] /= param_.num_group; + + if (param_.kernel.ndim() == 2) { + // 2d conv + index_t o_pad[2]; + index_t o_adj[2]; + param_.InferPad(dshape, o_pad, o_adj); + CHECK_EQ(cudnnSetConvolution2dDescriptor(conv_desc_, - pad_y, - pad_x, + o_pad[0], + o_pad[1], param_.stride[0], param_.stride[1], 1, 1, CUDNN_CROSS_CORRELATION), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptorEx(in_desc_, - dtype_, - data.shape_[0], - data.shape_[1] / param_.num_group, - data.shape_[2], - data.shape_[3], - data.shape_[1] * data.shape_[2] * data.shape_[3], - data.shape_[2] * data.shape_[3], - data.shape_[3], - 1), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptorEx(out_desc_, - dtype_, - out.shape_[0], - out.shape_[1] / param_.num_group, - out.shape_[2], - out.shape_[3], - out.shape_[1] * out.shape_[2] * out.shape_[3], - out.shape_[2] * out.shape_[3], - out.shape_[3], - 1), CUDNN_STATUS_SUCCESS); - if (!param_.no_bias) { - Tensor bias = in_data[deconv::kBias].get(s); - bias_offset_ = bias.shape_[0] / param_.num_group; - CHECK_EQ(cudnnSetTensor4dDescriptor(bias_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - 1, - bias.shape_[0] / param_.num_group, - 1, - 1), CUDNN_STATUS_SUCCESS); + + #if CUDNN_MAJOR >= 5 + wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); + CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + format_, + wshape[0], + wshape[1], + wshape[2], + wshape[3]), CUDNN_STATUS_SUCCESS); + #else + CHECK_EQ(param_.layout.value(), kNCHW) << "CuDNN V4 only support NCHW layout"; + CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + wshape[0], + wshape[1], + wshape[2], + wshape[3]), CUDNN_STATUS_SUCCESS); + #endif + + dstride = ConvertLayout(Shape4(dshape[1] * dshape[2] * dshape[3], + dshape[2] * dshape[3], + dshape[3], + 1), + param_.layout.value(), kNCHW); + dshape = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); + + ostride = ConvertLayout(Shape4(oshape[1] * oshape[2] * oshape[3], + oshape[2] * oshape[3], + oshape[3], + 1), + param_.layout.value(), kNCHW); + oshape = ConvertLayout(oshape.get<4>(), param_.layout.value(), kNCHW); + } else if (param_.kernel.ndim() == 3) { + // 3d conv + std::vector upscale_vec = {1, 1, 1}; + + index_t o_pad[3]; + index_t o_adj[3]; + param_.InferPad(dshape, o_pad, o_adj); + + #if CUDNN_MAJOR >= 5 + CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; + CHECK_EQ(cudnnSetFilterNdDescriptor(filter_desc_, + dtype_, + CUDNN_TENSOR_NCHW, + static_cast(wshape.ndim()), + reinterpret_cast(&wshape[0])), + CUDNN_STATUS_SUCCESS); + #else + LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; + #endif + CHECK_EQ(cudnnSetConvolutionNdDescriptor(conv_desc_, + 3, + reinterpret_cast(&o_pad[0]), + reinterpret_cast(¶m_.stride[0]), + &upscale_vec[0], + CUDNN_CROSS_CORRELATION, + dtype_), CUDNN_STATUS_SUCCESS); + + dstride = ConvertLayout(Shape5(dshape[1] * dshape[2] * dshape[3] * dshape[4], + dshape[2] * dshape[3] * dshape[4], + dshape[3] * dshape[4], + dshape[4], + 1), + param_.layout.value(), kNCDHW); + dshape = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW); + + ostride = ConvertLayout(Shape5(oshape[1] * oshape[2] * oshape[3] * oshape[4], + oshape[2] * oshape[3] * oshape[4], + oshape[3] * oshape[4], + oshape[4], + 1), + param_.layout.value(), kNCDHW); + oshape = ConvertLayout(oshape.get<5>(), param_.layout.value(), kNCDHW); + } + dshape[1] /= param_.num_group; + oshape[1] /= param_.num_group; + weight_offset_ = wshape.Size(); + data_offset_ = dstride[1] * dshape[1]; + out_offset_ = ostride[1] * oshape[1]; + + CHECK_EQ(cudnnSetTensorNdDescriptor(in_desc_, + dtype_, + static_cast(dshape.ndim()), + reinterpret_cast(&dshape[0]), + reinterpret_cast(&dstride[0])), + CUDNN_STATUS_SUCCESS); + + CHECK_EQ(cudnnSetTensorNdDescriptor(out_desc_, + dtype_, + static_cast(oshape.ndim()), + reinterpret_cast(&oshape[0]), + reinterpret_cast(&ostride[0])), + CUDNN_STATUS_SUCCESS); + + if (!param_.no_bias) { + TShape bias = in_shape[deconv::kBias]; + bias_offset_ = bias[0] / param_.num_group; + std::vector bias_shape = {1, + static_cast(bias[0] / param_.num_group), + 1, 1}; + std::vector bias_stride = {static_cast(bias_offset_), 1, 1, 1}; + if (param_.kernel.ndim() == 3) { + bias_shape.push_back(1); + bias_stride.push_back(1); } + CHECK_EQ(cudnnSetTensorNdDescriptor(bias_desc_, + dtype_, + static_cast(bias_shape.size()), + &bias_shape[0], + &bias_stride[0]), CUDNN_STATUS_SUCCESS); + } + init_cudnn_ = true; + } + + void SelectAlgo(const Context& ctx, + const std::vector& in_shape, + const std::vector& out_shape) { + std::string key = CuDNNAlgoReg::Get()->GetKey(param_, in_shape, out_shape); + if (CuDNNAlgoReg::Get()->Find(key, &algo_, &back_algo_, &back_algo_w_)) return; + + Engine::VarHandle var = Engine::Get()->NewVariable(); + Engine::Get()->PushSync([=](RunContext rctx) { + mshadow::Stream *s = rctx.get_stream(); CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - CHECK_EQ(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - conv_desc_, - in_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &algo_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - conv_desc_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &back_algo_w_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - conv_desc_, - out_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &back_algo_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, + size_t workspace_byte = static_cast(param_.workspace * sizeof(DType)); + if (!param_.cudnn_tune.value()) { + CHECK_EQ(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + conv_desc_, + in_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->algo_)), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->back_algo_w_)), CUDNN_STATUS_SUCCESS); + CHECK_EQ(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + conv_desc_, + out_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->back_algo_)), CUDNN_STATUS_SUCCESS); + } else { + const int kMaxAlgos = 10; + int nalgo = kMaxAlgos; + int i; + + cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; + CHECK_EQ(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + conv_desc_, + in_desc_, + kMaxAlgos, + &nalgo, + fwd_algo), CUDNN_STATUS_SUCCESS); + i = 0; + while (i < nalgo + && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && fwd_algo[i].memory > workspace_byte))) ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find an convolution algorithm."; + } else { + this->algo_ = fwd_algo[i].algo; + } + + cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; + CHECK_EQ(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + conv_desc_, + filter_desc_, + kMaxAlgos, + &nalgo, + bwd_filter_algo), CUDNN_STATUS_SUCCESS); + i = 0; + while (i < nalgo + && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_filter_algo[i].memory > workspace_byte))) ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find an convolution algorithm."; + } else { + this->back_algo_w_ = bwd_filter_algo[i].algo; + } + + cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; + CHECK_EQ(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + conv_desc_, + out_desc_, + kMaxAlgos, + &nalgo, + bwd_data_algo), CUDNN_STATUS_SUCCESS); + i = 0; + while (i < nalgo + && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS + || (param_.cudnn_tune.value() == deconv::kLimited + && bwd_data_algo[i].memory > workspace_byte))) ++i; + if (i == nalgo) { + LOG(FATAL) << "Failed to find an convolution algorithm."; + } else { + this->back_algo_ = bwd_data_algo[i].algo; + } + CuDNNAlgoReg::Get()->Register(key, this->algo_, this->back_algo_, this->back_algo_w_); + } + }, ctx, {}, {var}); + Engine::Get()->WaitForVar(var); + Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); + } + + void GetTempSize(const OpContext& ctx) { + if (init_temp_size_) return; + mshadow::Stream *s = ctx.get_stream(); + size_t back_size = 0, back_size_w = 0; + CHECK_EQ(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, filter_desc_, in_desc_, conv_desc_, out_desc_, back_algo_, &back_size), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, + CHECK_EQ(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, out_desc_, in_desc_, conv_desc_, filter_desc_, back_algo_w_, &back_size_w), CUDNN_STATUS_SUCCESS); - backward_workspace_byte_ = std::max(back_size, back_size_w); - CHECK_EQ(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, + backward_workspace_byte_ = std::max(back_size, back_size_w); + CHECK_EQ(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, out_desc_, filter_desc_, conv_desc_, in_desc_, algo_, &forward_workspace_byte_), CUDNN_STATUS_SUCCESS); - forward_workspace_ = forward_workspace_byte_ / sizeof(DType) + 1; - backward_workspace_ = backward_workspace_byte_ / sizeof(DType) + 1; - } + + forward_workspace_ = forward_workspace_byte_ / sizeof(DType) + 1; + backward_workspace_ = backward_workspace_byte_ / sizeof(DType) + 1; + init_temp_size_ = true; } bool init_cudnn_; + bool init_temp_size_; size_t forward_workspace_; size_t backward_workspace_; size_t forward_workspace_byte_; @@ -364,12 +584,10 @@ class CuDNNDeconvolutionOp : public Operator { cudnnConvolutionFwdAlgo_t algo_; cudnnConvolutionBwdDataAlgo_t back_algo_; cudnnConvolutionBwdFilterAlgo_t back_algo_w_; - #if CUDNN_MAJOR == 5 cudnnTensorFormat_t format_; - #endif DeconvolutionParam param_; }; -#endif // __CUDACC__ && CUDNN +#endif // CUDNN } // namespace op } // namespace mxnet diff --git a/src/operator/deconvolution-inl.h b/src/operator/deconvolution-inl.h index 94f38e531dde..991937bd038f 100644 --- a/src/operator/deconvolution-inl.h +++ b/src/operator/deconvolution-inl.h @@ -25,6 +25,7 @@ namespace deconv { enum DeconvolutionOpInputs {kData, kWeight, kBias}; enum DeconvolutionOpOutputs {kOut}; enum DeconvolutionOpResource {kTempSpace}; + enum DeconvolutionOpCudnnTune {kOff, kLimited, kFastest}; } struct DeconvolutionParam : public dmlc::Parameter { @@ -37,54 +38,69 @@ struct DeconvolutionParam : public dmlc::Parameter { uint32_t num_group; uint64_t workspace; bool no_bias; + dmlc::optional cudnn_tune; + dmlc::optional layout; DMLC_DECLARE_PARAMETER(DeconvolutionParam) { - int shape[] = {1, 1}; - DMLC_DECLARE_FIELD(kernel).describe("deconvolution kernel size: (y, x)"); - DMLC_DECLARE_FIELD(stride).set_default(TShape(shape, shape + 2)) - .describe("deconvolution stride: (y, x)"); - shape[0] = shape[1] = 0; - DMLC_DECLARE_FIELD(pad).set_default(TShape(shape, shape + 2)) - .describe("pad for deconvolution: (y, x), a good number is : (kernel-1)/2, " - "if target_shape set, pad will be ignored and will be computed " - "automatically"); - DMLC_DECLARE_FIELD(adj).set_default(TShape(shape, shape + 2)) - .describe("adjustment for output shape: (y, x), if target_shape set, adj " - "will be ignored and will be computed automatically"); - DMLC_DECLARE_FIELD(target_shape).set_default(TShape(shape, shape + 2)) - .describe("output shape with targe shape : (y, x)"); + DMLC_DECLARE_FIELD(kernel).describe("deconvolution kernel size: (h, w) or (d, h, w)"); + DMLC_DECLARE_FIELD(stride).set_default(TShape()) + .describe("deconvolution stride: (h, w) or (d, h, w)"); + DMLC_DECLARE_FIELD(pad).set_default(TShape()) + .describe("pad for deconvolution: (h, w) or (d, h, w). " + "A good number is : (kernel-1)/2. " + "If target_shape is set, " + "pad will be ignored and computed accordingly"); + DMLC_DECLARE_FIELD(adj).set_default(TShape()) + .describe("adjustment for output shape: (h, w) or (d, h, w). " + "If target_shape is set, " + "ad will be ignored and computed accordingly"); + DMLC_DECLARE_FIELD(target_shape).set_default(TShape()) + .describe("output shape with target shape : (h, w) or (d, h, w)"); DMLC_DECLARE_FIELD(num_filter).set_range(1, 100000) .describe("deconvolution filter(channel) number"); DMLC_DECLARE_FIELD(num_group).set_default(1) .describe("number of groups partition"); DMLC_DECLARE_FIELD(workspace).set_default(512).set_range(0, 8192) - .describe("Tmp workspace for deconvolution (MB)"); + .describe("Maximum temporal workspace allowed for deconvolution (MB)."); DMLC_DECLARE_FIELD(no_bias).set_default(true) .describe("Whether to disable bias parameter."); + DMLC_DECLARE_FIELD(cudnn_tune) + .add_enum("off", deconv::kOff) + .add_enum("limited_workspace", deconv::kLimited) + .add_enum("fastest", deconv::kFastest) + .set_default(dmlc::optional()) + .describe("Whether to pick convolution algo by running performance test."); + DMLC_DECLARE_FIELD(layout) + .add_enum("NCW", mshadow::kNCW) + .add_enum("NCHW", mshadow::kNCHW) + .add_enum("NCDHW", mshadow::kNCDHW) + .add_enum("NHWC", mshadow::kNHWC) + .add_enum("NDHWC", mshadow::kNDHWC) + .set_default(dmlc::optional()) + .describe("Set layout for input, output and weight. Empty for\n " + "default layout: NCW for 1d, NCHW for 2d and NCDHW for 3d."); } - inline void InferPad(index_t input_y, index_t input_x, - index_t* o_pad_y, index_t* o_pad_x, - index_t* o_adj_y, index_t* o_adj_x) const { - index_t& pad_y = *o_pad_y; - index_t& pad_x = *o_pad_x; - index_t& adj_y = *o_adj_y; - index_t& adj_x = *o_adj_x; - if (target_shape[0] != 0 || target_shape[1] != 0) { - pad_y = stride[0] * (input_y - 1) + kernel[0]; - pad_x = stride[1] * (input_x - 1) + kernel[1]; - CHECK_GE(pad_y, target_shape[0]) - << "too big target shape"; - CHECK_GE(pad_x, target_shape[1]) + template + void InferPad(TShape input, index_t (&o_pad)[ndim], index_t (&o_adj)[ndim] ) const { + if (target_shape.ndim() != 0) { + size_t input_ndim = input.ndim(); + + for (unsigned int i = 0; i < ndim; i++) { + // input.ndim() can be larger than ndim, in case that the complete input + // shape was passed and not only the ndim last ones + o_pad[i] = stride[i] * (input[(input_ndim - ndim) + i] - 1) + kernel[i]; + + CHECK_GE(o_pad[i], target_shape[i]) << "too big target shape"; - pad_y -= target_shape[0]; - pad_x -= target_shape[1]; - adj_y = pad_y % 2; pad_y = (pad_y + 1) / 2; - adj_x = pad_x % 2; pad_x = (pad_x + 1) / 2; + + o_pad[i] -= target_shape[i]; + o_adj[i] = o_pad[i] % 2; o_pad[i] = (o_pad[i] + 1) / 2; + } } else { - pad_y = pad[0]; - pad_x = pad[1]; - adj_y = adj[0]; - adj_x = adj[1]; + for (unsigned int i = 0; i < ndim; i++) { + o_pad[i] = pad[i]; + o_adj[i] = adj[i]; + } } } }; @@ -105,6 +121,11 @@ class DeconvolutionOp : public Operator { const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; + + if (param_.kernel.ndim() != 2) { + LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported"; + } + CHECK_EQ(req[deconv::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); @@ -113,8 +134,9 @@ class DeconvolutionOp : public Operator { Tensor data = in_data[deconv::kData].get(s); Tensor out = out_data[deconv::kOut].get(s); - index_t pad_y, pad_x, adj_y, adj_x; - param_.InferPad(data.size(2), data.size(3), &pad_y, &pad_x, &adj_y, &adj_x); + index_t o_pad[2], o_adj[2]; + TShape dshape = {data.size(2), data.size(3)}; + param_.InferPad(dshape, o_pad, o_adj); Shape<3> wmat_shape = Shape3(param_.num_group, @@ -142,7 +164,7 @@ class DeconvolutionOp : public Operator { shape_dstunit_[1], shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (pad_y == 0 && pad_x == 0) { + if (o_pad[0] == 0 && o_pad[1] == 0) { temp_col = unpack_patch2col(out.Slice(i, i + step), param_.kernel[0], param_.kernel[1], @@ -151,7 +173,7 @@ class DeconvolutionOp : public Operator { 1, 1); // Deconvolution only support dilate equals 1 } else { temp_col = unpack_patch2col(pad(out.Slice(i, i + step), - pad_y, pad_x), + o_pad[0], o_pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], @@ -164,7 +186,7 @@ class DeconvolutionOp : public Operator { gstride * (gid + 1)); tmpc = dot(wmat[gid].T(), temp_dst[gid]); } - if (pad_y == 0 && pad_x == 0) { + if (o_pad[0] == 0 && o_pad[1] == 0) { out.Slice(i, i + step) = pack_col2patch(temp_col, out.Slice(i, i + step).shape_, param_.kernel[0], @@ -175,8 +197,8 @@ class DeconvolutionOp : public Operator { 1); // Deconvolution only support dilate equals 1 } else { Shape<4> pshape = out.Slice(i, i + step).shape_; - pshape[2] += 2 * pad_y; - pshape[3] += 2 * pad_x; + pshape[2] += 2 * o_pad[0]; + pshape[3] += 2 * o_pad[1]; out.Slice(i, i + step) = crop(pack_col2patch(temp_col, pshape, param_.kernel[0], @@ -227,8 +249,9 @@ class DeconvolutionOp : public Operator { CHECK_EQ(s->blas_handle_ownership_, Stream::OwnHandle) << "Must init CuBLAS handle in stream"; #endif - index_t pad_y, pad_x, adj_y, adj_x; - param_.InferPad(data.size(2), data.size(3), &pad_y, &pad_x, &adj_y, &adj_x); + index_t o_pad[2], o_adj[2]; + TShape dshape = {data.size(2), data.size(3)}; + param_.InferPad(dshape, o_pad, o_adj); const index_t nbatch = data.size(0); Tensor workspace = @@ -246,7 +269,7 @@ class DeconvolutionOp : public Operator { shape_dstunit_[1], shape_dstunit_[2] * step), s); temp_dst = reshape(swapaxis<1, 0>(data.Slice(i, i + step)), temp_dst.shape_); - if (pad_y == 0 && pad_x == 0) { + if (o_pad[0] == 0 && o_pad[1] == 0) { temp_col = unpack_patch2col(grad.Slice(i, i + step), param_.kernel[0], param_.kernel[1], @@ -254,7 +277,7 @@ class DeconvolutionOp : public Operator { param_.stride[1], 1, 1); // Deconvolution only support dilate equals 1 } else { - temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), pad_y, pad_x), + temp_col = unpack_patch2col(pad(grad.Slice(i, i + step), o_pad[0], o_pad[1]), param_.kernel[0], param_.kernel[1], param_.stride[0], @@ -329,7 +352,10 @@ class DeconvolutionOp : public Operator { }; // class DeconvolutionOp template -Operator* CreateOp(DeconvolutionParam param, int dtype); +Operator* CreateOp(DeconvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx); #if DMLC_USE_CXX11 class DeconvolutionProp : public OperatorProperty { @@ -343,7 +369,25 @@ class DeconvolutionProp : public OperatorProperty { } void Init(const std::vector >& kwargs) override { + using namespace mshadow; param_.Init(kwargs); + if (param_.kernel.ndim() == 1) { + param_.layout = param_.layout? param_.layout.value() : mshadow::kNCW; + if (param_.stride.ndim() == 0) param_.stride = Shape1(1); + if (param_.pad.ndim() == 0) param_.pad = Shape1(0); + if (param_.adj.ndim() == 0) param_.adj = Shape1(0); + } else if (param_.kernel.ndim() == 2) { + param_.layout = param_.layout ? param_.layout.value() : mshadow::kNCHW; + if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0); + if (param_.adj.ndim() == 0) param_.adj = Shape2(0, 0); + } else { + CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim() << "D deconvolution not supported"; + param_.layout = param_.layout ? param_.layout.value(): mshadow::kNCDHW; + if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1); + if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0); + if (param_.adj.ndim() == 0) param_.adj = Shape3(0, 0, 0); + } } std::map GetParams() const override { @@ -353,54 +397,174 @@ class DeconvolutionProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { +#if MXNET_USE_CUDNN == 0 + if (param_.kernel.ndim() != 2) { + LOG(FATAL) << "If not using CUDNN only 2D-Deconvolution is supported"; + return false; + } +#endif // CUDNN + using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; } + out_shape->resize(1, TShape()); const TShape &dshape = (*in_shape)[deconv::kData]; if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 4U) \ + + if (param_.kernel.ndim() == 1) { + CHECK_EQ(dshape.ndim(), 3) \ + << "Input data should be 3D in batch-num_filter-x"; + Shape<3> dshape_ncw = ConvertLayout(dshape.get<3>(), param_.layout.value(), kNCW); + Shape<3> wshape = Shape3(dshape_ncw[1], param_.num_filter / param_.num_group, + param_.kernel[0]); + wshape = ConvertLayout(wshape, kNCW, param_.layout.value()); + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter)); + } + + const index_t ksize_x = static_cast(param_.kernel[0]); + + index_t o_pad[1]; + index_t o_adj[1]; + param_.InferPad(dshape_ncw, o_pad, o_adj); + + CHECK_EQ(dshape_ncw[1] % param_.num_group, 0U) \ + << "input num_filter must divide group size"; + CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ + << "output num_filter must divide group size"; + CHECK_GT(param_.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param_.kernel; + CHECK_GT(param_.stride.Size(), 0U) \ + << "incorrect stride size: " << param_.stride; + + CHECK_GE(ksize_x-1, o_adj[0]) << "adj(x) must be samller than kernel(w)"; + + Shape<3> oshape; + oshape[0] = dshape_ncw[0]; + oshape[1] = param_.num_filter; + oshape[2] = param_.stride[0] * (dshape_ncw[2] - 1) + ksize_x - 2 * o_pad[0] + o_adj[0]; + + if (param_.target_shape[0] > 0) { + CHECK_EQ(param_.target_shape[0], oshape[2]) \ + << "param_.target_shape[0] was not reasonable, please it carefully"; + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCW, param_.layout.value())); + + return true; + } else if (param_.kernel.ndim() == 2) { + CHECK_EQ(dshape.ndim(), 4U) \ << "Input data should be 4D in batch-num_filter-y-x"; - SHAPE_ASSIGN_CHECK(*in_shape, - deconv::kWeight, - Shape4(dshape[1], param_.num_filter / param_.num_group, - param_.kernel[0], param_.kernel[1])); - if (!param_.no_bias) { - SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter)); - } - out_shape->clear(); - out_shape->push_back(dshape); - // osize = stride * (isize - 1) + ksize - 2 * pad + adj - const index_t ksize_y = static_cast(param_.kernel[0]); - const index_t ksize_x = static_cast(param_.kernel[1]); - index_t pad_y, pad_x, adj_y, adj_x; - param_.InferPad(dshape[2], dshape[3], &pad_y, &pad_x, &adj_y, &adj_x); - CHECK_EQ(dshape[1] % param_.num_group, 0U) \ + Shape<4> dshape_nchw = ConvertLayout(dshape.get<4>(), param_.layout.value(), kNCHW); + Shape<4> wshape = Shape4(dshape_nchw[1], param_.num_filter / param_.num_group, + param_.kernel[0], param_.kernel[1]); + wshape = ConvertLayout(wshape, kNCHW, param_.layout.value()); + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter)); + } + + const index_t ksize_y = static_cast(param_.kernel[0]); + const index_t ksize_x = static_cast(param_.kernel[1]); + + index_t o_pad[2]; + index_t o_adj[2]; + param_.InferPad(dshape_nchw, o_pad, o_adj); + + CHECK_EQ(dshape_nchw[1] % param_.num_group, 0U) \ << "input num_filter must divide group size"; - CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ + CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ << "output num_filter must divide group size"; - CHECK_GT(param_.kernel.Size(), 0U) \ + CHECK_GT(param_.kernel.Size(), 0U) \ << "incorrect kernel size: " << param_.kernel; - CHECK_GT(param_.stride.Size(), 0U) \ + CHECK_GT(param_.stride.Size(), 0U) \ << "incorrect stride size: " << param_.stride; - CHECK_GE(ksize_y-1, adj_y) << "adj(y) must be samller than kernel(h)"; - CHECK_GE(ksize_x-1, adj_x) << "adj(x) must be samller than kernel(w)"; - (*out_shape)[deconv::kOut][1] = param_.num_filter; - (*out_shape)[deconv::kOut][2] = param_.stride[0] * (dshape[2] - 1) + - ksize_y - 2 * pad_y + adj_y; - (*out_shape)[deconv::kOut][3] = param_.stride[1] * (dshape[3] - 1) + - ksize_x - 2 * pad_x + adj_x; - if (param_.target_shape[0] > 0) { - CHECK_EQ(param_.target_shape[0], (*out_shape)[deconv::kOut][2]) \ - << "param_.target_shape[0] was not reasonable, pelase set it carefully"; - } - if (param_.target_shape[1] > 0) { - CHECK_EQ(param_.target_shape[1], (*out_shape)[deconv::kOut][3]) \ - << "param_.target_shape[1] was not reasonable, pelase set it carefully"; + + CHECK_GE(ksize_y-1, o_adj[0]) << "adj(y) must be samller than kernel(h)"; + CHECK_GE(ksize_x-1, o_adj[1]) << "adj(x) must be samller than kernel(w)"; + + Shape<4> oshape; + oshape[0] = dshape_nchw[0]; + oshape[1] = param_.num_filter; + oshape[2] = param_.stride[0] * (dshape_nchw[2] - 1) + ksize_y - 2 * o_pad[0] + o_adj[0]; + oshape[3] = param_.stride[1] * (dshape_nchw[3] - 1) + ksize_x - 2 * o_pad[1] + o_adj[1]; + + if (param_.target_shape[0] > 0) { + CHECK_EQ(param_.target_shape[0], oshape[2]) \ + << "param_.target_shape[0] was not reasonable, please it carefully"; + } + if (param_.target_shape[1] > 0) { + CHECK_EQ(param_.target_shape[1], oshape[3]) \ + << "param_.target_shape[1] was not reasonable, please set it carefully"; + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCHW, param_.layout.value())); + + return true; + } else if (param_.kernel.ndim() == 3) { + CHECK_EQ(dshape.ndim(), 5U) \ + << "Input data should be 5D in batch-num_filter-depth-y-x"; + Shape<5> dshape_ncdhw = ConvertLayout(dshape.get<5>(), param_.layout.value(), kNCDHW); + Shape<5> wshape = Shape5(dshape_ncdhw[1], param_.num_filter / param_.num_group, + param_.kernel[0], param_.kernel[1], param_.kernel[2]); + wshape = ConvertLayout(wshape, kNCDHW, param_.layout.value()); + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kWeight, wshape); + if (!param_.no_bias) { + SHAPE_ASSIGN_CHECK(*in_shape, deconv::kBias, Shape1(param_.num_filter)); + } + + const index_t ksize_d = static_cast(param_.kernel[0]); + const index_t ksize_y = static_cast(param_.kernel[1]); + const index_t ksize_x = static_cast(param_.kernel[2]); + + index_t o_pad[3]; + index_t o_adj[3]; + param_.InferPad(dshape_ncdhw, o_pad, o_adj); + + CHECK_EQ(dshape_ncdhw[1] % param_.num_group, 0U) \ + << "input num_filter must divide group size"; + CHECK_EQ(param_.num_filter % param_.num_group, 0U) \ + << "output num_filter must divide group size"; + CHECK_GT(param_.kernel.Size(), 0U) \ + << "incorrect kernel size: " << param_.kernel; + CHECK_GT(param_.stride.Size(), 0U) \ + << "incorrect stride size: " << param_.stride; + + CHECK_GE(ksize_d-1, o_adj[0]) << "adj(d) must be samller than kernel(d)"; + CHECK_GE(ksize_y-1, o_adj[1]) << "adj(y) must be samller than kernel(h)"; + CHECK_GE(ksize_x-1, o_adj[2]) << "adj(x) must be samller than kernel(w)"; + + Shape<5> oshape; + oshape[0] = dshape_ncdhw[0]; + oshape[1] = param_.num_filter; + oshape[2] = param_.stride[0] * (dshape_ncdhw[2] - 1) + ksize_d - 2 * o_pad[0] + o_adj[0]; + oshape[3] = param_.stride[1] * (dshape_ncdhw[3] - 1) + ksize_y - 2 * o_pad[1] + o_adj[1]; + oshape[4] = param_.stride[2] * (dshape_ncdhw[4] - 1) + ksize_x - 2 * o_pad[2] + o_adj[2]; + + if (param_.target_shape[0] > 0) { + CHECK_EQ(param_.target_shape[0], oshape[2]) \ + << "param_.target_shape[0] was not reasonable, please it carefully"; + } + if (param_.target_shape[1] > 0) { + CHECK_EQ(param_.target_shape[1], oshape[3]) \ + << "param_.target_shape[1] was not reasonable, please set it carefully"; + } + if (param_.target_shape[2] > 0) { + CHECK_EQ(param_.target_shape[2], oshape[4]) \ + << "param_.target_shape[2] was not reasonable, please set it carefully"; + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, ConvertLayout(oshape, kNCDHW, param_.layout.value())); + + return true; + } else { + LOG(FATAL) << "Unknown convolution type"; + return false; } - return true; } bool InferType(std::vector *in_type, diff --git a/src/operator/deconvolution.cc b/src/operator/deconvolution.cc index 61d839bae8d3..5b2d065667a9 100644 --- a/src/operator/deconvolution.cc +++ b/src/operator/deconvolution.cc @@ -10,7 +10,10 @@ namespace mxnet { namespace op { template<> -Operator* CreateOp(DeconvolutionParam param, int dtype) { +Operator* CreateOp(DeconvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx) { Operator *op = NULL; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new DeconvolutionOp(param); @@ -24,7 +27,7 @@ Operator* DeconvolutionProp::CreateOperatorEx(Context ctx, std::vector * std::vector out_type, aux_type; CHECK(InferType(in_type, &out_type, &aux_type)); CHECK(InferShape(in_shape, &out_shape, &aux_shape)); - DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); + DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0), in_shape, &out_shape, ctx); } DMLC_REGISTER_PARAMETER(DeconvolutionParam); diff --git a/src/operator/deconvolution.cu b/src/operator/deconvolution.cu index eb9c78b1d5a8..a670bc089739 100644 --- a/src/operator/deconvolution.cu +++ b/src/operator/deconvolution.cu @@ -13,11 +13,14 @@ namespace mxnet { namespace op { template<> -Operator* CreateOp(DeconvolutionParam param, int dtype) { +Operator* CreateOp(DeconvolutionParam param, int dtype, + std::vector *in_shape, + std::vector *out_shape, + Context ctx) { Operator *op = NULL; #if MXNET_USE_CUDNN == 1 MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { - op = new CuDNNDeconvolutionOp(param); + op = new CuDNNDeconvolutionOp(param, *in_shape, *out_shape, ctx); }); #else MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { diff --git a/src/operator/upsampling.cc b/src/operator/upsampling.cc index dd35a581e0ca..284afc57e856 100644 --- a/src/operator/upsampling.cc +++ b/src/operator/upsampling.cc @@ -33,8 +33,6 @@ Operator *CreateOp(UpSamplingParam param, int dtype) { p.stride = TShape(shape, shape + 2); shape[0] = shape[1] = pad; p.pad = TShape(shape, shape + 2); - shape[0] = shape[1] = 0; - p.target_shape = TShape(shape, shape + 2); op = new DeconvolutionOp(p); } else { LOG(FATAL) << "Unknown sample type"; diff --git a/src/operator/upsampling.cu b/src/operator/upsampling.cu index 1a96091472d5..95864e430010 100644 --- a/src/operator/upsampling.cu +++ b/src/operator/upsampling.cu @@ -32,8 +32,6 @@ Operator *CreateOp(UpSamplingParam param, int dtype) { p.stride = TShape(shape, shape + 2); shape[0] = shape[1] = pad; p.pad = TShape(shape, shape + 2); - shape[0] = shape[1] = 0; - p.target_shape = TShape(shape, shape + 2); op = new DeconvolutionOp(p); } else { LOG(FATAL) << "Unknown sample type"; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6a7bad27da0a..7e62cfcf369f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -704,9 +704,13 @@ def check_deconvolution_gradient(input_shape, num_filter, pad): def check_deconvolution_target_shape(input_shape, kernel, stride, pad, adj, target_shape=None): data = mx.sym.Variable(name="data") - deconv = mx.sym.Deconvolution( - data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5, - target_shape = target_shape if target_shape is not None else (0, 0)) + if target_shape: + deconv = mx.sym.Deconvolution( + data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5, + target_shape = target_shape) + else: + deconv = mx.sym.Deconvolution( + data=data, kernel=kernel, stride=stride, pad=pad, adj=adj, num_filter=5) arg_names = deconv.list_arguments() arg_shapes, out_shapes, _ = deconv.infer_shape(data=input_shape) assert out_shapes[0] == (input_shape[0], 5, 8, 8)