Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 76 additions & 60 deletions src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace mxnet {
namespace op {
const int kWarpSize = 32;

template<int SZ, typename DType, typename IdxType>
template<int SZ, bool lookup, typename DType, typename IdxType>
__global__ void AddTakeGradLargeBatchKernel(DType* dst,
// If idx_start == NULL, then in-kernel edge
// detection is used
Expand All @@ -47,7 +47,9 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
const int* idx_start_size_ptr,
const IdxType *sorted, const IdxType *index,
const DType *src,
int ymax, int xmax) {
int ymax, int xmax,
// table to look up positions of row_ids in dst
const nnvm::dim_t *lookup_table) {
// Size of the shared memory is [blockDim.x*SZ*blockDim.y]*sizeof(DType)
extern __shared__ char sh_grad_weight_char[];
DType* sh_grad_weight = (DType*)sh_grad_weight_char;
Expand Down Expand Up @@ -125,7 +127,8 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
}

const int start_feature = threadIdx.x + blockIdx.x * blockDim.x * SZ;
const int dst_row = sorted_value * xmax;
// Lookup inclusive prefix sum table if necessary
const int dst_row = (lookup ? (lookup_table[sorted_value] - 1) : sorted_value) * xmax;

int num_idx = idx_end - idx_begin;
int idx0 = idx_begin + threadIdx.y*num_idx/blockDim.y;
Expand Down Expand Up @@ -179,7 +182,6 @@ __global__ void AddTakeGradLargeBatchKernel(DType* dst,
}
}
}

}
}

Expand All @@ -199,6 +201,73 @@ AddTakeGradLargeBatchWorkspaceSize(size_t num_keys) {
return (unique_bytes + counts_bytes + num_runs_bytes + temporary_bytes);
}

template<bool lookup, typename IndexType, typename DType>
inline void AddTakeGradLargeBatchKernelLaunch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
IndexType* sum_counts_ptr,
int* num_runs_ptr,
const nnvm::dim_t* lookup_table) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(dst.stream_);
const int num_unique_est = min(dst.size(0), src.size(0));
const int max_nthread = 128;
const int num_y = max(src.size(0)/num_unique_est, 1);
const int block_dim_x = kWarpSize;
const int block_dim_y = min(num_y, max_nthread/block_dim_x);
const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4);
const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ);
const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum);
dim3 dimBlock(block_dim_x, block_dim_y);
dim3 dimGrid(grid_dim_x, grid_dim_y);
// Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit DType elements
int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType);

CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch";
CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch";
mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch");

switch (SZ) {
case 1:
AddTakeGradLargeBatchKernel<1, lookup, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)), lookup_table);
break;
case 2:
AddTakeGradLargeBatchKernel<2, lookup, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)), lookup_table);
break;
case 3:
AddTakeGradLargeBatchKernel<3, lookup, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)), lookup_table);
break;
case 4:
AddTakeGradLargeBatchKernel<4, lookup, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)), lookup_table);
break;
default:
LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ;
break;
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
}


template<typename IndexType, typename DType>
inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
Expand Down Expand Up @@ -249,62 +318,9 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
(temporary_storage, temporary_bytes, counts_out_ptr, sum_counts_ptr,
sorted.size(0), stream);
}

const int num_unique_est = min(dst.size(0), src.size(0));
const int max_nthread = 128;
const int num_y = max(src.size(0)/num_unique_est, 1);
const int block_dim_x = kWarpSize;
const int block_dim_y = min(num_y, max_nthread/block_dim_x);
const int SZ = min((src.size(1) + block_dim_x - 1) / block_dim_x, 4);
const int grid_dim_x = (src.size(1) + block_dim_x * SZ - 1) / (block_dim_x * SZ);
const int grid_dim_y = min(num_unique_est, mshadow::cuda::kBaseGridNum);
dim3 dimBlock(block_dim_x, block_dim_y);
dim3 dimGrid(grid_dim_x, grid_dim_y);
// Maximum shared memory usage: 128*4*sizeof(DType), which is 4K for 64bit DType elements
int shmem_size = dimBlock.x*SZ*dimBlock.y*sizeof(DType);

CHECK_EQ(dst.size(1), src.size(1)) << "AddTakeGradLargeBatch: shape mismatch";
CHECK_EQ(index.size(0), src.size(0)) << "AddTakeGradLargeBatch: shape mismatch";
mshadow::cuda::CheckLaunchParam(dimGrid, dimBlock, "AddTakeGradLargeBatch");

switch (SZ) {
case 1:
AddTakeGradLargeBatchKernel<1, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)));
break;
case 2:
AddTakeGradLargeBatchKernel<2, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)));
break;
case 3:
AddTakeGradLargeBatchKernel<3, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)));
break;
case 4:
AddTakeGradLargeBatchKernel<4, DType>
<<<dimGrid, dimBlock, shmem_size, stream>>>
(dst.dptr_, sum_counts_ptr, num_runs_ptr,
sorted.dptr_, index.dptr_, src.dptr_,
static_cast<int>(src.size(0)),
static_cast<int>(src.size(1)));
break;
default:
LOG(FATAL) << "AddTakeGradLargeBatch, incorrect value SZ " << SZ;
break;
}
MSHADOW_CUDA_POST_KERNEL_CHECK(AddTakeGradLargeBatchKernel);
nnvm::dim_t* lookup_table = nullptr;
AddTakeGradLargeBatchKernelLaunch<false>(dst, sorted, index, src, sum_counts_ptr,
num_runs_ptr, lookup_table);
}

} // namespace op
Expand Down
109 changes: 66 additions & 43 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,6 @@ struct is_valid_check {
}
};


struct AddTakeGradRspGPUKernel {
template<typename DType, typename IType>
__device__ __forceinline__ static void Map(int tid,
DType* out,
const nnvm::dim_t* prefix_sum,
const IType* data,
const DType* ograd,
const nnvm::dim_t row_length) {
using nnvm::dim_t;
const dim_t data_i = tid / row_length;
const dim_t grad_i = tid % row_length;
const dim_t irow = static_cast<dim_t>(data[data_i]);
const dim_t rsp_row = prefix_sum[irow] - 1;
const DType val = ograd[data_i * row_length + grad_i];
atomicAdd(static_cast<DType *>(&(out[rsp_row*row_length+grad_i])), val);
}
};

template<>
void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
const TBlob& data,
Expand Down Expand Up @@ -103,7 +84,6 @@ void SparseEmbeddingOpForwardRspImpl<gpu>(const OpContext& ctx,
}
}


template<>
inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
const TBlob& ograd,
Expand All @@ -125,55 +105,98 @@ inline void SparseEmbeddingOpBackwardRspImpl<gpu>(const OpContext& ctx,
dim_t row_length = output.shape()[1];
dim_t data_size = static_cast<dim_t>(data.shape_.Size());
dim_t num_threads;

if (data_size == 0) {
FillZerosRspImpl(s, output);
return;
}
MSHADOW_TYPE_SWITCH(data.type_flag_, IType, {
MSHADOW_SGL_DBL_TYPE_SWITCH(ograd.type_flag_, DType, {
MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), RType, {
dim_t* prefix_sum = NULL;
void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
void* temp_storage = NULL;
dim_t* sorted_data = NULL;
dim_t* original_idx = NULL;
// calculate resource bytes
size_t row_flg_storage_bytes = num_rows * sizeof(dim_t);
size_t sorted_data_storage_bytes = data_size * sizeof(dim_t);
size_t original_idx_storage_bytes = data_size * sizeof(dim_t);
size_t sum_workspace_bytes = 0;
size_t sort_workspace_size = SortByKeyWorkspaceSize<dim_t, dim_t, gpu>(data_size);
cub::DeviceScan::InclusiveSum(temp_storage,
sum_workspace_bytes,
prefix_sum,
prefix_sum,
num_rows,
Stream<gpu>::GetStream(s));
// temp_workspace is shared by inclusive sum and sort
size_t temp_workspace_bytes = std::max(sum_workspace_bytes, sort_workspace_size);
size_t total_storage_bytes = row_flg_storage_bytes + sorted_data_storage_bytes +
original_idx_storage_bytes + temp_workspace_bytes;

// request resource and split it. layout =
// row_flg/prefixsum, sorted_data, original_idx, temp_storage
Tensor<gpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<gpu, 1, char>(Shape1(num_rows * sizeof(dim_t) +
temp_storage_bytes), s);
.get_space_typed<gpu, 1, char>(Shape1(total_storage_bytes), s);
prefix_sum = reinterpret_cast<dim_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + num_rows*sizeof(dim_t);
sorted_data = reinterpret_cast<dim_t*>(workspace.dptr_ + row_flg_storage_bytes);
original_idx = reinterpret_cast<dim_t*>(workspace.dptr_ + row_flg_storage_bytes +
sorted_data_storage_bytes);
temp_storage = workspace.dptr_ + total_storage_bytes - temp_workspace_bytes;
// compute row flags and prefix sum
num_threads = num_rows;
Fill<false>(s, TBlob(prefix_sum, Shape1(num_threads), gpu::kDevMask), kWriteTo, 0);
Kernel<MarkRowFlgKernel, gpu>::Launch(s, data_size, prefix_sum, data.dptr<IType>());

cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
cub::DeviceScan::InclusiveSum(temp_storage,
temp_workspace_bytes,
prefix_sum,
prefix_sum,
num_rows,
mshadow::Stream<gpu>::GetStream(s));
// retrieve nnr and allocate output
dim_t nnr = 0;
CUDA_CALL(cudaMemcpy(&nnr, &prefix_sum[num_rows-1], sizeof(dim_t),
cudaMemcpyDeviceToHost));

if (nnr == 0) {
FillZerosRspImpl(s, output);
return;
}
output.CheckAndAlloc({Shape1(nnr)});
RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
// fill row_idx array of output matrix, using the row_flg values
RType* grad_row_idx = output.aux_data(kIdx).dptr<RType>();
Kernel<FillRspRowIdxKernel, gpu>::Launch(s, num_rows,
grad_row_idx, prefix_sum, num_rows);
// prefill with zeros

// make a copy of the data, to be sorted
TBlob sorted_data_blob(sorted_data, Shape1(data_size), gpu::kDevMask);
auto sorted_data_tensor = sorted_data_blob.FlatTo1D<gpu, dim_t>(s);
mxnet_op::copy(s, sorted_data_blob, data);

// generate original idx
Tensor<gpu, 1, dim_t> original_idx_tensor(original_idx, Shape1(data_size), s);
Kernel<range_fwd, gpu>::Launch(s, data_size, 1, static_cast<dim_t>(0),
static_cast<dim_t>(1), kWriteTo, original_idx);
// sort data with its original idx
int num_bits = ilog2(num_rows - 1);
char* temp_storage_ptr = reinterpret_cast<char*>(temp_storage);
Tensor<gpu, 1, char> temp_storage_tensor(temp_storage_ptr,
Shape1(sort_workspace_size), s);
SortByKey(sorted_data_tensor, original_idx_tensor, true,
&temp_storage_tensor, 0, num_bits);
// accumulate gradients
DType* grad_data = output.data().dptr<DType>();
Fill<false>(s, TBlob(grad_data, Shape1(nnr * row_length), gpu::kDevMask),
kWriteTo, 0);
// add the final gradients
num_threads = row_length * data_size;
Kernel<AddTakeGradRspGPUKernel, gpu>::Launch(s, num_threads, grad_data, prefix_sum,
data.dptr<IType>(), ograd.dptr<DType>(), row_length);

// reuse dense op backward kernel
{
dim_t* sum_counts_ptr = NULL;
int* num_runs_ptr = NULL;
mshadow::Tensor<gpu, 2, DType> dst = output.data().get<gpu, 2, DType>(s);
mshadow::Tensor<gpu, 1, dim_t> sorted = sorted_data_tensor;
mshadow::Tensor<gpu, 1, dim_t> index = original_idx_tensor;
const auto oshape = ograd.shape_;
mshadow::Tensor<gpu, 2, DType> src = ograd.get_with_shape<gpu, 2, DType>(
Shape2(oshape.ProdShape(0, oshape.ndim()-1), oshape[oshape.ndim()-1]), s);
nnvm::dim_t* lookup_table = prefix_sum;
AddTakeGradLargeBatchKernelLaunch<true>(dst, sorted, index, src, sum_counts_ptr,
num_runs_ptr, lookup_table);
}
});
});
});
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n):

@with_seed()
def test_sparse_embedding():
''' test sparse embedding op on cpu '''
''' test sparse embedding operator '''
def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
# update weight based on density
weight[:] = rand_ndarray(weight.shape, 'row_sparse', density=density)
Expand Down Expand Up @@ -1665,7 +1665,7 @@ def check_sparse_embedding(executor, weight_ref, data_onehot, grad, density):
arg_map["data"][:] = np_data
# init grad
np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape)
grad = mx.nd.sparse.zeros('row_sparse', np_grad.shape)
grad = mx.nd.zeros(np_grad.shape)
grad[:] = np_grad
# weight
weight = arg_map["embed_weight"]
Expand Down