Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
113 changes: 63 additions & 50 deletions src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -200,57 +200,15 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) {
}

template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
// const int kWarpBits = kMemUnitBits;
inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
IndexType* sum_counts_ptr,
int* num_runs_ptr,
const mshadow::index_t num_rows) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
IndexType* sum_counts_ptr = NULL;
int* num_runs_ptr = NULL;
if (dst.size(0)*4 < src.size(0) && workspace != NULL) {
// Workspace given and potentially loops at least 4 times, use CUB to create sum_counts
CHECK_EQ(workspace->CheckContiguous(), true);
// workspace = [unique_out, counts_out, temporary_storage]
size_t unique_bytes = sorted.size(0)*sizeof(IndexType);
size_t counts_bytes = sorted.size(0)*sizeof(IndexType);
size_t num_runs_bytes = 1*sizeof(int);

size_t encode_bytes = 0;
cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
(NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
size_t exclusivesum_bytes = 0;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);

// Check that we have enough storage
CHECK_GE(workspace->size(0), unique_bytes + counts_bytes +
num_runs_bytes + temporary_bytes);

IndexType* unique_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_);
IndexType* counts_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_ + unique_bytes);
num_runs_ptr = reinterpret_cast<int*>(workspace->dptr_ + unique_bytes +
counts_bytes);
void* temporary_storage = reinterpret_cast<void *>(workspace->dptr_ + unique_bytes +
counts_bytes + num_runs_bytes);

cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
(temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr,
num_runs_ptr, sorted.size(0), stream);

sum_counts_ptr = unique_out_ptr;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
sorted.size(0), stream);
}

const int num_unique_est = min(dst.size(0), src.size(0));
const int num_unique_est = min(num_rows, src.size(0));
const int max_nthread = 128;
const int num_y = max(src.size(0)/num_unique_est, 1);
const int block_dim_x = kWarpSize;
Expand Down Expand Up @@ -307,6 +265,61 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
}


template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
CHECK_EQ(src.CheckContiguous(), true);
// const int kWarpBits = kMemUnitBits;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
IndexType* sum_counts_ptr = NULL;
int* num_runs_ptr = NULL;
if (dst.size(0)*4 < src.size(0) && workspace != NULL) {
// Workspace given and potentially loops at least 4 times, use CUB to create sum_counts
CHECK_EQ(workspace->CheckContiguous(), true);
// workspace = [unique_out, counts_out, temporary_storage]
size_t unique_bytes = sorted.size(0)*sizeof(IndexType);
size_t counts_bytes = sorted.size(0)*sizeof(IndexType);
size_t num_runs_bytes = 1*sizeof(int);

size_t encode_bytes = 0;
cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
(NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream);
size_t exclusivesum_bytes = 0;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream);
size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes);

// Check that we have enough storage
CHECK_GE(workspace->size(0), unique_bytes + counts_bytes +
num_runs_bytes + temporary_bytes);

IndexType* unique_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_);
IndexType* counts_out_ptr = reinterpret_cast<IndexType*>(workspace->dptr_ + unique_bytes);
num_runs_ptr = reinterpret_cast<int*>(workspace->dptr_ + unique_bytes +
counts_bytes);
void* temporary_storage = reinterpret_cast<void *>(workspace->dptr_ + unique_bytes +
counts_bytes + num_runs_bytes);

cub::DeviceRunLengthEncode::Encode<IndexType*, IndexType*, IndexType*, int*>
(temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr,
num_runs_ptr, sorted.size(0), stream);

sum_counts_ptr = unique_out_ptr;
cub::DeviceScan::ExclusiveSum<IndexType*, IndexType*>
(temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
sorted.size(0), stream);
}
AddTakeGradLargeBatchKernelLaunch(dst, sorted, index, src, sum_counts_ptr,
num_runs_ptr, dst.size(0));
}

} // namespace op
} // namespace mxnet
#endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_
20 changes: 14 additions & 6 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ void SparseEmbeddingOpForwardRspImpl<cpu>(const OpContext& ctx,


template<>
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const OpContext& ctx,
inline void SparseEmbeddingOpBackwardRspImpl<cpu>(const SparseEmbeddingParam& param,
const OpContext& ctx,
const TBlob& ograd,
const TBlob& data,
const OpReqType req,
Expand Down Expand Up @@ -178,6 +179,7 @@ GatherNDBackwardImpl(int N, int M, int K,
}

DMLC_REGISTER_PARAMETER(EmbeddingParam);
DMLC_REGISTER_PARAMETER(SparseEmbeddingParam);
DMLC_REGISTER_PARAMETER(TakeParam);
DMLC_REGISTER_PARAMETER(OneHotParam);
DMLC_REGISTER_PARAMETER(ScatterNDParam);
Expand Down Expand Up @@ -230,8 +232,8 @@ Examples::
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
})
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<EmbeddingParam>)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<EmbeddingParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down Expand Up @@ -268,6 +270,11 @@ The storage type of weight must be `row_sparse`, and the gradient of the weight

`SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k).
The operator is available on both CPU and GPU.
When `deterministic` is set to `True`, the accumulation of gradients follows a
deterministic order if a feature appears multiple times in the input. However, the
accumulation is usually slower when the order is enforced.
When the operator is used in recurrent neural network models on the GPU,
the recommended value for `deterministic` is `True`.

Examples::

Expand All @@ -294,7 +301,7 @@ Examples::
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<EmbeddingParam>)
.set_attr_parser(ParamParser<SparseEmbeddingParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
Expand All @@ -303,8 +310,8 @@ Examples::
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType)
.set_attr<nnvm::FInferShape>("FInferShape", EmbeddingOpShape<SparseEmbeddingParam>)
.set_attr<nnvm::FInferType>("FInferType", EmbeddingOpType<SparseEmbeddingParam>)
.set_attr<FInferStorageType>("FInferStorageType", SparseEmbeddingOpForwardStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseEmbeddingOpForwardEx<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
Expand All @@ -327,6 +334,7 @@ NNVM_REGISTER_OP(_backward_Embedding)
.set_attr<FCompute>("FCompute<cpu>", EmbeddingOpBackward<cpu>);

NNVM_REGISTER_OP(_backward_SparseEmbedding)
.set_attr_parser(ParamParser<SparseEmbeddingParam>)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<FResourceRequest>("FResourceRequest",
Expand Down
Loading