diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 0c000d9955ff..8facde168408 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -77,6 +77,7 @@ def GetConvertEnumVariableToString(self, variable=''): class Arg: typeDict = {'boolean':'bool',\ + 'boolean or None':'dmlc::optional',\ 'Shape(tuple)':'Shape',\ 'Symbol':'Symbol',\ 'NDArray':'Symbol',\ diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 2fbf7d8786d6..24f30270ad64 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -675,7 +675,7 @@ def __init__(self, channels, kernel_size, strides=(1, 1, 1), padding=(0, 0, 0), class _Pooling(HybridBlock): """Abstract class for different pooling layers.""" def __init__(self, pool_size, strides, padding, ceil_mode, global_pool, - pool_type, **kwargs): + pool_type, count_include_pad=None, **kwargs): super(_Pooling, self).__init__(**kwargs) if strides is None: strides = pool_size @@ -687,6 +687,8 @@ def __init__(self, pool_size, strides, padding, ceil_mode, global_pool, 'kernel': pool_size, 'stride': strides, 'pad': padding, 'global_pool': global_pool, 'pool_type': pool_type, 'pooling_convention': 'full' if ceil_mode else 'valid'} + if count_include_pad is not None: + self._kwargs['count_include_pad'] = count_include_pad def _alias(self): return 'pool' @@ -863,6 +865,8 @@ class AvgPool1D(_Pooling): respectively. padding is applied on 'W' dimension. ceil_mode : bool, default False When `True`, will use ceil instead of floor to compute the output shape. + count_include_pad : bool, default True + When 'False', will exclude padding elements when computing the average value. Inputs: @@ -879,13 +883,13 @@ class AvgPool1D(_Pooling): equation. """ def __init__(self, pool_size=2, strides=None, padding=0, layout='NCW', - ceil_mode=False, **kwargs): + ceil_mode=False, count_include_pad=True, **kwargs): assert layout == 'NCW', "Only supports 'NCW' layout for now" if isinstance(pool_size, numeric_types): pool_size = (pool_size,) assert len(pool_size) == 1, "pool_size must be a number or a list of 1 ints" super(AvgPool1D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) class AvgPool2D(_Pooling): @@ -907,6 +911,8 @@ class AvgPool2D(_Pooling): dimensions respectively. padding is applied on 'H' and 'W' dimension. ceil_mode : bool, default False When True, will use ceil instead of floor to compute the output shape. + count_include_pad : bool, default True + When 'False', will exclude padding elements when computing the average value. Inputs: @@ -926,13 +932,13 @@ class AvgPool2D(_Pooling): equation. """ def __init__(self, pool_size=(2, 2), strides=None, padding=0, - ceil_mode=False, layout='NCHW', **kwargs): + ceil_mode=False, layout='NCHW', count_include_pad=True, **kwargs): assert layout == 'NCHW', "Only supports 'NCHW' layout for now" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*2 assert len(pool_size) == 2, "pool_size must be a number or a list of 2 ints" super(AvgPool2D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) class AvgPool3D(_Pooling): @@ -955,6 +961,8 @@ class AvgPool3D(_Pooling): dimension. ceil_mode : bool, default False When True, will use ceil instead of floor to compute the output shape. + count_include_pad : bool, default True + When 'False', will exclude padding elements when computing the average value. Inputs: @@ -975,13 +983,13 @@ class AvgPool3D(_Pooling): equation. """ def __init__(self, pool_size=(2, 2, 2), strides=None, padding=0, - ceil_mode=False, layout='NCDHW', **kwargs): + ceil_mode=False, layout='NCDHW', count_include_pad=True, **kwargs): assert layout == 'NCDHW', "Only supports 'NCDHW' layout for now" if isinstance(pool_size, numeric_types): pool_size = (pool_size,)*3 assert len(pool_size) == 3, "pool_size must be a number or a list of 3 ints" super(AvgPool3D, self).__init__( - pool_size, strides, padding, ceil_mode, False, 'avg', **kwargs) + pool_size, strides, padding, ceil_mode, False, 'avg', count_include_pad, **kwargs) class GlobalMaxPool1D(_Pooling): diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index 84cf64030434..bc3ee366007c 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -51,7 +51,11 @@ class CuDNNPoolingOp { mode_ = CUDNN_POOLING_MAX; break; case pool_enum::kAvgPooling: - mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) { + mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; + } else { + mode_ = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + } break; default: LOG(FATAL) << "Not implmented"; @@ -263,7 +267,7 @@ class CuDNNPoolingOp { &(pad_vec[0]), &(stride_vec[0]))); #else - LOG(FATAL) << "3D pooling only support CUDNN v5 and abouve"; + LOG(FATAL) << "3D pooling only support CUDNN v5 and above"; #endif } } diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index 259af2b94025..9fd88a13c465 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -121,7 +121,11 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { return mkldnn::algorithm::pooling_max; break; case pool_enum::kAvgPooling: - return mkldnn::algorithm::pooling_avg_include_padding; + if (param.count_include_pad.has_value() && !param.count_include_pad.value()) { + return mkldnn::algorithm::pooling_avg_exclude_padding; + } else { + return mkldnn::algorithm::pooling_avg_include_padding; + } break; default: LOG(FATAL) << "MKLDNN Pooling: Unknown pooling method."; diff --git a/src/operator/nn/pool.cuh b/src/operator/nn/pool.cuh index 9d004d295bed..976aacf63a55 100644 --- a/src/operator/nn/pool.cuh +++ b/src/operator/nn/pool.cuh @@ -214,16 +214,19 @@ template __global__ void pool_sum_1d_gpu_kernel(const int nthreads, const DType* in_data, const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* out_data, - const bool getAvg = false) { + const bool get_avg = false, const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int c = (index / pooled_width) % channels; const int n = index / pooled_width / channels; int wstart = pw * stride_w - pad_w; int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (wend - wstart) : 1); + int pool_size = (get_avg? (wend - wstart) : 1); wstart = max(wstart, 0); wend = min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * width; for (int w = wstart; w < wend; ++w) { @@ -244,7 +247,8 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, DType* out_data, - const bool getAvg = false) { + const bool get_avg = false, + const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int ph = (index / pooled_width) % pooled_height; @@ -254,11 +258,14 @@ __global__ void pool_sum_2d_gpu_kernel(const int nthreads, const DType* in_data, int wstart = pw * stride_w - pad_w; int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (hend - hstart) * (wend - wstart) : 1); + int pool_size = (get_avg? (hend - hstart) * (wend - wstart) : 1); hstart = max(hstart, 0); wstart = max(wstart, 0); hend = min(hend, height); wend = min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * height * width; for (int h = hstart; h < hend; ++h) { @@ -282,7 +289,8 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, const int kernel_h, const int kernel_w, const int stride_d, const int stride_h, const int stride_w, const int pad_d, const int pad_h, const int pad_w, - DType* out_data, const bool getAvg = false) { + DType* out_data, const bool get_avg = false, + const bool count_include_pad = true) { CUDA_KERNEL_LOOP(index, nthreads) { const int pw = index % pooled_width; const int ph = (index / pooled_width) % pooled_height; @@ -295,13 +303,16 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, int dend = min(dstart + kernel_d, depth + pad_d); int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - const int pool_size = (getAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + int pool_size = (get_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); dstart = max(dstart, 0); hstart = max(hstart, 0); wstart = max(wstart, 0); dend = min(dend, depth); hend = min(hend, height); wend = min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } DType sum = 0; const DType* out_slice = in_data + (n * channels + c) * depth * height * width; for (int d = dstart; d < dend; ++d) { @@ -311,7 +322,9 @@ __global__ void pool_sum_3d_gpu_kernel(const int nthreads, const DType* in_data, } } } - out_data[index] = a_root_p::Map(sum); + out_data[index] = (pool_size == 0) ? + DType(nanf("")) : + a_root_p::Map(sum); } } @@ -487,7 +500,8 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr const int channels, const int width, const int pooled_width, const int kernel_w, const int stride_w, const int pad_w, DType* in_grad, - const bool isAvg = false) { + const bool is_avg = false, + const bool count_include_pad = true) { // index is the input image index in NCW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -506,7 +520,12 @@ __global__ void unpool_sum_1d_gpu_kernel(const int nthreads, const DType* out_gr // figure out the pooling size int wstart = pw * stride_w - pad_w; int wend = min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg? (wend - wstart) : 1); + int pool_size = (is_avg? (wend - wstart) : 1); + if (is_avg && !count_include_pad) { + wstart = max(wstart, 0); + wend = min(wend, width); + pool_size = (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[pw], in_data[index], out_data_slice[pw]) / pool_size; } @@ -528,7 +547,8 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, DType* in_grad, - const bool isAvg = false) { + const bool is_avg = false, + const bool count_include_pad = true) { // index is the input image index in NCHW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -553,8 +573,15 @@ __global__ void unpool_sum_2d_gpu_kernel(const int nthreads, const DType* out_gr int wstart = pw * stride_w - pad_w; int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg? (hend - hstart) * (wend - wstart) : 1); + int pool_size = (is_avg? (hend - hstart) * (wend - wstart) : 1); int out_index = ph * pooled_width + pw; + if (is_avg && !count_include_pad) { + hstart = max(hstart, 0); + wstart = max(wstart, 0); + hend = min(hend, height); + wend = min(wend, width); + pool_size = (hend - hstart) * (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[out_index], in_data[index], @@ -580,7 +607,8 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr const int kernel_d, const int kernel_h, const int kernel_w, const int stride_d, const int stride_h, const int stride_w, const int pad_d, const int pad_h, - const int pad_w, DType* in_grad, const bool isAvg = false) { + const int pad_w, DType* in_grad, const bool is_avg = false, + const bool count_include_pad = true) { // index is the input image index in NCDHW CUDA_KERNEL_LOOP(index, nthreads) { // find out the local index @@ -611,8 +639,17 @@ __global__ void unpool_sum_3d_gpu_kernel(const int nthreads, const DType* out_gr int dend = min(dstart + kernel_d, depth + pad_d); int hend = min(hstart + kernel_h, height + pad_h); int wend = min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + int pool_size = (is_avg? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); int out_index = (pd * pooled_height + ph) * pooled_width + pw; + if (is_avg && !count_include_pad) { + dstart = max(dstart, 0); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dend = min(dend, depth); + hend = min(hend, height); + wend = min(wend, width); + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } gradient += lp_grad::Map(out_grad_slice[out_index], in_data[index], out_data_slice[out_index]) / pool_size; @@ -643,7 +680,7 @@ template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, - DType* out_data) { + DType* out_data, const bool count_include_pad) { CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations"; using namespace mxnet_op; if (kernel.ndim() == 1) { @@ -659,7 +696,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is pool_sum_1d_gpu_kernel<<::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], oshape[2], - kernel[0], stride[0], pad[0], out_data, true); + kernel[0], stride[0], pad[0], out_data, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -693,7 +731,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is 0, mshadow::Stream::GetStream(s)>>>( oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], - stride[0], stride[1], pad[0], pad[1], out_data, true); + stride[0], stride[1], pad[0], pad[1], out_data, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -731,7 +770,7 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is oshape.Size(), in_data, ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], kernel[1], kernel[2], stride[0], stride[1], stride[2], - pad[0], pad[1], pad[2], out_data, true); + pad[0], pad[1], pad[2], out_data, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(pool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -777,7 +816,8 @@ template inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - const int pool_type, OpReqType req_type, DType* in_grad) { + const int pool_type, OpReqType req_type, DType* in_grad, + const bool count_include_pad) { if (mxnet::kNullOp == req_type) return; if (mxnet::kAddTo != req_type) { mxnet_op::Kernel::Launch(s, ishape.Size(), in_grad); @@ -798,7 +838,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* 0, mshadow::Stream::GetStream(s)>>>( ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], oshape[2], kernel[0], - stride[0], pad[0], in_grad, true); + stride[0], pad[0], in_grad, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_1d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -836,7 +876,8 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* ishape.Size(), out_grad, in_data, out_data, ishape[1], ishape[2], ishape[3], oshape[2], oshape[3], kernel[0], kernel[1], - stride[0], stride[1], pad[0], pad[1], in_grad, true); + stride[0], stride[1], pad[0], pad[1], in_grad, + true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_2d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) @@ -878,7 +919,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* ishape[1], ishape[2], ishape[3], ishape[4], oshape[2], oshape[3], oshape[4], kernel[0], kernel[1], kernel[2], stride[0], stride[1], stride[2], pad[0], pad[1], - pad[2], in_grad, true); + pad[2], in_grad, true, count_include_pad); MSHADOW_CUDA_POST_KERNEL_CHECK(unpool_sum_3d_gpu_kernel); } else if (pool_enum::kSumPooling == pool_type) { // NOLINT_NEXT_LINE(whitespace/operators) diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 9fe43b2bd468..8f7a5edc8324 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -216,7 +216,8 @@ inline void pool_max_3d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool get_avg = false, const bool count_include_pad = true) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -229,9 +230,12 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS for (int pw = 0; pw < pooled_width; ++pw) { int wstart = pw * stride_w - pad_w; int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (getAvg ? (wend - wstart) : 1); + int pool_size = (get_avg ? (wend - wstart) : 1); wstart = std::max(wstart, 0); wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (wend - wstart); + } DType sum = 0; for (int w = wstart; w < wend; ++w) { sum += a_pow_p::Map(in_data[w]) / pool_size; @@ -251,7 +255,8 @@ inline void pool_sum_1d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool get_avg = false, const bool count_include_pad = true) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -267,11 +272,14 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS int wstart = pw * stride_w - pad_w; int hend = std::min(hstart + kernel_h, height + pad_h); int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (getAvg ? (hend - hstart) * (wend - wstart) : 1); + int pool_size = (get_avg ? (hend - hstart) * (wend - wstart) : 1); hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); hend = std::min(hend, height); wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } DType sum = 0; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -294,7 +302,8 @@ inline void pool_sum_2d_cpu(const DType* in_data, const TShape& ishape, const TS template inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - DType* out_data, const bool getAvg = false) { + DType* out_data, + const bool get_avg = false, const bool count_include_pad = true) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -313,13 +322,16 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS int dend = std::min(dstart + kernel_d, depth + pad_d); int hend = std::min(hstart + kernel_h, height + pad_h); int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (getAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + int pool_size = (get_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); dstart = std::max(dstart, 0); hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); dend = std::min(dend, depth); hend = std::min(hend, height); wend = std::min(wend, width); + if (get_avg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } DType sum = 0; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -328,7 +340,9 @@ inline void pool_sum_3d_cpu(const DType* in_data, const TShape& ishape, const TS } } } - out_data[(pd*pooled_height+ph)*pooled_width+pw] = a_root_p::Map(sum); + out_data[(pd*pooled_height+ph)*pooled_width+pw] = (pool_size == 0) ? + DType(nanf("")) : + a_root_p::Map(sum); } } } @@ -509,8 +523,8 @@ inline void unpool_max_3d_cpu(const DType* out_grad, const DType* in_data, template inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool is_avg = false, const bool count_include_pad = true) { const int width = ishape[2]; const int pooled_width = oshape[2]; const int kernel_w = kernel[0]; @@ -523,9 +537,12 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const for (int pw = 0; pw < pooled_width; ++pw) { int wstart = pw * stride_w - pad_w; int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg ? (wend - wstart) : 1); + int pool_size = (is_avg ? (wend - wstart) : 1); wstart = std::max(wstart, 0); wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (wend - wstart); + } for (int w = wstart; w < wend; ++w) { in_grad[w] += lp_grad::Map(out_grad[pw], in_data[w], out_data[pw]) / pool_size; } @@ -545,8 +562,8 @@ inline void unpool_sum_1d_cpu(const DType* out_grad, const DType* in_data, const template inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool is_avg = false, const bool count_include_pad = true) { const int height = ishape[2], width = ishape[3]; const int pooled_height = oshape[2], pooled_width = oshape[3]; const int kernel_h = kernel[0], kernel_w = kernel[1]; @@ -562,11 +579,14 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const int wstart = pw * stride_w - pad_w; int hend = std::min(hstart + kernel_h, height + pad_h); int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg ? (hend - hstart) * (wend - wstart) : 1); + int pool_size = (is_avg ? (hend - hstart) * (wend - wstart) : 1); hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); hend = std::min(hend, height); wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (hend - hstart) * (wend - wstart); + } const int pool_index = ph * pooled_width + pw; for (int h = hstart; h < hend; ++h) { for (int w = wstart; w < wend; ++w) { @@ -593,8 +613,8 @@ inline void unpool_sum_2d_cpu(const DType* out_grad, const DType* in_data, const template inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, - const TShape& pad, const TShape& stride, - DType* in_grad, const bool isAvg = false) { + const TShape& pad, const TShape& stride, DType* in_grad, + const bool is_avg = false, const bool count_include_pad = true) { const int depth = ishape[2], height = ishape[3], width = ishape[4]; const int pooled_depth = oshape[2], pooled_height = oshape[3], pooled_width = oshape[4]; const int kernel_d = kernel[0], kernel_h = kernel[1], kernel_w = kernel[2]; @@ -613,13 +633,16 @@ inline void unpool_sum_3d_cpu(const DType* out_grad, const DType* in_data, const int dend = std::min(dstart + kernel_d, depth + pad_d); int hend = std::min(hstart + kernel_h, height + pad_h); int wend = std::min(wstart + kernel_w, width + pad_w); - int pool_size = (isAvg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); + int pool_size = (is_avg ? (dend - dstart) * (hend - hstart) * (wend - wstart) : 1); dstart = std::max(dstart, 0); hstart = std::max(hstart, 0); wstart = std::max(wstart, 0); dend = std::min(dend, depth); hend = std::min(hend, height); wend = std::min(wend, width); + if (is_avg && !count_include_pad) { + pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart); + } const int pool_index = (pd * pooled_height + ph) * pooled_width + pw; for (int d = dstart; d < dend; ++d) { for (int h = hstart; h < hend; ++h) { @@ -660,13 +683,14 @@ template inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, const int pool_type, OpReqType req_type, - DType* out_data) { + DType* out_data, const bool count_include_pad) { CHECK_EQ(req_type, kWriteTo) << "Only support req=kWriteTo in pooling operations"; if (kernel.ndim() == 1) { if (pool_enum::kMaxPooling == pool_type) { pool_max_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_1d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -678,7 +702,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is if (pool_enum::kMaxPooling == pool_type) { pool_max_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_2d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -690,7 +715,8 @@ inline void pool(mshadow::Stream* s, const DType* in_data, const TShape& is if (pool_enum::kMaxPooling == pool_type) { pool_max_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kAvgPooling == pool_type) { - pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, true); + pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data, + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { pool_sum_3d_cpu(in_data, ishape, oshape, kernel, pad, stride, out_data); } else if (pool_enum::kLpPooling == pool_type) { @@ -723,7 +749,8 @@ template inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* in_data, const DType* out_data, const TShape& ishape, const TShape& oshape, const TShape& kernel, const TShape& pad, const TShape& stride, - const int pool_type, OpReqType req_type, DType* in_grad, const int p_value = 2) { + const int pool_type, OpReqType req_type, DType* in_grad, + const bool count_include_pad) { if (mxnet::kNullOp == req_type) return; if (mxnet::kAddTo != req_type) { mxnet_op::Kernel::Launch(s, ishape.Size(), in_grad); @@ -733,7 +760,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_1d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { @@ -747,7 +774,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_2d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { @@ -761,7 +788,7 @@ inline void unpool(mshadow::Stream* s, const DType* out_grad, const DType* unpool_max_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kAvgPooling == pool_type) { unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad, - true); + true, count_include_pad); } else if (pool_enum::kSumPooling == pool_type) { unpool_sum_3d_cpu(out_grad, in_data, out_data, ishape, oshape, kernel, pad, stride, in_grad); } else if (pool_enum::kLpPooling == pool_type) { diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index a4770b49e857..0c4acf9d3184 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -50,6 +50,7 @@ struct PoolingParam : public dmlc::Parameter { bool global_pool; bool cudnn_off; dmlc::optional p_value; + dmlc::optional count_include_pad; DMLC_DECLARE_PARAMETER(PoolingParam) { DMLC_DECLARE_FIELD(kernel).set_default(TShape()) // add default value here .enforce_nonzero() @@ -81,7 +82,13 @@ struct PoolingParam : public dmlc::Parameter { .describe("Pad for pooling: (y, x) or (d, y, x). Defaults to no padding."); DMLC_DECLARE_FIELD(p_value).set_default(dmlc::optional()) - .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling"); + .describe("Value of p for Lp pooling, can be 1 or 2, required for Lp Pooling."); + + DMLC_DECLARE_FIELD(count_include_pad).set_default(dmlc::optional()) + .describe("Only used for AvgPool, specify whether to count padding elements for average" + "calculation. For example, with a 5*5 kernel on a 3*3 corner of a image," + "the sum of the 9 valid elements will be divided by 25 if this is set to true," + "or it will be divided by 9 if this is set to false. Defaults to true."); } bool operator==(const PoolingParam& other) const { @@ -92,7 +99,8 @@ struct PoolingParam : public dmlc::Parameter { this->pooling_convention == other.pooling_convention && this->global_pool == other.global_pool && this->cudnn_off == other.cudnn_off && - this->p_value == other.p_value; + this->p_value == other.p_value && + this->count_include_pad == other.count_include_pad; } }; @@ -112,6 +120,7 @@ struct hash { ret = dmlc::HashCombine(ret, val.global_pool); ret = dmlc::HashCombine(ret, val.cudnn_off); ret = dmlc::HashCombine(ret, val.p_value); + ret = dmlc::HashCombine(ret, val.count_include_pad); return ret; } }; @@ -153,27 +162,29 @@ class PoolingOp { } const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? param_.p_value.value() : 1; + const bool count_include_pad = (param_.count_include_pad.has_value()) ? + param_.count_include_pad.value() : true; switch (p_value) { case 1: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; case 2: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; case 3: pool(s, in_data.dptr(), in_data.shape_, out_data.shape_, kernel, padding, stride, - param_.pool_type, req, out_data.dptr()); + param_.pool_type, req, out_data.dptr(), count_include_pad); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet..."; @@ -201,6 +212,8 @@ class PoolingOp { const int p_value = (param_.pool_type == pool_enum::kLpPooling && param_.p_value.has_value()) ? param_.p_value.value() : 1; + const bool count_include_pad = (param_.count_include_pad.has_value()) ? + param_.count_include_pad.value() : true; switch (p_value) { case 1: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -208,7 +221,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; case 2: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -216,7 +229,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; case 3: unpool(s, out_grad.dptr(), in_data.dptr(), out_data.dptr(), @@ -224,7 +237,7 @@ class PoolingOp { kernel, padding, stride, - param_.pool_type, req, in_grad.dptr()); + param_.pool_type, req, in_grad.dptr(), count_include_pad); break; default: LOG(FATAL) << "p value of " << p_value << " is not supported yet..."; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7c3d670ba222..1c6785a57025 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -740,8 +740,8 @@ def test_pooling_with_type(): @with_seed() def test_pooling_versions(): - def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, - pooling_convention='valid', global_pool=False, p_value=2): + def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, stride, pooling_convention='valid', + global_pool=False, p_value=2, count_include_pad=True, tol=None): ctx_list = [] sym_list = [] # PoolingV1 cpu @@ -765,61 +765,69 @@ def test_pooling_versions_helper(pool_op_list, data, kernel, pool_type, pad, str ctx_list.append({'ctx': mx.cpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) if not global_pool: sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, name='pool', p_value=p_value)) + pooling_convention=pooling_convention, name='pool', + p_value=p_value, count_include_pad=count_include_pad)) else: - sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', p_value=p_value)) + sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, name='pool', + p_value=p_value, count_include_pad=count_include_pad)) # Pooling gpu if 'pool_gpu' in pool_op_list: ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) if not global_pool: sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, cudnn_off=True, name='pool', p_value=p_value)) + pooling_convention=pooling_convention, cudnn_off=True, name='pool', + p_value=p_value, count_include_pad=count_include_pad)) else: sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, cudnn_off=True, - name='pool', p_value=p_value)) + name='pool', p_value=p_value, count_include_pad=count_include_pad)) # CuDNNPooling if 'pool_cudnn' in pool_op_list: ctx_list.append({'ctx': mx.gpu(0), 'pool_data': data, 'type_dict': {'pool_data': np.float32}}) if not global_pool: sym_list.append(mx.sym.Pooling(kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, name='pool')) + pooling_convention=pooling_convention, p_value=p_value, cudnn_off=False, + name='pool', count_include_pad=count_include_pad)) else: sym_list.append(mx.sym.Pooling(kernel=kernel, pool_type=pool_type, global_pool=True, p_value=p_value, - cudnn_off=False, name='pool')) - check_consistency(sym_list, ctx_list) + cudnn_off=False, name='pool', count_include_pad=count_include_pad)) + check_consistency(sym_list, ctx_list, equal_nan=(not count_include_pad), tol=tol) - def test_1d_pooling(pool_type, p_value=2): + def test_1d_pooling(pool_type, p_value=2, count_include_pad=True): data = (2, 3, 20) kernel = (4,) pad = (0,) stride = (1,) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) + pooling_convention='valid', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (2,) stride = (2,) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) + pooling_convention='valid', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (0,) stride = (1,) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) + pooling_convention='full', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (2,) stride = (2,) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) + pooling_convention='full', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, p_value=p_value) + global_pool=True, p_value=p_value, count_include_pad=count_include_pad) - def test_2d_pooling(pool_type, p_value=2): + def test_2d_pooling(pool_type, p_value=2, count_include_pad=True): data = (2, 3, 20, 20) kernel = (4, 5) pad = (0, 0) @@ -831,14 +839,15 @@ def test_2d_pooling(pool_type, p_value=2): else: test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False) + pooling_convention='valid', global_pool=False, count_include_pad=count_include_pad) # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here pad = (2, 3) stride = (2, 3) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) + pooling_convention='valid', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (0, 0) stride = (1, 1) @@ -847,16 +856,24 @@ def test_2d_pooling(pool_type, p_value=2): data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, pooling_convention='full', global_pool=False, p_value=p_value) else: - test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], - data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False) + if count_include_pad: + test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], + data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, + pooling_convention='full', global_pool=False, + count_include_pad=count_include_pad) + else: + test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], + data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, + pooling_convention='full', global_pool=False, + count_include_pad=count_include_pad) # pool_v1 has bugs when pad is not 0, do not test PoolingV1 here pad = (2, 3) stride = (2, 3) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) + pooling_convention='full', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) if pool_type == 'lp': test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], @@ -865,55 +882,62 @@ def test_2d_pooling(pool_type, p_value=2): else: test_pooling_versions_helper(pool_op_list=['pool_v1_cpu', 'pool_v1_gpu', 'pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True) + global_pool=True, count_include_pad=count_include_pad) - def test_3d_pooling(pool_type, p_value=2): + def test_3d_pooling(pool_type, p_value=2, count_include_pad=True): data = (2, 3, 20, 20, 20) kernel = (4, 5, 3) pad = (0, 0, 0) stride = (1, 1, 1) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) + pooling_convention='valid', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (2, 3, 3) stride = (2, 3, 1) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='valid', global_pool=False, p_value=p_value) + pooling_convention='valid', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (0, 0, 0) stride = (1, 1, 1) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) + pooling_convention='full', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) pad = (2, 3, 3) stride = (2, 3, 1) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - pooling_convention='full', global_pool=False, p_value=p_value) + pooling_convention='full', global_pool=False, p_value=p_value, + count_include_pad=count_include_pad) test_pooling_versions_helper(pool_op_list=['pool_cpu', 'pool_gpu', 'pool_cudnn'], data=data, kernel=kernel, pad=pad, stride=stride, pool_type=pool_type, - global_pool=True, p_value=p_value) + global_pool=True, p_value=p_value, count_include_pad=count_include_pad) test_1d_pooling('max') - test_1d_pooling('avg') + test_1d_pooling('avg', count_include_pad=True) + test_1d_pooling('avg', count_include_pad=False) test_1d_pooling('sum') test_1d_pooling('lp', p_value=1) test_1d_pooling('lp', p_value=2) test_1d_pooling('lp', p_value=3) test_2d_pooling('max') - test_2d_pooling('avg') + test_2d_pooling('avg', count_include_pad=True) + test_2d_pooling('avg', count_include_pad=False) test_2d_pooling('sum') test_2d_pooling('lp', p_value=1) test_2d_pooling('lp', p_value=2) test_2d_pooling('lp', p_value=3) test_3d_pooling('max') - test_3d_pooling('avg') + test_3d_pooling('avg', count_include_pad=True) + test_3d_pooling('avg', count_include_pad=False) test_3d_pooling('sum') test_3d_pooling('lp', p_value=1) test_3d_pooling('lp', p_value=2) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index bf1e0deb200b..50ecd5c8809b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -470,6 +470,7 @@ def test_pool(): nn.MaxPool1D(3), nn.MaxPool1D(3, 2), nn.AvgPool1D(), + nn.AvgPool1D(count_include_pad=False), nn.GlobalAvgPool1D(), ] for layer in layers1d: @@ -481,6 +482,7 @@ def test_pool(): nn.MaxPool2D((3, 3)), nn.MaxPool2D(3, 2), nn.AvgPool2D(), + nn.AvgPool2D(count_include_pad=False), nn.GlobalAvgPool2D(), ] for layer in layers2d: @@ -491,6 +493,7 @@ def test_pool(): nn.MaxPool3D((3, 3, 3)), nn.MaxPool3D(3, 2), nn.AvgPool3D(), + nn.AvgPool3D(count_include_pad=False), nn.GlobalAvgPool3D(), ] for layer in layers3d: