Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 69 additions & 30 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,51 @@ struct TakeZeroAxisCPU {
}
};

template <bool clip = true>
struct TakeNonzeroAxisCPU {
/*!
* \brief Map function for take operator
* \param i global thread id
* \param out_data ptr to output buffer
* \param in_data ptr to input buffer
* \param indices ptr to indices buffer
* \param outer_dim_stride stride of dimension before axis
* \param axis_dim_stride stride of axis dimension
* \param idx_size size of the indices tensor
* \param axis_dim dim size of the axis dimension
* \param axis axis id
*/
template <typename DType, typename IType>
MSHADOW_XINLINE static void Map(index_t i,
DType* out_data,
const DType* in_data,
const IType* indices,
const index_t outer_dim_stride,
const index_t axis_dim_stride,
const int idx_size,
const int axis_dim,
const int axis) {
for (index_t j = 0; j < static_cast<index_t>(idx_size); ++j) {
int index = indices[j];
if (clip) {
index = std::max(index, 0);
index = std::min(axis_dim - 1, index);
} else {
index %= axis_dim;
index += (index < 0) ? axis_dim : 0;
}
size_t in_offset = i * outer_dim_stride + index * axis_dim_stride;
size_t out_offset = (i * idx_size + j) * axis_dim_stride;
#pragma GCC diagnostic push
#if __GNUC__ >= 8
#pragma GCC diagnostic ignored "-Wclass-memaccess"
#endif
std::memcpy(out_data + out_offset, in_data + in_offset, axis_dim_stride * sizeof(DType));
#pragma GCC diagnostic pop
}
}
};

/*
* \brief returns true if all indices are between [min, max]
* \param data_ptr the indices to check
Expand Down Expand Up @@ -323,6 +368,7 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;

if (req[take_::kOut] == kNullOp)
return;
const TakeParam& param = nnvm::get<TakeParam>(attrs.parsed);
Expand Down Expand Up @@ -375,39 +421,32 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
for (int i = arrshape.ndim() - 1; i >= 0; stride *= arrshape[i], --i) {
in_strides[i] = stride;
}
mshadow::Shape<10> out_strides;
stride = 1;
for (int i = oshape.ndim() - 1; i >= 0; stride *= oshape[i], --i) {
out_strides[i] = stride;
int outer_dimensions = 1;
for (int i = 0; i < actual_axis; i++) {
outer_dimensions *= oshape[i];
}
if (param.mode == take_::kClip) {
Kernel<TakeNonzeroAxis<true>, cpu>::Launch(s,
oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis - 1],
in_strides[actual_axis - 1],
in_strides[actual_axis],
arrshape.ndim(),
oshape.ndim(),
idxshape.ndim(),
arrshape[actual_axis],
actual_axis);
Kernel<TakeNonzeroAxisCPU<true>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
} else {
Kernel<TakeNonzeroAxis<false>, cpu>::Launch(s,
oshape.Size(),
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
out_strides[actual_axis - 1],
in_strides[actual_axis - 1],
in_strides[actual_axis],
arrshape.ndim(),
oshape.ndim(),
idxshape.ndim(),
arrshape[actual_axis],
actual_axis);
Kernel<TakeNonzeroAxisCPU<false>, cpu>::Launch(s,
outer_dimensions,
outputs[take_::kOut].dptr<DType>(),
inputs[take_::kArr].dptr<DType>(),
inputs[take_::kIdx].dptr<IType>(),
in_strides[actual_axis - 1],
in_strides[actual_axis],
idxshape.Size(),
arrshape[actual_axis],
actual_axis);
}
}
});
Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,9 @@ inline bool EmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs,
return dispatched;
}

/*! \brief name the struct TakeNonzeroAxis for general take when
* axis is not zero, use TakeZeroAxisGPU or TakeZeroAxisCPU for axis zero
/*! \brief TakeNonzeroAxis is designated for general take when
* axis is not zero (for CPU optimized version use TakeNonZeroAxisCPU and
for axis zero use TakeZeroAxisGPU or TakeZeroAxisCPU)
*/
template <bool clip = true>
struct TakeNonzeroAxis {
Expand Down