Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ using std::isnan;
#endif


#ifdef __CUDACC__
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)


/*!
* \brief Get the number of blocks for cuda kernel given N
*/
inline int cuda_get_num_blocks(const int N) {
using namespace mshadow::cuda;
return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
}
#endif // __CUDACC__


/*! \brief operator request type switch */
#define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \
switch (req) { \
Expand Down Expand Up @@ -139,6 +156,22 @@ MSHADOW_XINLINE Shape<ndim> calc_stride(const Shape<ndim>& shape) {
}


struct fill {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType val) {
out[i] = val;
}
};


struct set_zero {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out) {
out[i] = static_cast<DType>(0);
}
};


template<typename OP, typename xpu>
struct Kernel;

Expand Down
18 changes: 4 additions & 14 deletions src/operator/nn/im2col.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,6 @@
namespace mxnet {
namespace op {

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

/*!
* \brief Get the number of blocks for cuda kernel given N
*/
inline int cuda_get_num_blocks(const int N) {
using namespace mshadow::cuda;
return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
}

/*!
* \brief im2col gpu kernel.
* DO NOT call this directly. Use wrapper function im2col() instead;
Expand Down Expand Up @@ -141,6 +127,7 @@ inline void im2col_gpu(mshadow::Stream<gpu>* s,
int width_col = (width + 2 * pad_w -
(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
using namespace mxnet_op;
// NOLINT_NEXT_LINE(whitespace/operators)
im2col_gpu_kernel<DType><<<cuda_get_num_blocks(num_kernels), mshadow::cuda::kBaseThreadNum,
0, mshadow::Stream<gpu>::GetStream(s)>>>(
Expand Down Expand Up @@ -303,6 +290,7 @@ inline void im2col(mshadow::Stream<gpu>* s,
index_t num_spatial_axes = kernel_shape.ndim();
CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
index_t num_kernels = im_shape[1] * col_shape.ProdShape(1, col_shape.ndim());
using namespace mxnet_op;
switch (num_spatial_axes) {
case 1:
im2col_nd_gpu_kernel<DType, 1> // NOLINT_NEXT_LINE(whitespace/operators)
Expand Down Expand Up @@ -347,6 +335,7 @@ inline void col2im_gpu(mshadow::Stream<gpu>* s, const DType* data_col, const int
int width_col = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) /
stride_w + 1;
int num_kernels = channels * height * width;
using namespace mxnet_op;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
// NOLINT_NEXT_LINE(whitespace/operators)
Expand Down Expand Up @@ -487,6 +476,7 @@ inline void col2im(mshadow::Stream<gpu>* s,
index_t im_size = im_shape.ProdShape(1, im_shape.ndim());
// num_axes should be smaller than block size
CHECK_LT(num_spatial_axes, mshadow::cuda::kBaseThreadNum);
using namespace mxnet_op;
switch (num_spatial_axes) {
case 1:
col2im_nd_gpu_kernel<DType, 1> // NOLINT_NEXT_LINE(whitespace/operators)
Expand Down
Loading