diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index 4458151f1782..34cc26302548 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -200,57 +200,15 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) { } template -inline void AddTakeGradLargeBatch(mshadow::Tensor dst, - const mshadow::Tensor& sorted, - const mshadow::Tensor& index, - const mshadow::Tensor &src, - mshadow::Tensor* workspace) { - CHECK_EQ(dst.CheckContiguous(), true); - CHECK_EQ(sorted.CheckContiguous(), true); - CHECK_EQ(index.CheckContiguous(), true); - CHECK_EQ(src.CheckContiguous(), true); - // const int kWarpBits = kMemUnitBits; +inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor dst, + const mshadow::Tensor& sorted, + const mshadow::Tensor& index, + const mshadow::Tensor &src, + IndexType* sum_counts_ptr, + int* num_runs_ptr, + const mshadow::index_t num_rows) { cudaStream_t stream = mshadow::Stream::GetStream(dst.stream_); - IndexType* sum_counts_ptr = NULL; - int* num_runs_ptr = NULL; - if (dst.size(0)*4 < src.size(0) && workspace != NULL) { - // Workspace given and potentially loops at least 4 times, use CUB to create sum_counts - CHECK_EQ(workspace->CheckContiguous(), true); - // workspace = [unique_out, counts_out, temporary_storage] - size_t unique_bytes = sorted.size(0)*sizeof(IndexType); - size_t counts_bytes = sorted.size(0)*sizeof(IndexType); - size_t num_runs_bytes = 1*sizeof(int); - - size_t encode_bytes = 0; - cub::DeviceRunLengthEncode::Encode - (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream); - size_t exclusivesum_bytes = 0; - cub::DeviceScan::ExclusiveSum - (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream); - size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes); - - // Check that we have enough storage - CHECK_GE(workspace->size(0), unique_bytes + counts_bytes + - num_runs_bytes + temporary_bytes); - - IndexType* unique_out_ptr = reinterpret_cast(workspace->dptr_); - IndexType* counts_out_ptr = reinterpret_cast(workspace->dptr_ + unique_bytes); - num_runs_ptr = reinterpret_cast(workspace->dptr_ + unique_bytes + - counts_bytes); - void* temporary_storage = reinterpret_cast(workspace->dptr_ + unique_bytes + - counts_bytes + num_runs_bytes); - - cub::DeviceRunLengthEncode::Encode - (temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr, - num_runs_ptr, sorted.size(0), stream); - - sum_counts_ptr = unique_out_ptr; - cub::DeviceScan::ExclusiveSum - (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr, - sorted.size(0), stream); - } - - const int num_unique_est = min(dst.size(0), src.size(0)); + const int num_unique_est = min(num_rows, src.size(0)); const int max_nthread = 128; const int num_y = max(src.size(0)/num_unique_est, 1); const int block_dim_x = kWarpSize; @@ -307,6 +265,61 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel); } + +template +inline void AddTakeGradLargeBatch(mshadow::Tensor dst, + const mshadow::Tensor& sorted, + const mshadow::Tensor& index, + const mshadow::Tensor &src, + mshadow::Tensor* workspace) { + CHECK_EQ(dst.CheckContiguous(), true); + CHECK_EQ(sorted.CheckContiguous(), true); + CHECK_EQ(index.CheckContiguous(), true); + CHECK_EQ(src.CheckContiguous(), true); + // const int kWarpBits = kMemUnitBits; + cudaStream_t stream = mshadow::Stream::GetStream(dst.stream_); + IndexType* sum_counts_ptr = NULL; + int* num_runs_ptr = NULL; + if (dst.size(0)*4 < src.size(0) && workspace != NULL) { + // Workspace given and potentially loops at least 4 times, use CUB to create sum_counts + CHECK_EQ(workspace->CheckContiguous(), true); + // workspace = [unique_out, counts_out, temporary_storage] + size_t unique_bytes = sorted.size(0)*sizeof(IndexType); + size_t counts_bytes = sorted.size(0)*sizeof(IndexType); + size_t num_runs_bytes = 1*sizeof(int); + + size_t encode_bytes = 0; + cub::DeviceRunLengthEncode::Encode + (NULL, encode_bytes, NULL, NULL, NULL, NULL, sorted.size(0), stream); + size_t exclusivesum_bytes = 0; + cub::DeviceScan::ExclusiveSum + (NULL, exclusivesum_bytes, NULL, NULL, sorted.size(0), stream); + size_t temporary_bytes = std::max(encode_bytes, exclusivesum_bytes); + + // Check that we have enough storage + CHECK_GE(workspace->size(0), unique_bytes + counts_bytes + + num_runs_bytes + temporary_bytes); + + IndexType* unique_out_ptr = reinterpret_cast(workspace->dptr_); + IndexType* counts_out_ptr = reinterpret_cast(workspace->dptr_ + unique_bytes); + num_runs_ptr = reinterpret_cast(workspace->dptr_ + unique_bytes + + counts_bytes); + void* temporary_storage = reinterpret_cast(workspace->dptr_ + unique_bytes + + counts_bytes + num_runs_bytes); + + cub::DeviceRunLengthEncode::Encode + (temporary_storage, temporary_bytes, sorted.dptr_, unique_out_ptr, counts_out_ptr, + num_runs_ptr, sorted.size(0), stream); + + sum_counts_ptr = unique_out_ptr; + cub::DeviceScan::ExclusiveSum + (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr, + sorted.size(0), stream); + } + AddTakeGradLargeBatchKernelLaunch(dst, sorted, index, src, sum_counts_ptr, + num_runs_ptr, dst.size(0)); +} + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_TENSOR_INDEXING_OP_CUH_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index cce4537ae3a2..bb65419a79c8 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -70,7 +70,8 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, template<> -inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, +inline void SparseEmbeddingOpBackwardRspImpl(const SparseEmbeddingParam& param, + const OpContext& ctx, const TBlob& ograd, const TBlob& data, const OpReqType req, @@ -178,6 +179,7 @@ GatherNDBackwardImpl(int N, int M, int K, } DMLC_REGISTER_PARAMETER(EmbeddingParam); +DMLC_REGISTER_PARAMETER(SparseEmbeddingParam); DMLC_REGISTER_PARAMETER(TakeParam); DMLC_REGISTER_PARAMETER(OneHotParam); DMLC_REGISTER_PARAMETER(ScatterNDParam); @@ -230,8 +232,8 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{"data", "weight"}; }) -.set_attr("FInferShape", EmbeddingOpShape) -.set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferType", EmbeddingOpType) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; @@ -268,6 +270,11 @@ The storage type of weight must be `row_sparse`, and the gradient of the weight `SparseEmbedding` is designed for the use case where `input_dim` is very large (e.g. 100k). The operator is available on both CPU and GPU. + When `deterministic` is set to `True`, the accumulation of gradients follows a + deterministic order if a feature appears multiple times in the input. However, the + accumulation is usually slower when the order is enforced. + When the operator is used in recurrent neural network models on the GPU, + the recommended value for `deterministic` is `True`. Examples:: @@ -294,7 +301,7 @@ Examples:: )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) -.set_attr_parser(ParamParser) +.set_attr_parser(ParamParser) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data", "weight"}; @@ -303,8 +310,8 @@ Examples:: [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FInferShape", EmbeddingOpShape) -.set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferShape", EmbeddingOpShape) +.set_attr("FInferType", EmbeddingOpType) .set_attr("FInferStorageType", SparseEmbeddingOpForwardStorageType) .set_attr("FComputeEx", SparseEmbeddingOpForwardEx) .set_attr("FGradient", @@ -327,6 +334,7 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("FCompute", EmbeddingOpBackward); NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_attr_parser(ParamParser) .set_num_inputs(2) .set_num_outputs(2) .set_attr("FResourceRequest", diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 762d8fd64c2b..5cdf5060aec4 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -60,6 +60,75 @@ struct AddTakeGradRspGPUKernel { } }; +/* + * \brief kernel for backward computation for take, executed with deterministic order + * \param thread_id the thread id + * \param out the output gradient data + * \param lookup_table the table to lookup the position of an id in gradient array + * \param sorted_data the sorted data input + * \param original_idx the original indices of the sorted data input + * \param ograd head gradient + * \param row_length the output dimension + * \param num_threads_per_row the number of threads to process a row together + * \param SZ the number of features a thread is responsible for + */ +template +struct AddTakeGradRspDeterministicKernel { + template + __device__ __forceinline__ static void Map(int thread_id, + DType* out, + const nnvm::dim_t* lookup_table, + const nnvm::dim_t* sorted_data, + const nnvm::dim_t data_size, + const nnvm::dim_t* original_idx, + const DType* ograd, + const nnvm::dim_t row_length, + const nnvm::dim_t num_threads_per_row) { + using nnvm::dim_t; + int tid = thread_id / num_threads_per_row; + const int feature_start = thread_id % num_threads_per_row * SZ; + int num_features = SZ; + if (feature_start + num_features > row_length) { + num_features = row_length - feature_start; + } + if (tid == 0 || sorted_data[tid - 1] != sorted_data[tid]) { + DType acc[SZ]; + #pragma unroll + for (int i = 0; i < SZ; i++) { + acc[i] = 0; + } + const dim_t data = sorted_data[tid]; + const dim_t row_id = lookup_table[data]; + const dim_t out_offset = row_id * row_length + feature_start; + do { + const dim_t idx = original_idx[tid]; + const dim_t ograd_offset = idx * row_length + feature_start; + for (int i = 0; i < num_features; i++) { + acc[i] += ograd[ograd_offset + i]; + } + tid++; + } while (tid < data_size && sorted_data[tid - 1] == sorted_data[tid]); + for (int i = 0; i < num_features; i++) { + out[out_offset + i] += acc[i]; + } + } + } +}; + +/* + * \brief the kernel to generate a lookup table for positions of row ids + * \param i thread id + * \param out output table + * \param data the input row id in sorted order + */ +struct mark_lookup_table { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const DType* data) { + out[static_cast(data[i])] = i; + } +}; + + template<> void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, @@ -103,13 +172,138 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, } } +template +void SparseEmbeddingDeterministicKernelLaunch(const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const OpReqType req, + const NDArray& output) { + using namespace mshadow; + using namespace mxnet_op; + using namespace expr; + using namespace rowsparse; + using nnvm::dim_t; + mshadow::Stream *s = ctx.get_stream(); + const dim_t num_rows = output.shape()[0]; + const dim_t row_length = output.shape()[1]; + const dim_t data_size = static_cast(data.shape_.Size()); + // temp resource declarations + dim_t* lookup_table = NULL; + void* temp_storage = NULL; + dim_t* sorted_data = NULL; + dim_t* original_idx = NULL; + // calculate number of bytes for temp resources + size_t lookup_table_bytes = num_rows * sizeof(dim_t); + size_t sorted_data_storage_bytes = data_size * sizeof(dim_t); + size_t original_idx_storage_bytes = data_size * sizeof(dim_t); + size_t sort_workspace_size = SortByKeyWorkspaceSize(data_size); + size_t unique_workspace_bytes = 0; + // estimate unique temp space + IType* data_ptr = data.dptr(); + size_t *null_ptr = nullptr; + cub::DeviceSelect::Unique(NULL, unique_workspace_bytes, data_ptr, data_ptr, + null_ptr, data_size, Stream::GetStream(s)); + // One more space reserved for unique count + size_t temp_workspace_bytes = std::max(unique_workspace_bytes, + sort_workspace_size); + size_t total_storage_bytes = lookup_table_bytes + sorted_data_storage_bytes + + original_idx_storage_bytes + temp_workspace_bytes; + + // request resource and split it. layout is: + // lookup_table, sorted_data, original_idx, temp_storage + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(total_storage_bytes), s); + lookup_table = reinterpret_cast(workspace.dptr_); + sorted_data = reinterpret_cast(workspace.dptr_ + lookup_table_bytes); + original_idx = reinterpret_cast(workspace.dptr_ + lookup_table_bytes + + sorted_data_storage_bytes); + temp_storage = workspace.dptr_ + total_storage_bytes - temp_workspace_bytes; + + // make a copy of the data, to be sorted + TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask); + auto sorted_data_tensor = sorted_data_blob.FlatTo1D(s); + mxnet_op::copy(s, sorted_data_blob, data); + + // generate original idx + Tensor original_idx_tensor(original_idx, Shape1(data_size), s); + Kernel::Launch(s, data_size, 1, static_cast(0), + static_cast(1), kWriteTo, original_idx); + // sort data with its original idx + int num_bits = ilog2(num_rows - 1); + char* temp_storage_ptr = reinterpret_cast(temp_storage); + Tensor temp_storage_tensor(temp_storage_ptr, + Shape1(sort_workspace_size), s); + SortByKey(sorted_data_tensor, original_idx_tensor, true, + &temp_storage_tensor, 0, num_bits); + + // compute unique row ids based on sorted values. + output.CheckAndAllocAuxData(kIdx, Shape1(data_size + 1)); + + // fill row_idx array of output matrix, using the row_flg values + RType* grad_row_idx = output.aux_data(kIdx).dptr(); + cub::DeviceSelect::Unique(temp_storage_ptr, unique_workspace_bytes, sorted_data, + grad_row_idx, grad_row_idx + data_size, data_size, Stream::GetStream(s)); + + dim_t nnr = 0; + CUDA_CALL(cudaMemcpy(&nnr, grad_row_idx + data_size, sizeof(RType), + cudaMemcpyDeviceToHost)); + CHECK_EQ(output.shape().ndim(), 2) << "Unexcepted ndim"; + output.CheckAndAllocData(Shape2(nnr, output.shape()[1])); + output.set_aux_shape(kIdx, Shape1(nnr)); + + // generate lookup table + Kernel::Launch(s, nnr, lookup_table, grad_row_idx); + + // accumulate gradients + DType* grad_data = output.data().dptr(); + Fill(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask), + kWriteTo, 0); + const int SZ = 4; + const nnvm::dim_t num_threads_per_row = (row_length + SZ - 1) / SZ; + Kernel, gpu>::Launch(s, data_size * num_threads_per_row, + grad_data, lookup_table, sorted_data, data_size, original_idx, + ograd.dptr(), row_length, num_threads_per_row); +} + +inline void SparseEmbeddingOpBackwardDeterministicRspImpl(const OpContext& ctx, + const TBlob& ograd, + const TBlob& data, + const OpReqType req, + const NDArray& output) { + using nnvm::dim_t; + if (req == kNullOp) return; + CHECK_EQ(req, kWriteTo) << "SparseEmbedding layer doesn't support " + << "weight gradient calculation with req != write"; + + mshadow::Stream *s = ctx.get_stream(); + const dim_t data_size = static_cast(data.shape_.Size()); + if (data_size == 0) { + FillZerosRspImpl(s, output); + return; + } + + MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), RType, { + SparseEmbeddingDeterministicKernelLaunch(ctx, ograd, data, + req, output); + }); + }); + }); +} + template<> -inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, +inline void SparseEmbeddingOpBackwardRspImpl(const SparseEmbeddingParam& param, + const OpContext& ctx, const TBlob& ograd, const TBlob& data, const OpReqType req, const NDArray& output) { + if (param.deterministic) { + SparseEmbeddingOpBackwardDeterministicRspImpl(ctx, ograd, data, req, output); + return; + } using namespace mshadow; using namespace mxnet_op; using namespace mshadow::expr; @@ -156,7 +350,6 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, dim_t nnr = 0; CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t), cudaMemcpyDeviceToHost)); - if (nnr == 0) { FillZerosRspImpl(s, output); return; diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 1888a4179729..45bf45f14fcd 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -57,6 +57,29 @@ enum EmbeddingOpResource {kTempSpace}; } // namespace embedding +struct SparseEmbeddingParam: public dmlc::Parameter { + int input_dim; + int output_dim; + int dtype; + bool deterministic; + DMLC_DECLARE_PARAMETER(SparseEmbeddingParam) { + DMLC_DECLARE_FIELD(input_dim).set_lower_bound(1) + .describe("Vocabulary size of the input indices."); + DMLC_DECLARE_FIELD(output_dim).set_lower_bound(1) + .describe("Dimension of the embedding vectors."); + DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kFloat32) + .add_enum("float32", mshadow::kFloat32) + .add_enum("float64", mshadow::kFloat64) + .add_enum("float16", mshadow::kFloat16) + .add_enum("uint8", mshadow::kUint8) + .add_enum("int32", mshadow::kInt32) + .describe("Data type of weight."); + DMLC_DECLARE_FIELD(deterministic).set_default(false) + .describe("Force the backward gradient calculation to be executed based on a deterministic" + " order at the cost of slower speed."); + } +}; + struct EmbeddingParam: public dmlc::Parameter { int input_dim; int output_dim; @@ -130,14 +153,14 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, const mshadow::Tensor& index, const mshadow::Tensor &src, mshadow::Tensor* workspace = NULL); - +template inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { using namespace mshadow; const TShape &dshape = (*in_attrs)[embedding::kData]; if (dshape.ndim() == 0) return false; - const EmbeddingParam& param = nnvm::get(attrs.parsed); + const ParamType& param = nnvm::get(attrs.parsed); SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, Shape2(param.input_dim, param.output_dim)); out_attrs->clear(); @@ -152,10 +175,11 @@ inline bool EmbeddingOpShape(const nnvm::NodeAttrs& attrs, return true; } +template inline bool EmbeddingOpType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { - const EmbeddingParam& param = nnvm::get(attrs.parsed); + const ParamType& param = nnvm::get(attrs.parsed); CHECK_EQ(in_type->size(), 2U); CHECK_GE(out_type->size(), 1U); int itype = (*in_type)[0]; @@ -219,6 +243,11 @@ inline bool SparseEmbeddingOpBackwardStorageType(const nnvm::NodeAttrs& attrs, dispatched = true; } } + const SparseEmbeddingParam& param = nnvm::get(attrs.parsed); + if (param.deterministic) { + common::LogOnce("_SparseEmbedding_backward with deterministic=True may reduce " + "speed significantly"); + } return dispatched; } /*! \brief name the struct Take instead of take @@ -560,7 +589,8 @@ struct AddTakeGradRspKernel { }; template -inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, +inline void SparseEmbeddingOpBackwardRspImpl(const SparseEmbeddingParam& param, + const OpContext& ctx, const TBlob& ograd, const TBlob& data, const OpReqType req, @@ -582,9 +612,10 @@ void SparseEmbeddingOpBackwardEx(const nnvm::NodeAttrs& attrs, // check req CHECK_EQ(req[embedding::kData], kNullOp) << "SparseEmbedding layer doesn't support calculate data gradient"; + const SparseEmbeddingParam& param = nnvm::get(attrs.parsed); if (data.storage_type() == kDefaultStorage && ograd.storage_type() == kDefaultStorage && weight_grad.storage_type() == kRowSparseStorage) { - SparseEmbeddingOpBackwardRspImpl(ctx, ograd.data(), data.data(), + SparseEmbeddingOpBackwardRspImpl(param, ctx, ograd.data(), data.data(), req[embedding::kWeight], weight_grad); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 84dfc5878c20..e0d25da03449 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1615,43 +1615,45 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n): def test_sparse_embedding(): - ''' test sparse embedding op on cpu ''' - def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density): - # update weight based on density - weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density) - # check forward - executor.forward(is_train=True) - assert_almost_equal(executor.outputs[0].asnumpy(), np.dot(data_onehot, weight.asnumpy())) - # check backward - executor.backward([grad]) - assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(data_onehot.T, grad.asnumpy())) + ''' test sparse embedding operator ''' + def check_sparse_embedding(in_dim, out_dim, batch, densities, deterministic): + # init executor + data = mx.sym.Variable("data") + weight = mx.sym.Variable("embed_weight", stype='row_sparse') + embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim, + output_dim=out_dim, deterministic=deterministic, + name="embed") + grad_req = {'data': 'null', 'embed_weight': 'write'} + exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,)) + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + # init data + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_onehot = np.zeros((batch, in_dim)).astype(np.float32) + np_onehot[np.arange(batch), np_data] = 1.0 + arg_map["data"][:] = np_data + # init grad + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.zeros(np_grad.shape) + grad[:] = np_grad + # weight + weight = arg_map["embed_weight"] + for density in densities: + # update weight based on density + weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density) + # check forward + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, weight.asnumpy())) + # check backward + exe_test.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, grad.asnumpy())) densities = [0, 0.5, 1] in_dim = 50 out_dim = 3 batch = 8 - # init executor - data = mx.sym.Variable("data") - weight = mx.sym.Variable("embed_weight", stype='row_sparse') - embed = mx.sym.contrib.SparseEmbedding(data=data, weight=weight, input_dim=in_dim, - output_dim=out_dim, name="embed") - grad_req = {'data': 'null', 'embed_weight': 'write'} - exe_test = embed.simple_bind(default_context(), grad_req=grad_req, data=(batch,)) - arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) - grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) - # init data - np_data = np.random.randint(low=0, high=in_dim, size=batch) - np_onehot = np.zeros((batch, in_dim)) - np_onehot[np.arange(batch), np_data] = 1.0 - arg_map["data"][:] = np_data - # init grad - np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) - grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape) - grad[:] = np_grad - # weight - weight = arg_map["embed_weight"] - for density in densities: - check_sparse_embedding(exe_test, weight, np_onehot, grad, density) + check_sparse_embedding(in_dim, out_dim, batch, densities, True) + check_sparse_embedding(in_dim, out_dim, batch, densities, False) def test_scatter_ops():