diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 28827db0e635..d257e53a1e94 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -43,7 +43,6 @@ #include "../mxnet_op.h" #include "./sort_op.h" #include "./init_op.h" -#include "./matrix_op-inl.h" #include "../../engine/openmp.h" namespace mxnet { diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index c46233c367fe..555b9cc7cb59 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -706,12 +706,42 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, return oshape.ndim() != 0 && oshape.Size() != 0; } -template -struct slice_forward { +template +struct slice_forward; + +template +struct slice_forward { + // i is the i-th row after flattening out into 2D tensor + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data, + const mshadow::Shape dshape, + const mshadow::Shape oshape, + const common::StaticArray begin, + const common::StaticArray step) { + const int data_last_dim_size = dshape[ndim-1]; + const int out_last_dim_size = oshape[ndim-1]; + const int step_last_dim = step[ndim-1]; + const int begin_last_dim = begin[ndim-1]; + const int j = i % out_last_dim_size; + int irow = 0; // row id of flattend 2D data + int stride = 1; + int idx = i / out_last_dim_size; + #pragma unroll + for (int k = ndim - 2; k >= 0; --k) { + irow += stride * ((idx % oshape[k]) * step[k] + begin[k]); + idx /= oshape[k]; + stride *= dshape[k]; + } + KERNEL_ASSIGN(out[i], req, + data[irow * data_last_dim_size + j * step_last_dim + begin_last_dim]); + } +}; + +template +struct slice_forward { // i is the i-th row after flattening out into 2D tensor template MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data, - const OpReqType req, const mshadow::Shape dshape, const mshadow::Shape oshape, const common::StaticArray begin, @@ -756,19 +786,27 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - mxnet_op::Kernel, xpu>::Launch(s, out.shape_.FlatTo2D()[0], - out.dptr(), data.dptr(), req[0], - data.shape_.get(), out.shape_.get(), begin, step); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + int num_threads = out.shape_.FlatTo2D()[0]; + if (std::is_same::value) { + num_threads *= out.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, + out.dptr(), data.dptr(), + data.shape_.get(), out.shape_.get(), begin, step); + }) }) }) } -template -struct slice_assign { +template +struct slice_assign; + +template +struct slice_assign { // i is the i-th row after flattening out into 2D tensor template MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val, - const OpReqType req, const mshadow::Shape oshape, const mshadow::Shape vshape, const common::StaticArray begin, @@ -794,6 +832,34 @@ struct slice_assign { } }; +template +struct slice_assign { + // i is the i-th row after flattening out into 2D tensor + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* val, + const mshadow::Shape oshape, + const mshadow::Shape vshape, + const common::StaticArray begin, + const common::StaticArray step) { + const int data_last_dim_size = oshape[ndim-1]; + const int out_last_dim_size = vshape[ndim-1]; + const int step_last_dim = step[ndim-1]; + const int begin_last_dim = begin[ndim-1]; + const int j = i % out_last_dim_size; + int irow = 0; // row id of flattend 2D out + int stride = 1; + int idx = i / out_last_dim_size; + #pragma unroll + for (int k = ndim - 2; k >= 0; --k) { + irow += stride * ((idx % vshape[k]) * step[k] + begin[k]); + idx /= vshape[k]; + stride *= oshape[k]; + } + KERNEL_ASSIGN(out[irow * data_last_dim_size + j * step_last_dim + begin_last_dim], + req, val[i]); + } +}; + template void SliceOpBackward(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -818,9 +884,15 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { - mxnet_op::Kernel, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0], - igrad.dptr(), ograd.dptr(), req[0], + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + int num_threads = ograd.shape_.FlatTo2D()[0]; + if (std::is_same::value) { + num_threads *= ograd.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, + igrad.dptr(), ograd.dptr(), igrad.shape_.get(), ograd.shape_.get(), begin, step); + }) }) }) } @@ -876,9 +948,15 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - mxnet_op::Kernel, xpu>::Launch(s, val.shape_.FlatTo2D()[0], - out.dptr(), val.dptr(), req[0], - out.shape_.get(), val.shape_.get(), begin, step); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + int num_threads = val.shape_.FlatTo2D()[0]; + if (std::is_same::value) { + num_threads *= val.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, + out.dptr(), val.dptr(), + out.shape_.get(), val.shape_.get(), begin, step); + }) }) }) } @@ -1242,9 +1320,15 @@ void SliceLikeForward(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(data.shape_, param_begin, param_end, param_step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - mxnet_op::Kernel, xpu>::Launch(s, out.shape_.FlatTo2D()[0], - out.dptr(), data.dptr(), req[0], - data.shape_.get(), out.shape_.get(), begin, step); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + int num_threads = out.shape_.FlatTo2D()[0]; + if (std::is_same::value) { + num_threads *= out.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, + num_threads, out.dptr(), data.dptr(), + data.shape_.get(), out.shape_.get(), begin, step); + }) }) }) } @@ -1282,9 +1366,15 @@ void SliceLikeBackward(const nnvm::NodeAttrs& attrs, common::StaticArray begin, end, step; GetIndexRange(ograd.shape_, param_begin, param_end, param_step, &begin, &end, &step); MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { - mxnet_op::Kernel, xpu>::Launch(s, ograd.shape_.FlatTo2D()[0], - igrad.dptr(), ograd.dptr(), req[0], - igrad.shape_.get(), ograd.shape_.get(), begin, step); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + int num_threads = ograd.shape_.FlatTo2D()[0]; + if (std::is_same::value) { + num_threads *= ograd.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, + igrad.dptr(), ograd.dptr(), + igrad.shape_.get(), ograd.shape_.get(), begin, step); + }) }) }) }