From 9e80d58dbfdb94e81922434c819ca497fe01db93 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 21 Feb 2018 17:37:14 +0800 Subject: [PATCH 1/2] Fix embedding (#194) * refactor embed backward kernelcallker * pass unit test * refactor * fix dim bug * add unique impl * remove old op * remove unused kernel --- src/operator/tensor/indexing_op-inl.cuh | 136 ++++++++++-------- src/operator/tensor/indexing_op.cu | 109 ++++++++------ tests/python/unittest/test_sparse_operator.py | 4 +- 3 files changed, 144 insertions(+), 105 deletions(-) diff --git a/src/operator/tensor/indexing_op-inl.cuh b/src/operator/tensor/indexing_op-inl.cuh index 4458151f1782..4df1fd451ec5 100644 --- a/src/operator/tensor/indexing_op-inl.cuh +++ b/src/operator/tensor/indexing_op-inl.cuh @@ -38,7 +38,7 @@ namespace mxnet { namespace op { const int kWarpSize = 32; -template +template __global__ void AddTakeGradLargeBatchKernel(DType* dst, // If idx_start == NULL, then in-kernel edge // detection is used @@ -47,7 +47,9 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst, const int* idx_start_size_ptr, const IdxType *sorted, const IdxType *index, const DType *src, - int ymax, int xmax) { + int ymax, int xmax, + // table to look up positions of row_ids in dst + const nnvm::dim_t *lookup_table) { // Size of the shared memory is [blockDim.x*SZ*blockDim.y]*sizeof(DType) extern __shared__ char sh_grad_weight_char[]; DType* sh_grad_weight = (DType*)sh_grad_weight_char; @@ -125,7 +127,8 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst, } const int start_feature = threadIdx.x + blockIdx.x * blockDim.x * SZ; - const int dst_row = sorted_value * xmax; + // Lookup inclusive prefix sum table if necessary + const int dst_row = (lookup ? (lookup_table[sorted_value] - 1) : sorted_value) * xmax; int num_idx = idx_end - idx_begin; int idx0 = idx_begin + threadIdx.y*num_idx/blockDim.y; @@ -179,7 +182,6 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst, } } } - } } @@ -199,6 +201,73 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) { return (unique_bytes + counts_bytes + num_runs_bytes + temporary_bytes); } +template +inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor dst, + const mshadow::Tensor& sorted, + const mshadow::Tensor& index, + const mshadow::Tensor &src, + IndexType* sum_counts_ptr, + int* num_runs_ptr, + const nnvm::dim_t* lookup_table) { + cudaStream_t stream = mshadow::Stream::GetStream(dst.stream_); + const int num_unique_est = min(dst.size(0), src.size(0)); + const int max_nthread = 128; + const int num_y = max(src.size(0)/num_unique_est, 1); + const int block_dim_x = kWarpSize; + const int block_dim_y = min(num_y, max_nthread/block_dim_x); + const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4); + const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ); + const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum); + dim3 dimBlock(block_dim_x, block_dim_y); + dim3 dimGrid(grid_dim_x, grid_dim_y); + // Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit DType elements + int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType); + + CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch"; + CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch"; + mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch"); + + switch (SZ) { + case 1: + AddTakeGradLargeBatchKernel<1, lookup, DType> + <<>> + (dst.dptr_, sum_counts_ptr, num_runs_ptr, + sorted.dptr_, index.dptr_, src.dptr_, + static_cast(src.size(0)), + static_cast(src.size(1)), lookup_table); + break; + case 2: + AddTakeGradLargeBatchKernel<2, lookup, DType> + <<>> + (dst.dptr_, sum_counts_ptr, num_runs_ptr, + sorted.dptr_, index.dptr_, src.dptr_, + static_cast(src.size(0)), + static_cast(src.size(1)), lookup_table); + break; + case 3: + AddTakeGradLargeBatchKernel<3, lookup, DType> + <<>> + (dst.dptr_, sum_counts_ptr, num_runs_ptr, + sorted.dptr_, index.dptr_, src.dptr_, + static_cast(src.size(0)), + static_cast(src.size(1)), lookup_table); + break; + case 4: + AddTakeGradLargeBatchKernel<4, lookup, DType> + <<>> + (dst.dptr_, sum_counts_ptr, num_runs_ptr, + sorted.dptr_, index.dptr_, src.dptr_, + static_cast(src.size(0)), + static_cast(src.size(1)), lookup_table); + break; + default: + LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ; + break; + } + MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel); +} + + template inline void AddTakeGradLargeBatch(mshadow::Tensor dst, const mshadow::Tensor& sorted, @@ -249,62 +318,9 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor dst, (temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr, sorted.size(0), stream); } - - const int num_unique_est = min(dst.size(0), src.size(0)); - const int max_nthread = 128; - const int num_y = max(src.size(0)/num_unique_est, 1); - const int block_dim_x = kWarpSize; - const int block_dim_y = min(num_y, max_nthread/block_dim_x); - const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4); - const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ); - const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum); - dim3 dimBlock(block_dim_x, block_dim_y); - dim3 dimGrid(grid_dim_x, grid_dim_y); - // Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit DType elements - int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType); - - CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch"; - CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch"; - mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch"); - - switch (SZ) { - case 1: - AddTakeGradLargeBatchKernel<1, DType> - <<>> - (dst.dptr_, sum_counts_ptr, num_runs_ptr, - sorted.dptr_, index.dptr_, src.dptr_, - static_cast(src.size(0)), - static_cast(src.size(1))); - break; - case 2: - AddTakeGradLargeBatchKernel<2, DType> - <<>> - (dst.dptr_, sum_counts_ptr, num_runs_ptr, - sorted.dptr_, index.dptr_, src.dptr_, - static_cast(src.size(0)), - static_cast(src.size(1))); - break; - case 3: - AddTakeGradLargeBatchKernel<3, DType> - <<>> - (dst.dptr_, sum_counts_ptr, num_runs_ptr, - sorted.dptr_, index.dptr_, src.dptr_, - static_cast(src.size(0)), - static_cast(src.size(1))); - break; - case 4: - AddTakeGradLargeBatchKernel<4, DType> - <<>> - (dst.dptr_, sum_counts_ptr, num_runs_ptr, - sorted.dptr_, index.dptr_, src.dptr_, - static_cast(src.size(0)), - static_cast(src.size(1))); - break; - default: - LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ; - break; - } - MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel); + nnvm::dim_t* lookup_table = nullptr; + AddTakeGradLargeBatchKernelLaunch(dst, sorted, index, src, sum_counts_ptr, + num_runs_ptr, lookup_table); } } // namespace op diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 762d8fd64c2b..87633efaff84 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -41,25 +41,6 @@ struct is_valid_check { } }; - -struct AddTakeGradRspGPUKernel { - template - __device__ __forceinline__ static void Map(int tid, - DType* out, - const nnvm::dim_t* prefix_sum, - const IType* data, - const DType* ograd, - const nnvm::dim_t row_length) { - using nnvm::dim_t; - const dim_t data_i = tid / row_length; - const dim_t grad_i = tid % row_length; - const dim_t irow = static_cast(data[data_i]); - const dim_t rsp_row = prefix_sum[irow] - 1; - const DType val = ograd[data_i * row_length + grad_i]; - atomicAdd(static_cast(&(out[rsp_row*row_length+grad_i])), val); - } -}; - template<> void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, const TBlob& data, @@ -103,7 +84,6 @@ void SparseEmbeddingOpForwardRspImpl(const OpContext& ctx, } } - template<> inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, const TBlob& ograd, @@ -125,55 +105,98 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, dim_t row_length = output.shape()[1]; dim_t data_size = static_cast(data.shape_.Size()); dim_t num_threads; - + if (data_size == 0) { + FillZerosRspImpl(s, output); + return; + } MSHADOW_TYPE_SWITCH(data.type_flag_, IType, { - MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, { MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, { dim_t* prefix_sum = NULL; - void* d_temp_storage = NULL; - size_t temp_storage_bytes = 0; - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, + void* temp_storage = NULL; + dim_t* sorted_data = NULL; + dim_t* original_idx = NULL; + // calculate resource bytes + size_t row_flg_storage_bytes = num_rows * sizeof(dim_t); + size_t sorted_data_storage_bytes = data_size * sizeof(dim_t); + size_t original_idx_storage_bytes = data_size * sizeof(dim_t); + size_t sum_workspace_bytes = 0; + size_t sort_workspace_size = SortByKeyWorkspaceSize(data_size); + cub::DeviceScan::InclusiveSum(temp_storage, + sum_workspace_bytes, prefix_sum, prefix_sum, num_rows, Stream::GetStream(s)); + // temp_workspace is shared by inclusive sum and sort + size_t temp_workspace_bytes = std::max(sum_workspace_bytes, sort_workspace_size); + size_t total_storage_bytes = row_flg_storage_bytes + sorted_data_storage_bytes + + original_idx_storage_bytes + temp_workspace_bytes; + + // request resource and split it. layout = + // row_flg/prefixsum, sorted_data, original_idx, temp_storage Tensor workspace = ctx.requested[0] - .get_space_typed(Shape1(num_rows * sizeof(dim_t) + - temp_storage_bytes), s); + .get_space_typed(Shape1(total_storage_bytes), s); prefix_sum = reinterpret_cast(workspace.dptr_); - d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t); + sorted_data = reinterpret_cast(workspace.dptr_ + row_flg_storage_bytes); + original_idx = reinterpret_cast(workspace.dptr_ + row_flg_storage_bytes + + sorted_data_storage_bytes); + temp_storage = workspace.dptr_ + total_storage_bytes - temp_workspace_bytes; + // compute row flags and prefix sum num_threads = num_rows; Fill(s, TBlob(prefix_sum, Shape1(num_threads), gpu::kDevMask), kWriteTo, 0); Kernel::Launch(s, data_size, prefix_sum, data.dptr()); - - cub::DeviceScan::InclusiveSum(d_temp_storage, - temp_storage_bytes, + cub::DeviceScan::InclusiveSum(temp_storage, + temp_workspace_bytes, prefix_sum, prefix_sum, num_rows, mshadow::Stream::GetStream(s)); + // retrieve nnr and allocate output dim_t nnr = 0; CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t), cudaMemcpyDeviceToHost)); - - if (nnr == 0) { - FillZerosRspImpl(s, output); - return; - } output.CheckAndAlloc({Shape1(nnr)}); - RType* grad_row_idx = output.aux_data(kIdx).dptr(); // fill row_idx array of output matrix, using the row_flg values + RType* grad_row_idx = output.aux_data(kIdx).dptr(); Kernel::Launch(s, num_rows, grad_row_idx, prefix_sum, num_rows); - // prefill with zeros + + // make a copy of the data, to be sorted + TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask); + auto sorted_data_tensor = sorted_data_blob.FlatTo1D(s); + mxnet_op::copy(s, sorted_data_blob, data); + + // generate original idx + Tensor original_idx_tensor(original_idx, Shape1(data_size), s); + Kernel::Launch(s, data_size, 1, static_cast(0), static_cast(1), + kWriteTo, original_idx); + // sort data with its original idx + int num_bits = ilog2(num_rows - 1); + char* temp_storage_ptr = reinterpret_cast(temp_storage); + Tensor temp_storage_tensor(temp_storage_ptr, + Shape1(sort_workspace_size), s); + SortByKey(sorted_data_tensor, original_idx_tensor, true, + &temp_storage_tensor, 0, num_bits); + // accumulate gradients DType* grad_data = output.data().dptr(); Fill(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask), kWriteTo, 0); - // add the final gradients - num_threads = row_length * data_size; - Kernel::Launch(s, num_threads, grad_data, prefix_sum, - data.dptr(), ograd.dptr(), row_length); + + // reuse dense op backward kernel + { + dim_t* sum_counts_ptr = NULL; + int* num_runs_ptr = NULL; + mshadow::Tensor dst = output.data().get(s); + mshadow::Tensor sorted = sorted_data_tensor; + mshadow::Tensor index = original_idx_tensor; + const auto oshape = ograd.shape_; + mshadow::Tensor src = ograd.get_with_shape( + Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s); + nnvm::dim_t* lookup_table = prefix_sum; + AddTakeGradLargeBatchKernelLaunch(dst, sorted, index, src, sum_counts_ptr, + num_runs_ptr, lookup_table); + } }); }); }); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 54809d9419d3..9441cc7f64ef 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1634,7 +1634,7 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n): @with_seed() def test_sparse_embedding(): - ''' test sparse embedding op on cpu ''' + ''' test sparse embedding operator ''' def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density): # update weight based on density weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density) @@ -1665,7 +1665,7 @@ def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density): arg_map["data"][:] = np_data # init grad np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) - grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape) + grad = mx.nd.zeros(np_grad.shape) grad[:] = np_grad # weight weight = arg_map["embed_weight"] From c0b42ec71467cc60a8613fcd4d96bc4d910824f8 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Wed, 21 Feb 2018 17:54:57 +0800 Subject: [PATCH 2/2] fix lint --- src/operator/tensor/indexing_op.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 87633efaff84..223a12303650 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -169,8 +169,8 @@ inline void SparseEmbeddingOpBackwardRspImpl(const OpContext& ctx, // generate original idx Tensor original_idx_tensor(original_idx, Shape1(data_size), s); - Kernel::Launch(s, data_size, 1, static_cast(0), static_cast(1), - kWriteTo, original_idx); + Kernel::Launch(s, data_size, 1, static_cast(0), + static_cast(1), kWriteTo, original_idx); // sort data with its original idx int num_bits = ilog2(num_rows - 1); char* temp_storage_ptr = reinterpret_cast(temp_storage);