From aeab5d9894db69e434da7d31d0576ed8fe5543f4 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 2 Apr 2018 17:49:39 +0000 Subject: [PATCH 1/7] add support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU --- src/operator/tensor/dot-inl.cuh | 219 ++++++++++++++++++++++++++++++++ src/operator/tensor/dot-inl.h | 121 +++++------------- src/operator/tensor/dot.cc | 99 +++++++++++++++ src/operator/tensor/dot.cu | 1 + 4 files changed, 349 insertions(+), 91 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index c546c4351a28..fcf9440a77d2 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -442,6 +442,100 @@ struct DotCsrRspDnsScalarKernel { } }; +/*! + * \brief GPU Kernel to re-arrange nnz elements to csc order + * Parallelization by output elements: 1 thread/row of csr + */ +struct CscDataIndicesKernel { + template + __device__ __forceinline__ static void Map(int tid, + const DType* csr_data, + const IType* csr_indices, + const CType* csr_indptr, + DType* csc_data, + int* csc_indices, + int* csc_indptr, + int* workspace, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + if (tid < num_rows) { + printf("%d, %d ", tid, csc_indptr[tid]); + for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { + // target column + IType target_col = csr_indices[i]; + int target_offset = atomicAdd(&workspace[target_col], 1); + int new_pos = csc_indptr[target_col] + target_offset; + csc_data[new_pos] = csr_data[i]; + csc_indices[new_pos] = tid; + } + } + } +}; + +/*! + * \brief GPU Kernel of getting count for every column + * Parallelization by output elements: 1 thread/element + */ +struct CsrTransHistogramKernel { + /*! + * \brief + * \param tid global thread id + * \param in_indices csr matrix column indices + * \param out_indptr csr matrix row pointer + * \param nnz number of non-zero elements in csr + */ + template + __device__ __forceinline__ static void Map(int tid, + const IType* in_indices, + int* out_indptr, + const nnvm::dim_t nnz) { + if (tid < nnz) { + atomicAdd(&out_indptr[in_indices[tid]], 1); + } + } +}; + +/*! + * \brief GPU Kernel of dot(dns, csr.T) = dns + * Parallelization by output elements: 1 thread/element + */ +struct DotDnsCsrTransDnsKernel { + /*! + * \brief + * \param tid global thread id + * \param lhs_data lhs dense matrix data + * \param rhs_data csr matrix data + * \param rhs_indices csr matrix column indices + * \param rhs_indptr csr matrix row pointer + * \param out output matrix data + * \param lhs_num_cols lhs dns matrix number of columns + * \param out_num_rows output dns matrix number of rows + * \param out_num_cols output dns matrix number of columns + */ + template + __device__ __forceinline__ static void Map(int tid, + const DType* lhs_data, + const DType* rhs_data, + const IType* rhs_indices, + const CType* rhs_indptr, + DType* out, + const nnvm::dim_t lhs_num_cols, + const nnvm::dim_t out_num_rows, + const nnvm::dim_t out_num_cols) { + using nnvm::dim_t; + if (tid < out_num_rows*out_num_cols) { + const dim_t i = static_cast(tid) / out_num_cols; // i = row this thread computes + const dim_t k = static_cast(tid) % out_num_cols; // k = col this thread computes + // Compute inner product of i-th row and k-th col + DType sum = 0; + for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) { + sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id]; + } + out[tid] = sum; + } + } +}; + /*! * \brief GPU Impl of dot(csr, dns1) = dns2 and dot(csr.T, dns1) = dns2 */ @@ -895,6 +989,131 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, }); } +/* + * \brief GPU Impl of dot(dns, csr) = csr + */ +template +inline void DotDnsCsrCsrImpl(const OpContext& ctx, + const TBlob& lhs, const NDArray& rhs, + const OpReqType req, NDArray* ret) { + LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU"; +} + +/* + * \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns + */ +template +inline void DotDnsCsrDnsImpl(const OpContext& ctx, + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret, + const bool transpose_b) { + CHECK_EQ(req, kWriteTo); + CHECK_EQ(rhs.storage_type(), kCSRStorage); + + using namespace mshadow; + using namespace mshadow::expr; + using nnvm::dim_t; + + /* Initialize data structures */ + mshadow::Stream* s = ctx.get_stream(); + TBlob csr_data = rhs.data(); + TBlob csr_indices = rhs.aux_data(csr::kIdx); + TBlob csr_indptr = rhs.aux_data(csr::kIndPtr); + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + return; + } + + // if dot(dense, csr) = dns, transform to csc first + if (!transpose_b) { + // LOG(FATAL) << "dot(dns, csr) = dns not implemented yet"; + const nnvm::dim_t csr_rows = rhs.shape()[0]; + const nnvm::dim_t csr_cols = rhs.shape()[1]; + const nnvm::dim_t dns_rows = dns.shape_[0]; + const nnvm::dim_t nnz = rhs.storage_shape().Size(); + + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { + DType* csc_data_ptr = NULL; + int* csc_indices_ptr = NULL; + int* csc_indptr_ptr = NULL; + int* col_counters = NULL; + void* temp_storage = NULL; + size_t temp_storage_bytes = 0; + CType out_num_rows = ret->shape()[0]; + CType out_num_cols = ret->shape()[1]; + // Get necessary temporary storage amount + cub::DeviceScan::ExclusiveSum(NULL, + temp_storage_bytes, + csc_indices_ptr, + csc_indices_ptr, + csr_cols+1, + Stream::GetStream(s)); + Tensor workspace = + ctx.requested[0].get_space_typed( + Shape1(nnz*sizeof(DType) + nnz*sizeof(int) + + (csr_cols + 1)*sizeof(int) + + (csr_cols + 1)*sizeof(int) + + temp_storage_bytes), + s); + csc_data_ptr = reinterpret_cast(workspace.dptr_); + csc_indices_ptr = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType)); + csc_indptr_ptr = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + + nnz*sizeof(int)); + col_counters = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + + nnz*sizeof(int) + (csr_cols+1)*sizeof(int)); + temp_storage = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + + nnz*sizeof(int) + (csr_cols+1)*sizeof(int) + + (csr_cols + 1)*sizeof(int)); + mxnet_op::Kernel::Launch( + s, dns_rows*csr_cols, ret->data().dptr()); + // Reset values for indptr, ready for histogramming + mxnet_op::Kernel::Launch( + s, csr_cols + 1, csc_indptr_ptr); + // Histogramming on col id + mxnet_op::Kernel::Launch( + s, nnz, csr_indices.dptr(), csc_indptr_ptr, nnz); + cub::DeviceScan::ExclusiveSum(temp_storage, + temp_storage_bytes, + csc_indptr_ptr, + csc_indptr_ptr, + csr_cols+1, + Stream::GetStream(s)); + // Reset values for col_counter, ready for the final transform + mxnet_op::Kernel::Launch( + s, csr_cols+1, col_counters); + // Transform to CSC + mxnet_op::Kernel::Launch( + s, csr_rows, csr_data.dptr(), csr_indices.dptr(), + csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, + csc_indptr_ptr, col_counters, csr_rows, csr_cols); + + mxnet_op::Kernel::Launch( + s, out_num_rows * out_num_cols, dns.dptr(), + csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, + ret->data().dptr(), dns.shape_[1], + out_num_rows, out_num_cols); + }); + }); + }); + } else { + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type + CType out_num_rows = ret->shape()[0]; + CType out_num_cols = ret->shape()[1]; + mxnet_op::Kernel::Launch( + s, out_num_rows * out_num_cols, dns.dptr(), + csr_data.dptr(), csr_indices.dptr(), + csr_indptr.dptr(), ret->data().dptr(), + dns.shape_[1], out_num_rows, out_num_cols); + }); + }); + }); + } +} + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 83571d9e4d2c..bf0c59150dcb 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -236,13 +236,21 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, DispatchMode::kFComputeEx); } if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && - !param.transpose_a && !param.transpose_b) { + !param.transpose_a) { // dns, csr -> csr - const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; - const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback - : DispatchMode::kFComputeEx; - dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - dispatch_ex); + // const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; + // const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback + // : DispatchMode::kFComputeEx; + // dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + // dispatch_ex); + if (dev_mask == mshadow::cpu::kDevMask) { + CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } else if (dev_mask == mshadow::gpu::kDevMask) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); @@ -897,94 +905,21 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, } /* - * \brief CPU Impl of dot(dns, csr) = csr + * \brief Impl of dot(dns, csr) = csr */ template inline void DotDnsCsrCsrImpl(const OpContext& ctx, - const TBlob& lhs, const NDArray& rhs, - const OpReqType req, NDArray* ret) { - if (kNullOp == req) return; - - CHECK_EQ(req, kWriteTo); - CHECK_EQ(rhs.storage_type(), kCSRStorage); - - using namespace mshadow; - using namespace mshadow::expr; - using nnvm::dim_t; - - /* Initialize data structures */ - mshadow::Stream* s = ctx.get_stream(); - const NDArray& out = *ret; - const TBlob data_l = lhs; - const TBlob data_r = rhs.data(); - const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); - const TBlob col_idx_r = rhs.aux_data(csr::kIdx); - if (!rhs.storage_initialized()) { - FillZerosCsrImpl(s, *ret); - return; - } - - MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type - /* Allocate workspace */ - CType num_cols_out = out.shape()[1]; - CType rhs_data_size = static_cast(col_idx_r.shape_.Size()); - size_t workspace_size = 2 * num_cols_out * sizeof(CType); - Tensor workspace = - ctx.requested[0].get_space_typed( - Shape1(workspace_size), s); - CType* col_flg = reinterpret_cast(workspace.dptr_); + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret); - CType* prefix_sum = col_flg; - CType* nnc_idx = prefix_sum + num_cols_out; - - /* Set the column flags for nnz columns */ - mxnet_op::Kernel::Launch(s, num_cols_out, - col_flg); - mxnet_op::Kernel::Launch( - s, rhs_data_size, col_flg, col_idx_r.dptr()); - - /* 1. Calculate prefix sum from col flgs - * 2. Storage all non zero column indexes in nnc_idx - */ - CType cur = 0; - prefix_sum[0] = col_flg[0]; - if (prefix_sum[0]) nnc_idx[cur++] = 0; - for (CType i = 1; i < num_cols_out; i++) { - prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; - if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; - } - - /* Allocate aux data for out */ - IType num_rows_l = lhs.shape_[0]; - dim_t nnc = prefix_sum[num_cols_out - 1]; - dim_t nnz = nnc * num_rows_l; - out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); - out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); - out.CheckAndAllocData(Shape1(nnz)); - - /* Set csr indptr and index according to nnc_idx*/ - IType* indptr_out = out.aux_data(csr::kIndPtr).dptr(); - CType* col_idx_out = out.aux_data(csr::kIdx).dptr(); - DType* data_out = out.data().dptr(); - mxnet_op::Kernel::Launch( - s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); - mxnet_op::Kernel::Launch(s, nnz, data_out); - - const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); - const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; - - IType num_rows_r = rhs.shape()[0]; - mxnet_op::Kernel::Launch( - s, num_threads, data_out, data_l.dptr(), - indptr_r.dptr(), col_idx_r.dptr(), - data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, - nnc, prefix_sum); - }); - }); - }); -} +/* + * \brief Impl of dot(dns, csr) = dense (GPU only) + */ +template +inline void DotDnsCsrDnsImpl(const OpContext& ctx, + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret, + const bool transpose_b); inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -1039,7 +974,7 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; + // CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs"; CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs"; auto lhs_stype = inputs[0].storage_type(); @@ -1066,6 +1001,10 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, !(param.transpose_a || param.transpose_b)) { NDArray ret = outputs[0]; DotDnsCsrCsrImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret); + } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + out_stype == kDefaultStorage && !(param.transpose_a)) { + NDArray ret = outputs[0]; + DotDnsCsrDnsImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 834b559b86f6..f563d7411d66 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -28,6 +28,105 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(DotParam); +/* + * \brief CPU Impl of dot(dns, csr) = csr + */ +template +inline void DotDnsCsrCsrImpl(const OpContext& ctx, + const TBlob& lhs, const NDArray& rhs, + const OpReqType req, NDArray* ret) { + if (kNullOp == req) return; + + CHECK_EQ(req, kWriteTo); + CHECK_EQ(rhs.storage_type(), kCSRStorage); + + using namespace mshadow; + using namespace mshadow::expr; + using nnvm::dim_t; + + /* Initialize data structures */ + mshadow::Stream* s = ctx.get_stream(); + const NDArray& out = *ret; + const TBlob data_l = lhs; + const TBlob data_r = rhs.data(); + const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); + const TBlob col_idx_r = rhs.aux_data(csr::kIdx); + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + return; + } + + MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type + /* Allocate workspace */ + CType num_cols_out = out.shape()[1]; + CType rhs_data_size = static_cast(col_idx_r.shape_.Size()); + size_t workspace_size = 2 * num_cols_out * sizeof(CType); + Tensor workspace = + ctx.requested[0].get_space_typed( + Shape1(workspace_size), s); + CType* col_flg = reinterpret_cast(workspace.dptr_); + + CType* prefix_sum = col_flg; + CType* nnc_idx = prefix_sum + num_cols_out; + + /* Set the column flags for nnz columns */ + mxnet_op::Kernel::Launch(s, num_cols_out, + col_flg); + mxnet_op::Kernel::Launch( + s, rhs_data_size, col_flg, col_idx_r.dptr()); + + /* 1. Calculate prefix sum from col flgs + * 2. Storage all non zero column indexes in nnc_idx + */ + CType cur = 0; + prefix_sum[0] = col_flg[0]; + if (prefix_sum[0]) nnc_idx[cur++] = 0; + for (CType i = 1; i < num_cols_out; i++) { + prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; + if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; + } + + /* Allocate aux data for out */ + IType num_rows_l = lhs.shape_[0]; + dim_t nnc = prefix_sum[num_cols_out - 1]; + dim_t nnz = nnc * num_rows_l; + out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); + out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + + /* Set csr indptr and index according to nnc_idx*/ + IType* indptr_out = out.aux_data(csr::kIndPtr).dptr(); + CType* col_idx_out = out.aux_data(csr::kIdx).dptr(); + DType* data_out = out.data().dptr(); + mxnet_op::Kernel::Launch( + s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); + mxnet_op::Kernel::Launch(s, nnz, data_out); + + const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); + const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; + + IType num_rows_r = rhs.shape()[0]; + mxnet_op::Kernel::Launch( + s, num_threads, data_out, data_l.dptr(), + indptr_r.dptr(), col_idx_r.dptr(), + data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, + nnc, prefix_sum); + }); + }); + }); +} + + +template +inline void DotDnsCsrDnsImpl(const OpContext& ctx, + const TBlob& dns, const NDArray& rhs, + const OpReqType req, NDArray* ret, + const bool transpose_b) { + LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; +} + NNVM_REGISTER_OP(dot) .add_alias("_sparse_dot") // alias for op registration under mxnet.ndarray.sparse .describe(R"doc(Dot product of two arrays. diff --git a/src/operator/tensor/dot.cu b/src/operator/tensor/dot.cu index 8ee2e2832fbb..ac514f49df8f 100644 --- a/src/operator/tensor/dot.cu +++ b/src/operator/tensor/dot.cu @@ -23,6 +23,7 @@ */ #include "./dot-inl.h" +#include namespace mxnet { namespace op { From ab3461c8cadb975d205cd258a59a6849151ed423 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 5 Apr 2018 23:33:08 +0000 Subject: [PATCH 2/7] add unit test for new op and forward_stype_hint parameter to dot --- src/operator/tensor/dot-inl.cuh | 39 ++++++----- src/operator/tensor/dot-inl.h | 67 +++++++++++++------ tests/python/unittest/test_sparse_operator.py | 33 ++++++++- 3 files changed, 98 insertions(+), 41 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index fcf9440a77d2..a65936840662 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -453,13 +453,12 @@ struct CscDataIndicesKernel { const IType* csr_indices, const CType* csr_indptr, DType* csc_data, - int* csc_indices, - int* csc_indptr, - int* workspace, + unsigned long long* csc_indices, + unsigned long long* csc_indptr, + unsigned long long* workspace, const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { if (tid < num_rows) { - printf("%d, %d ", tid, csc_indptr[tid]); for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { // target column IType target_col = csr_indices[i]; @@ -487,7 +486,7 @@ struct CsrTransHistogramKernel { template __device__ __forceinline__ static void Map(int tid, const IType* in_indices, - int* out_indptr, + unsigned long long* out_indptr, const nnvm::dim_t nnz) { if (tid < nnz) { atomicAdd(&out_indptr[in_indices[tid]], 1); @@ -1036,9 +1035,10 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { DType* csc_data_ptr = NULL; - int* csc_indices_ptr = NULL; - int* csc_indptr_ptr = NULL; - int* col_counters = NULL; + unsigned long long* csc_indices_ptr = NULL; + unsigned long long* csc_indptr_ptr = NULL; + unsigned long long* col_counters = NULL; + size_t ull_mem_size = sizeof(unsigned long long); void* temp_storage = NULL; size_t temp_storage_bytes = 0; CType out_num_rows = ret->shape()[0]; @@ -1050,22 +1050,22 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, csc_indices_ptr, csr_cols+1, Stream::GetStream(s)); + temp_storage_bytes += (ull_mem_size - (temp_storage_bytes % ull_mem_size)); Tensor workspace = ctx.requested[0].get_space_typed( - Shape1(nnz*sizeof(DType) + nnz*sizeof(int) + - (csr_cols + 1)*sizeof(int) + - (csr_cols + 1)*sizeof(int) + + Shape1(nnz*sizeof(DType) + nnz*ull_mem_size + + 2*(csr_cols + 1)*ull_mem_size + temp_storage_bytes), s); - csc_data_ptr = reinterpret_cast(workspace.dptr_); - csc_indices_ptr = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType)); - csc_indptr_ptr = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + - nnz*sizeof(int)); - col_counters = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + - nnz*sizeof(int) + (csr_cols+1)*sizeof(int)); + csc_indices_ptr = reinterpret_cast(workspace.dptr_); + csc_indptr_ptr = reinterpret_cast( + workspace.dptr_ + nnz*ull_mem_size); + col_counters = reinterpret_cast( + workspace.dptr_ + nnz*ull_mem_size + (csr_cols+1)*ull_mem_size); + csc_data_ptr = reinterpret_cast(workspace.dptr_ + nnz*ull_mem_size + + 2*(csr_cols+1)*ull_mem_size); temp_storage = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + - nnz*sizeof(int) + (csr_cols+1)*sizeof(int) + - (csr_cols + 1)*sizeof(int)); + nnz*ull_mem_size + 2*(csr_cols+1)*ull_mem_size); mxnet_op::Kernel::Launch( s, dns_rows*csr_cols, ret->data().dptr()); // Reset values for indptr, ready for histogramming @@ -1088,7 +1088,6 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, s, csr_rows, csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, col_counters, csr_rows, csr_cols); - mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index bf0c59150dcb..92773bee7de0 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -45,6 +45,7 @@ namespace op { struct DotParam : public dmlc::Parameter { bool transpose_a; bool transpose_b; + dmlc::optional forward_stype_hint; DMLC_DECLARE_PARAMETER(DotParam) { DMLC_DECLARE_FIELD(transpose_a) .describe("If true then transpose the first input before dot.") @@ -52,6 +53,12 @@ struct DotParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(transpose_b) .describe("If true then transpose the second input before dot.") .set_default(false); + DMLC_DECLARE_FIELD(forward_stype_hint) + .describe("Desired storage type of the forward output.") + .add_enum("default", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .set_default(dmlc::optional()); } }; @@ -217,39 +224,59 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, bool only_lhs_transpose = param.transpose_a && !param.transpose_b; bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; + NDArrayStorageType target_stype; if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { // dns, dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFCompute); + target_stype = (param.forward_stype_hint.has_value())? + static_cast(param.forward_stype_hint.value()) : + kDefaultStorage; + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFCompute); + } } - if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && - (rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage)) { + if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) { // csr.T, rsp/dns -> rsp - dispatched = storage_type_assign(&out_stype, kRowSparseStorage, - dispatch_mode, DispatchMode::kFComputeEx); + target_stype = (param.forward_stype_hint.has_value())? + static_cast(param.forward_stype_hint.value()) : + kRowSparseStorage; + if (target_stype == kRowSparseStorage) { + dispatched = storage_type_assign(&out_stype, kRowSparseStorage, + dispatch_mode, DispatchMode::kFComputeEx); + } } if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFComputeEx); + target_stype = (param.forward_stype_hint.has_value())? + static_cast(param.forward_stype_hint.value()) : + kDefaultStorage; + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } } if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a) { - // dns, csr -> csr - // const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; - // const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback - // : DispatchMode::kFComputeEx; - // dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - // dispatch_ex); - if (dev_mask == mshadow::cpu::kDevMask) { - CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; - dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, - DispatchMode::kFComputeEx); + // dns, csr -> csr on CPU + if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { + target_stype = (param.forward_stype_hint.has_value())? + static_cast(param.forward_stype_hint.value()) : + kCSRStorage; + if (target_stype == kCSRStorage) { + dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } + // dns, csr/csr.T -> dns on GPU } else if (dev_mask == mshadow::gpu::kDevMask) { - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFComputeEx); + target_stype = (param.forward_stype_hint.has_value())? + static_cast(param.forward_stype_hint.value()) : + kDefaultStorage; + if (target_stype == kDefaultStorage) { + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } } } if (!dispatched) { diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 484c98643d91..7dcc670519b6 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1208,6 +1208,27 @@ def check_cast_storage(shape, density, from_stype, to_stype, check_numeric_grad= @with_seed() def test_sparse_dot(): + def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_a, trans_b): + all_stypes = ["default", "csr", "row_sparse"] + lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density) + rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density) + out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a, transpose_b=trans_b) + out_np = out_nd.asnumpy() + for lhs_stype in all_stypes: + for rhs_stype in all_stypes: + for forward_stype in all_stypes: + lhs = lhs_nd.tostype(lhs_stype) + rhs = rhs_nd.tostype(rhs_stype) + out = mx.nd.dot(lhs, rhs, forward_stype_hint=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5) + lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype) + out = mx.symbol.sparse.dot(lhs_var, rhs_var, + forward_stype_hint=forward_stype, + transpose_a=trans_a, transpose_b=trans_b) + location = {'lhs': lhs, 'rhs': rhs} + check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, lhs_density, rhs_density): lhs_nd = rand_ndarray(lhs_shape, 'csr', density=lhs_density, shuffle_csr_indices=False) lhs_dns = lhs_nd.tostype('default') @@ -1254,6 +1275,8 @@ def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=F # test symbolic backward backward_trans = not trans_lhs rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() + if trans_rhs is True: + rhs_backward_grad = rhs_backward_grad.T expected = {'rhs': rhs_backward_grad} check_symbolic_backward(out, location, [out_np], expected, grad_req={'lhs': 'null', 'rhs': 'write'}, @@ -1285,10 +1308,18 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(5, 10)), 'default', False, lhs_d, rhs_d) # test gpu SpMM test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(5, 10)), 'default', True, lhs_d, rhs_d) # (scalar kernel) test_dot_dns_csr(lhs_shape, (lhs_shape[1], rnd.randint(50, 200)), lhs_d, lhs_d) + test_dot_dns_csr(lhs_shape, (rnd.randint(50, 200), lhs_shape[1]), lhs_d, lhs_d, trans_rhs=True) for rhs_d in density: test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, lhs_d, rhs_d) test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, lhs_d, rhs_d) - + test_infer_forward_stype(lhs_shape, (lhs_shape[1], rnd.randint(10, 20)), + lhs_d, rhs_d, False, False) + test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[1]), + lhs_d, rhs_d, False, True) + test_infer_forward_stype(lhs_shape, (lhs_shape[0], rnd.randint(10, 20)), + lhs_d, rhs_d, True, False) + test_infer_forward_stype(lhs_shape, (rnd.randint(10, 20), lhs_shape[0]), + lhs_d, rhs_d, True, True) test_sparse_dot_zero_output(rand_shape_2d(50, 200), False, 40) test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) From 15100e5410ce219983717df14e9e114c6aff8f41 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sat, 7 Apr 2018 17:30:16 +0000 Subject: [PATCH 3/7] update documentation for dot --- src/operator/tensor/dot-inl.h | 1 - src/operator/tensor/dot.cc | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 92773bee7de0..14dba5a6da28 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -1001,7 +1001,6 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); const DotParam& param = nnvm::get(attrs.parsed); - // CHECK(!param.transpose_b) << "transposing rhs of the sparse dot op is not supported"; CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs"; CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs"; auto lhs_stype = inputs[0].storage_type(); diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index f563d7411d66..d7108085fe30 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -150,12 +150,16 @@ NNVM_REGISTER_OP(dot) dot(x,y)[0,0,1,1] = 0 sum(x[0,0,:]*y[:,1,1]) = 0 -The storage type of ``dot`` output depends on storage types of inputs and transpose options: +The storage type of ``dot`` output depends on storage types of inputs, transpose options and given +hint for output storage type: +Implemented sprase operations include: - dot(csr, default) = default - dot(csr.T, default) = row_sparse - dot(csr, row_sparse) = default -- dot(default, csr) = csr +- dot(default, csr) = csr on CPU only +- dot(default, csr) = dense on GPU only +- dot(default, csr.T) = dense on GPU only - otherwise, ``dot`` generates output with default storage )doc" ADD_FILELINE) From d5a2c5cfec0eea0526cbffdf1c1b3312422724ac Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 13 Apr 2018 01:07:47 +0000 Subject: [PATCH 4/7] address code reviews --- src/operator/tensor/dot-inl.cuh | 108 ++++++++++++++++---------------- src/operator/tensor/dot-inl.h | 15 +++-- src/operator/tensor/dot.cc | 8 ++- 3 files changed, 68 insertions(+), 63 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index a65936840662..19cac543bf50 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -30,6 +30,8 @@ #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" +typedef unsigned long long AtomicIType; + namespace mxnet { namespace op { @@ -453,17 +455,17 @@ struct CscDataIndicesKernel { const IType* csr_indices, const CType* csr_indptr, DType* csc_data, - unsigned long long* csc_indices, - unsigned long long* csc_indptr, - unsigned long long* workspace, + AtomicIType* csc_indices, + AtomicIType* csc_indptr, + AtomicIType* col_counters, const nnvm::dim_t num_rows, const nnvm::dim_t num_cols) { if (tid < num_rows) { for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { // target column - IType target_col = csr_indices[i]; - int target_offset = atomicAdd(&workspace[target_col], 1); - int new_pos = csc_indptr[target_col] + target_offset; + const IType target_col = csr_indices[i]; + const int target_offset = atomicAdd(&col_counters[target_col], 1); + const int new_pos = csc_indptr[target_col] + target_offset; csc_data[new_pos] = csr_data[i]; csc_indices[new_pos] = tid; } @@ -486,7 +488,7 @@ struct CsrTransHistogramKernel { template __device__ __forceinline__ static void Map(int tid, const IType* in_indices, - unsigned long long* out_indptr, + AtomicIType* out_indptr, const nnvm::dim_t nnz) { if (tid < nnz) { atomicAdd(&out_indptr[in_indices[tid]], 1); @@ -1023,54 +1025,60 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, return; } - // if dot(dense, csr) = dns, transform to csc first - if (!transpose_b) { - // LOG(FATAL) << "dot(dns, csr) = dns not implemented yet"; - const nnvm::dim_t csr_rows = rhs.shape()[0]; - const nnvm::dim_t csr_cols = rhs.shape()[1]; - const nnvm::dim_t dns_rows = dns.shape_[0]; - const nnvm::dim_t nnz = rhs.storage_shape().Size(); - - MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { - MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { + MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type + const CType out_num_rows = ret->shape()[0]; + const CType out_num_cols = ret->shape()[1]; + // if dot(dense, csr) = dns, transform to csc first + if (!transpose_b) { + const nnvm::dim_t num_csr_rows = rhs.shape()[0]; + const nnvm::dim_t num_csr_cols = rhs.shape()[1]; + const nnvm::dim_t num_dns_rows = dns.shape_[0]; + const nnvm::dim_t nnz = rhs.storage_shape().Size(); + DType* csc_data_ptr = NULL; - unsigned long long* csc_indices_ptr = NULL; - unsigned long long* csc_indptr_ptr = NULL; - unsigned long long* col_counters = NULL; - size_t ull_mem_size = sizeof(unsigned long long); + AtomicIType* csc_indices_ptr = NULL; + AtomicIType* csc_indptr_ptr = NULL; + AtomicIType* col_counters = NULL; + size_t ull_num_bytes = sizeof(AtomicIType); void* temp_storage = NULL; size_t temp_storage_bytes = 0; - CType out_num_rows = ret->shape()[0]; - CType out_num_cols = ret->shape()[1]; + // Get necessary temporary storage amount cub::DeviceScan::ExclusiveSum(NULL, temp_storage_bytes, csc_indices_ptr, csc_indices_ptr, - csr_cols+1, + num_csr_cols + 1, Stream::GetStream(s)); - temp_storage_bytes += (ull_mem_size - (temp_storage_bytes % ull_mem_size)); + // Align to multiple of ull_num_bytes + temp_storage_bytes += (ull_num_bytes - (temp_storage_bytes % ull_num_bytes)); + size_t csc_data_size = nnz*sizeof(DType); + size_t csc_indices_size = nnz*ull_num_bytes; + size_t csc_indptr_size = (num_csr_cols+1)*ull_num_bytes; + size_t col_counters_size = (num_csr_cols+1)*ull_num_bytes; Tensor workspace = ctx.requested[0].get_space_typed( - Shape1(nnz*sizeof(DType) + nnz*ull_mem_size + - 2*(csr_cols + 1)*ull_mem_size + + Shape1(csc_data_size + csc_indices_size + + csc_indptr_size + col_counters_size + temp_storage_bytes), s); - csc_indices_ptr = reinterpret_cast(workspace.dptr_); - csc_indptr_ptr = reinterpret_cast( - workspace.dptr_ + nnz*ull_mem_size); - col_counters = reinterpret_cast( - workspace.dptr_ + nnz*ull_mem_size + (csr_cols+1)*ull_mem_size); - csc_data_ptr = reinterpret_cast(workspace.dptr_ + nnz*ull_mem_size + - 2*(csr_cols+1)*ull_mem_size); - temp_storage = reinterpret_cast(workspace.dptr_ + nnz*sizeof(DType) + - nnz*ull_mem_size + 2*(csr_cols+1)*ull_mem_size); + csc_indices_ptr = reinterpret_cast(workspace.dptr_); + csc_indptr_ptr = reinterpret_cast( + workspace.dptr_ + csc_indices_size); + col_counters = reinterpret_cast( + workspace.dptr_ + csc_indices_size + csc_indptr_size); + csc_data_ptr = reinterpret_cast(workspace.dptr_ + csc_indices_size + + csc_indptr_size + col_counters_size); + temp_storage = reinterpret_cast(workspace.dptr_ + csc_data_size + + csc_indices_size + csc_indptr_size + + col_counters_size); mxnet_op::Kernel::Launch( - s, dns_rows*csr_cols, ret->data().dptr()); + s, num_dns_rows*num_csr_cols, ret->data().dptr()); // Reset values for indptr, ready for histogramming mxnet_op::Kernel::Launch( - s, csr_cols + 1, csc_indptr_ptr); + s, num_csr_cols+1, csc_indptr_ptr); // Histogramming on col id mxnet_op::Kernel::Launch( s, nnz, csr_indices.dptr(), csc_indptr_ptr, nnz); @@ -1078,39 +1086,31 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, temp_storage_bytes, csc_indptr_ptr, csc_indptr_ptr, - csr_cols+1, + num_csr_cols + 1, Stream::GetStream(s)); // Reset values for col_counter, ready for the final transform mxnet_op::Kernel::Launch( - s, csr_cols+1, col_counters); + s, num_csr_cols+1, col_counters); // Transform to CSC mxnet_op::Kernel::Launch( - s, csr_rows, csr_data.dptr(), csr_indices.dptr(), + s, num_csr_rows, csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, - csc_indptr_ptr, col_counters, csr_rows, csr_cols); + csc_indptr_ptr, col_counters, num_csr_rows, num_csr_cols); mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, ret->data().dptr(), dns.shape_[1], out_num_rows, out_num_cols); - }); - }); - }); - } else { - MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type - CType out_num_rows = ret->shape()[0]; - CType out_num_cols = ret->shape()[1]; + } else { mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csr_data.dptr(), csr_indices.dptr(), csr_indptr.dptr(), ret->data().dptr(), dns.shape_[1], out_num_rows, out_num_cols); - }); + } }); }); - } + }); } } // namespace op diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 14dba5a6da28..c8a3f3954d8c 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -54,7 +54,9 @@ struct DotParam : public dmlc::Parameter { .describe("If true then transpose the second input before dot.") .set_default(false); DMLC_DECLARE_FIELD(forward_stype_hint) - .describe("Desired storage type of the forward output.") + .describe("Hint on the desired storage type of the forward output given by user," + "if the combination of input storage types and this hint does not match" + "any implemented ones, the dot operator will perform fallback operation.") .add_enum("default", kDefaultStorage) .add_enum("row_sparse", kRowSparseStorage) .add_enum("csr", kCSRStorage) @@ -225,10 +227,11 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; NDArrayStorageType target_stype; + bool hint_has_value = param.forward_stype_hint.has_value(); if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { // dns, dns -> dns - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) { @@ -238,7 +241,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) { // csr.T, rsp/dns -> rsp - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kRowSparseStorage; if (target_stype == kRowSparseStorage) { @@ -249,7 +252,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) { @@ -261,7 +264,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, !param.transpose_a) { // dns, csr -> csr on CPU if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kCSRStorage; if (target_stype == kCSRStorage) { @@ -270,7 +273,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } // dns, csr/csr.T -> dns on GPU } else if (dev_mask == mshadow::gpu::kDevMask) { - target_stype = (param.forward_stype_hint.has_value())? + target_stype = hint_has_value ? static_cast(param.forward_stype_hint.value()) : kDefaultStorage; if (target_stype == kDefaultStorage) { diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index d7108085fe30..3d3e91cb082c 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -155,12 +155,14 @@ hint for output storage type: Implemented sprase operations include: - dot(csr, default) = default -- dot(csr.T, default) = row_sparse +- dot(csr, default, transpose_a=True) = row_sparse - dot(csr, row_sparse) = default - dot(default, csr) = csr on CPU only - dot(default, csr) = dense on GPU only -- dot(default, csr.T) = dense on GPU only -- otherwise, ``dot`` generates output with default storage +- dot(default, csr, transpose_b=True) = dense on GPU only +- if the combination of input storage types and forward_stype_hint +- does not match any of the above patterns, +- dot will generate output with default storage )doc" ADD_FILELINE) .set_num_inputs(2) From 9349af781a55bc0d84796794a68df6d0643928d8 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 13 Apr 2018 02:12:53 +0000 Subject: [PATCH 5/7] fix flaky test_gluon:test_lambda through loosening the atol --- tests/python/unittest/test_gluon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 0a5bda831d9c..abb27de1dc71 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -738,8 +738,8 @@ def test_lambda(): input_data = mx.nd.random.uniform(shape=(2, 3, 5, 7)) out1, out2, out3 = net1(input_data), net2(input_data), net3(input_data) - assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3) - assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3) + assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-3, atol=1e-3) + assert_almost_equal(out1.asnumpy(), out3.asnumpy(), rtol=1e-3, atol=1e-3) @with_seed() From 7caf5f5c6365a516a3e6e85610ca938738658367 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Tue, 17 Apr 2018 00:19:05 -0700 Subject: [PATCH 6/7] switch dot(dns, csr) case to a deterministic algorithm with unit test for determinism --- src/operator/tensor/dot-inl.cuh | 208 ++++++++++-------- src/operator/tensor/dot-inl.h | 46 ++-- src/operator/tensor/dot.cc | 4 +- src/operator/tensor/dot.cu | 1 - src/operator/tensor/util/tensor_util-inl.cuh | 16 ++ tests/python/unittest/test_sparse_operator.py | 45 +++- 6 files changed, 193 insertions(+), 127 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index 19cac543bf50..10ef7546b464 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,11 +27,12 @@ #include #include +#include "./indexing_op.h" +#include "./init_op.h" +#include "./sort_op.h" #include "./util/tensor_util-inl.h" #include "./util/tensor_util-inl.cuh" -typedef unsigned long long AtomicIType; - namespace mxnet { namespace op { @@ -445,53 +446,53 @@ struct DotCsrRspDnsScalarKernel { }; /*! - * \brief GPU Kernel to re-arrange nnz elements to csc order - * Parallelization by output elements: 1 thread/row of csr + * \brief GPU Kernel to scatter row id to corresponding entries + * \param tid global thread id + * \param csr_indptr indptr array of csr + * \param csr_rows array of row id of csr elements + * \param num_rows total number of rows in csr matrix + * Parallelization by output elements: 1 thread/row */ -struct CscDataIndicesKernel { - template +struct CsrRowScatterKernel { + template __device__ __forceinline__ static void Map(int tid, - const DType* csr_data, - const IType* csr_indices, const CType* csr_indptr, - DType* csc_data, - AtomicIType* csc_indices, - AtomicIType* csc_indptr, - AtomicIType* col_counters, - const nnvm::dim_t num_rows, - const nnvm::dim_t num_cols) { + CType* csr_rows, + const nnvm::dim_t num_rows) { if (tid < num_rows) { - for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { - // target column - const IType target_col = csr_indices[i]; - const int target_offset = atomicAdd(&col_counters[target_col], 1); - const int new_pos = csc_indptr[target_col] + target_offset; - csc_data[new_pos] = csr_data[i]; - csc_indices[new_pos] = tid; + for (CType i = csr_indptr[tid]; i < csr_indptr[tid+1]; ++i) { + csr_rows[i] = tid; } } } }; -/*! - * \brief GPU Kernel of getting count for every column - * Parallelization by output elements: 1 thread/element - */ -struct CsrTransHistogramKernel { +struct CscDataIndicesKernel { /*! * \brief * \param tid global thread id - * \param in_indices csr matrix column indices - * \param out_indptr csr matrix row pointer - * \param nnz number of non-zero elements in csr + * \param lhs_data lhs dense matrix data + * \param rhs_data csr matrix data + * \param rhs_indices csr matrix column indices + * \param rhs_indptr csr matrix row pointer + * \param out output matrix data + * \param lhs_num_cols lhs dns matrix number of columns + * \param out_num_rows output dns matrix number of rows + * \param out_num_cols output dns matrix number of columns */ - template + template __device__ __forceinline__ static void Map(int tid, - const IType* in_indices, - AtomicIType* out_indptr, + const IType* original_idx_ptr, + const DType* csr_data_ptr, + const CType* csr_rows_ptr, + DType* csc_data_ptr, + IType* csc_indices_ptr, const nnvm::dim_t nnz) { + using nnvm::dim_t; if (tid < nnz) { - atomicAdd(&out_indptr[in_indices[tid]], 1); + const IType origin = original_idx_ptr[tid]; + csc_data_ptr[tid] = csr_data_ptr[origin]; + csc_indices_ptr[tid] = csr_rows_ptr[origin]; } } }; @@ -525,14 +526,14 @@ struct DotDnsCsrTransDnsKernel { const nnvm::dim_t out_num_cols) { using nnvm::dim_t; if (tid < out_num_rows*out_num_cols) { - const dim_t i = static_cast(tid) / out_num_cols; // i = row this thread computes - const dim_t k = static_cast(tid) % out_num_cols; // k = col this thread computes + const dim_t i = static_cast(tid) % out_num_rows; // i = row this thread computes + const dim_t k = static_cast(tid) / out_num_rows; // k = col this thread computes // Compute inner product of i-th row and k-th col DType sum = 0; for (CType col_id = rhs_indptr[k]; col_id < rhs_indptr[k + 1]; ++col_id) { sum += lhs_data[i * lhs_num_cols + rhs_indices[col_id]] * rhs_data[col_id]; } - out[tid] = sum; + out[i * out_num_cols + k] = sum; } } }; @@ -1028,8 +1029,8 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, { // data type MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, { // indptr type MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, { // colidx type - const CType out_num_rows = ret->shape()[0]; - const CType out_num_cols = ret->shape()[1]; + const nnvm::dim_t out_num_rows = ret->shape()[0]; + const nnvm::dim_t out_num_cols = ret->shape()[1]; // if dot(dense, csr) = dns, transform to csc first if (!transpose_b) { const nnvm::dim_t num_csr_rows = rhs.shape()[0]; @@ -1037,65 +1038,86 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, const nnvm::dim_t num_dns_rows = dns.shape_[0]; const nnvm::dim_t nnz = rhs.storage_shape().Size(); - DType* csc_data_ptr = NULL; - AtomicIType* csc_indices_ptr = NULL; - AtomicIType* csc_indptr_ptr = NULL; - AtomicIType* col_counters = NULL; - size_t ull_num_bytes = sizeof(AtomicIType); - void* temp_storage = NULL; - size_t temp_storage_bytes = 0; - - // Get necessary temporary storage amount - cub::DeviceScan::ExclusiveSum(NULL, - temp_storage_bytes, - csc_indices_ptr, - csc_indices_ptr, - num_csr_cols + 1, - Stream::GetStream(s)); - // Align to multiple of ull_num_bytes - temp_storage_bytes += (ull_num_bytes - (temp_storage_bytes % ull_num_bytes)); - size_t csc_data_size = nnz*sizeof(DType); - size_t csc_indices_size = nnz*ull_num_bytes; - size_t csc_indptr_size = (num_csr_cols+1)*ull_num_bytes; - size_t col_counters_size = (num_csr_cols+1)*ull_num_bytes; - Tensor workspace = - ctx.requested[0].get_space_typed( - Shape1(csc_data_size + csc_indices_size + - csc_indptr_size + col_counters_size + - temp_storage_bytes), - s); - csc_indices_ptr = reinterpret_cast(workspace.dptr_); - csc_indptr_ptr = reinterpret_cast( - workspace.dptr_ + csc_indices_size); - col_counters = reinterpret_cast( - workspace.dptr_ + csc_indices_size + csc_indptr_size); - csc_data_ptr = reinterpret_cast(workspace.dptr_ + csc_indices_size + - csc_indptr_size + col_counters_size); - temp_storage = reinterpret_cast(workspace.dptr_ + csc_data_size + - csc_indices_size + csc_indptr_size + - col_counters_size); - mxnet_op::Kernel::Launch( - s, num_dns_rows*num_csr_cols, ret->data().dptr()); - // Reset values for indptr, ready for histogramming - mxnet_op::Kernel::Launch( - s, num_csr_cols+1, csc_indptr_ptr); - // Histogramming on col id - mxnet_op::Kernel::Launch( - s, nnz, csr_indices.dptr(), csc_indptr_ptr, nnz); - cub::DeviceScan::ExclusiveSum(temp_storage, + IType* original_idx_ptr = nullptr; + IType* csc_indices_ptr = nullptr; + IType* csc_cols_ptr = nullptr; + CType* csr_rows_ptr = nullptr; + CType* csc_indptr_ptr = nullptr; + DType* csc_data_ptr = nullptr; + char* temp_storage_ptr = nullptr; + size_t original_idx_bytes = nnz*sizeof(IType); + size_t csc_indices_bytes = nnz*sizeof(IType); + size_t csc_cols_bytes = nnz*sizeof(IType); + size_t csr_rows_bytes = nnz*sizeof(CType); + size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType); + size_t csc_data_bytes = nnz*sizeof(DType); + size_t scan_temp_storage_bytes = 0; + size_t temp_storage_bytes = SortByKeyWorkspaceSize(nnz); + IType* csr_indices_ptr = csr_indices.dptr(); + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, + scan_temp_storage_bytes, + csc_indptr_ptr, + csc_indptr_ptr, + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + temp_storage_bytes = std::max(temp_storage_bytes, scan_temp_storage_bytes); + temp_storage_bytes += (sizeof(dim_t) - temp_storage_bytes % sizeof(dim_t)); + size_t total_workspace_bytes = + original_idx_bytes + csc_indices_bytes + csc_cols_bytes + csr_rows_bytes + + csc_indptr_bytes + csc_data_bytes + temp_storage_bytes; + total_workspace_bytes += (sizeof(IType) - total_workspace_bytes % sizeof(IType)); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(total_workspace_bytes), s); + original_idx_ptr = reinterpret_cast(workspace.dptr_); + csc_indices_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes); + csc_cols_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes); + csr_rows_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes); + csc_indptr_ptr = reinterpret_cast(workspace.dptr_ + original_idx_bytes + + csc_indices_bytes + csc_cols_bytes + + csr_rows_bytes); + temp_storage_ptr = workspace.dptr_ + original_idx_bytes + csc_indices_bytes + + csc_cols_bytes + csr_rows_bytes + csc_indptr_bytes; + csc_data_ptr = reinterpret_cast( + workspace.dptr_ + total_workspace_bytes - csc_data_bytes); + + // Fill original_idx + mxnet_op::Kernel::Launch( + s, nnz, 1, IType(0), IType(1), kWriteTo, original_idx_ptr); + // Fill csc_cols with copy of csr_indices + mxnet_op::Kernel, gpu>::Launch( + s, nnz, csc_cols_ptr, csr_indices_ptr); + // Allocate the tensors needed for SortByKey + Tensor original_idx(original_idx_ptr, Shape1(nnz), s); + Tensor csc_cols(csc_cols_ptr, Shape1(nnz), s); + Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); + + int num_bits = 1; + unsigned int a = num_csr_cols - 1; + while (a >>= 1) num_bits++; + SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits); + + // Scatter csr indptr to row id + mxnet_op::Kernel::Launch( + s, num_csr_rows, csr_indptr.dptr(), csr_rows_ptr, num_csr_rows); + // Reset indptr to zero + mxnet_op::Kernel::Launch(s, num_csr_cols+1, csc_indptr_ptr); + // Histogram on the sorted cols + mxnet_op::Kernel::Launch( + s, nnz, csc_indptr_ptr, csc_cols_ptr, nnz); + // Scan the bin counts for every column to get csc_indptr + cub::DeviceScan::ExclusiveSum(temp_storage_ptr, temp_storage_bytes, csc_indptr_ptr, csc_indptr_ptr, - num_csr_cols + 1, - Stream::GetStream(s)); - // Reset values for col_counter, ready for the final transform - mxnet_op::Kernel::Launch( - s, num_csr_cols+1, col_counters); - // Transform to CSC + num_csr_cols+1, + mshadow::Stream::GetStream(s)); + // Assign data to csc matrix arrays mxnet_op::Kernel::Launch( - s, num_csr_rows, csr_data.dptr(), csr_indices.dptr(), - csr_indptr.dptr(), csc_data_ptr, csc_indices_ptr, - csc_indptr_ptr, col_counters, num_csr_rows, num_csr_cols); + s, nnz, original_idx_ptr, csr_data.dptr(), csr_rows_ptr, csc_data_ptr, + csc_indices_ptr, nnz); + mxnet_op::Kernel::Launch( s, out_num_rows * out_num_cols, dns.dptr(), csc_data_ptr, csc_indices_ptr, csc_indptr_ptr, diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index c8a3f3954d8c..69430ca55a68 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -45,7 +45,7 @@ namespace op { struct DotParam : public dmlc::Parameter { bool transpose_a; bool transpose_b; - dmlc::optional forward_stype_hint; + dmlc::optional forward_stype; DMLC_DECLARE_PARAMETER(DotParam) { DMLC_DECLARE_FIELD(transpose_a) .describe("If true then transpose the first input before dot.") @@ -53,10 +53,11 @@ struct DotParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(transpose_b) .describe("If true then transpose the second input before dot.") .set_default(false); - DMLC_DECLARE_FIELD(forward_stype_hint) - .describe("Hint on the desired storage type of the forward output given by user," - "if the combination of input storage types and this hint does not match" - "any implemented ones, the dot operator will perform fallback operation.") + DMLC_DECLARE_FIELD(forward_stype) + .describe("The desired storage type of the forward output given by user, if the" + "combination of input storage types and this hint does not match" + "any implemented ones, the dot operator will perform fallback operation" + "and still produce an output of the desired storage type.") .add_enum("default", kDefaultStorage) .add_enum("row_sparse", kRowSparseStorage) .add_enum("csr", kCSRStorage) @@ -226,14 +227,14 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, bool only_lhs_transpose = param.transpose_a && !param.transpose_b; bool rhs_rsp_or_dns = rhs_stype == kRowSparseStorage || rhs_stype == kDefaultStorage; - NDArrayStorageType target_stype; - bool hint_has_value = param.forward_stype_hint.has_value(); + bool hint_has_value = param.forward_stype.has_value(); + NDArrayStorageType target_stype = hint_has_value ? + static_cast(param.forward_stype.value()) : + kUndefinedStorage; if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { // dns, dns -> dns - target_stype = hint_has_value ? - static_cast(param.forward_stype_hint.value()) : - kDefaultStorage; + target_stype = hint_has_value ? target_stype : kDefaultStorage; if (target_stype == kDefaultStorage) { dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); @@ -241,9 +242,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched && lhs_stype == kCSRStorage && only_lhs_transpose && rhs_rsp_or_dns) { // csr.T, rsp/dns -> rsp - target_stype = hint_has_value ? - static_cast(param.forward_stype_hint.value()) : - kRowSparseStorage; + target_stype = hint_has_value ? target_stype : kRowSparseStorage; if (target_stype == kRowSparseStorage) { dispatched = storage_type_assign(&out_stype, kRowSparseStorage, dispatch_mode, DispatchMode::kFComputeEx); @@ -252,9 +251,7 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, if (!dispatched && lhs_stype == kCSRStorage && rhs_rsp_or_dns && !param.transpose_a && !param.transpose_b) { // csr, rsp/dns -> dns - target_stype = hint_has_value ? - static_cast(param.forward_stype_hint.value()) : - kDefaultStorage; + target_stype = hint_has_value ? target_stype : kDefaultStorage; if (target_stype == kDefaultStorage) { dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); @@ -262,20 +259,15 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && !param.transpose_a) { + target_stype = hint_has_value ? target_stype : kCSRStorage; // dns, csr -> csr on CPU if (dev_mask == mshadow::cpu::kDevMask && !param.transpose_b) { - target_stype = hint_has_value ? - static_cast(param.forward_stype_hint.value()) : - kCSRStorage; if (target_stype == kCSRStorage) { dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, DispatchMode::kFComputeEx); } // dns, csr/csr.T -> dns on GPU } else if (dev_mask == mshadow::gpu::kDevMask) { - target_stype = hint_has_value ? - static_cast(param.forward_stype_hint.value()) : - kDefaultStorage; if (target_stype == kDefaultStorage) { dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); @@ -283,7 +275,9 @@ inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, } } if (!dispatched) { - dispatched = dispatch_fallback(out_attrs, dispatch_mode); + target_stype = (target_stype == kUndefinedStorage)? kDefaultStorage : target_stype; + dispatched = storage_type_assign(&out_stype, target_stype, dispatch_mode, + DispatchMode::kFComputeFallback); } return dispatched; } @@ -1048,8 +1042,10 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 3U); CHECK_EQ(outputs.size(), 2U); CHECK_EQ(req.size(), 2U); - CHECK_EQ(kNullOp, req[0]) - << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK(!(req[0] != kNullOp && outputs[0].storage_type() == kCSRStorage)) + << "sparse dot does not support computing the gradient of csr"; + CHECK(!(req[1] != kNullOp && outputs[1].storage_type() == kCSRStorage)) + << "sparse dot does not support computing the gradient of csr"; CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; const DotParam& param = nnvm::get(attrs.parsed); diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 3d3e91cb082c..89985e3a3f09 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -158,8 +158,8 @@ Implemented sprase operations include: - dot(csr, default, transpose_a=True) = row_sparse - dot(csr, row_sparse) = default - dot(default, csr) = csr on CPU only -- dot(default, csr) = dense on GPU only -- dot(default, csr, transpose_b=True) = dense on GPU only +- dot(default, csr) = default on GPU only +- dot(default, csr, transpose_b=True) = default on GPU only - if the combination of input storage types and forward_stype_hint - does not match any of the above patterns, - dot will generate output with default storage diff --git a/src/operator/tensor/dot.cu b/src/operator/tensor/dot.cu index ac514f49df8f..8ee2e2832fbb 100644 --- a/src/operator/tensor/dot.cu +++ b/src/operator/tensor/dot.cu @@ -23,7 +23,6 @@ */ #include "./dot-inl.h" -#include namespace mxnet { namespace op { diff --git a/src/operator/tensor/util/tensor_util-inl.cuh b/src/operator/tensor/util/tensor_util-inl.cuh index f38e8e117c94..c9ee625af0c8 100644 --- a/src/operator/tensor/util/tensor_util-inl.cuh +++ b/src/operator/tensor/util/tensor_util-inl.cuh @@ -231,6 +231,22 @@ struct MarkCsrColWarpKernel { } }; +/*! + * \brief GPU Kernel to perform histogram (input types should be integer types) + * Parallelization by output elements: 1 thread/input element + */ +struct HistogramKernel { + template + __device__ __forceinline__ static void Map(int tid, + IType* target, + const CType* source, + const nnvm::dim_t num_elems) { + if (tid < num_elems) { + atomicAdd(&target[source[tid]], 1); + } + } +}; + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 7dcc670519b6..5382b70e3d8f 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1219,13 +1219,13 @@ def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density, rhs_density, tra for forward_stype in all_stypes: lhs = lhs_nd.tostype(lhs_stype) rhs = rhs_nd.tostype(rhs_stype) - out = mx.nd.dot(lhs, rhs, forward_stype_hint=forward_stype, + out = mx.nd.dot(lhs, rhs, forward_stype=forward_stype, transpose_a=trans_a, transpose_b=trans_b) assert_almost_equal(out.tostype('default').asnumpy(), out_np, rtol=1e-4, atol=1e-5) lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype) rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype) out = mx.symbol.sparse.dot(lhs_var, rhs_var, - forward_stype_hint=forward_stype, + forward_stype=forward_stype, transpose_a=trans_a, transpose_b=trans_b) location = {'lhs': lhs, 'rhs': rhs} check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) @@ -1260,15 +1260,19 @@ def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=F rhs_nd = rand_ndarray(rhs_shape, stype='csr', density=rhs_density) rhs_dns = rhs_nd.tostype('default') - out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs) - out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs) + if default_context() == mx.cpu(): + forward_stype = 'csr' + else: + forward_stype = 'default' + out = mx.nd.sparse.dot(lhs_nd, rhs_nd, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) + out_dns = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) out_np = out_dns.asnumpy() assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) # test symbolic forward lhs = mx.symbol.Variable('lhs', stype='default') rhs = mx.symbol.Variable('rhs', stype='csr') - out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs) + out = mx.symbol.sparse.dot(lhs, rhs, transpose_a=trans_lhs, transpose_b=trans_rhs, forward_stype=forward_stype) location = {'lhs': lhs_nd, 'rhs': rhs_nd} check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) @@ -1299,7 +1303,7 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): sps_out = mx.nd.sparse.dot(lhs.tostype('csr'), rhs.tostype('row_sparse'), transpose_a=trans_lhs) assert same(dns_out.asnumpy(), sps_out.asnumpy()) - density = [1.00, 0.50, 0.01] + density = [1.00, 0.5, 0.01] for lhs_d in density: lhs_shape = rand_shape_2d(50, 200) rhs_d = 1 @@ -1325,6 +1329,35 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): test_sparse_dot_zero_output(rand_shape_2d(50, 200), True, 40) +@with_seed() +def test_sparse_dot_determinism(): + def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b): + lhs_row = rnd.randint(200, 400) + lhs_col = rnd.randint(200, 400) + if transpose_a: + if transpose_b: + rhs_shape = (rnd.randint(200, 400), lhs_row) + else: + rhs_shape = (lhs_row, rnd.randint(200, 400)) + else: + if transpose_b: + rhs_shape = (rnd.randint(200, 400), lhs_col) + else: + rhs_shape = (lhs_col, rnd.randint(200, 400)) + if default_context() == mx.cpu(): + forward_stype = 'csr' + else: + forward_stype = 'default' + lhs_shape = (lhs_row, lhs_col) + lhs = rand_ndarray(lhs_shape, lhs_stype, density=lhs_density) + rhs = rand_ndarray(rhs_shape, rhs_stype, density=rhs_density) + res1 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) + res2 = mx.nd.sparse.dot(lhs, rhs, transpose_a=transpose_a, transpose_b=transpose_b, forward_stype=forward_stype) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.0, atol=0.0) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, False) + test_dot_determinism('default', 'csr', 1.0, 0.1, False, True) + + @with_seed() def test_sparse_slice(): def check_csr_slice(shape, slice_input): From 5ab47fd6bc86c727bf462a868a6107435ee4148c Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Mon, 23 Apr 2018 23:31:17 +0000 Subject: [PATCH 7/7] address code reviews and add backward --- src/operator/tensor/dot-inl.cuh | 23 ++-- src/operator/tensor/dot-inl.h | 116 ++++++++++++++++-- src/operator/tensor/dot.cc | 103 +--------------- tests/python/unittest/test_sparse_operator.py | 38 +++--- 4 files changed, 145 insertions(+), 135 deletions(-) diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh index 10ef7546b464..86df5801c73c 100644 --- a/src/operator/tensor/dot-inl.cuh +++ b/src/operator/tensor/dot-inl.cuh @@ -27,7 +27,6 @@ #include #include -#include "./indexing_op.h" #include "./init_op.h" #include "./sort_op.h" #include "./util/tensor_util-inl.h" @@ -991,11 +990,17 @@ inline void DotCsrRspDnsImpl(const OpContext& ctx, }); } +// Returns integer log2(a) rounded up +inline int log2i(size_t a) { + int k = 1; + while (a >>= 1) k++; + return k; +} + /* * \brief GPU Impl of dot(dns, csr) = csr */ -template -inline void DotDnsCsrCsrImpl(const OpContext& ctx, +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const gpu& gpu_dev, const TBlob& lhs, const NDArray& rhs, const OpReqType req, NDArray* ret) { LOG(FATAL) << "dot(dense, csr) = csr is not implemented on GPU"; @@ -1004,11 +1009,13 @@ inline void DotDnsCsrCsrImpl(const OpContext& ctx, /* * \brief GPU Impl of dot(dns, csr) = dns and dot(dns, csr.T) = dns */ -template -inline void DotDnsCsrDnsImpl(const OpContext& ctx, +inline void DotDnsCsrDnsImpl(const OpContext& ctx, const gpu& gpu_dev, const TBlob& dns, const NDArray& rhs, const OpReqType req, NDArray* ret, const bool transpose_b) { + if (req == kNullOp) { + return; + } CHECK_EQ(req, kWriteTo); CHECK_EQ(rhs.storage_type(), kCSRStorage); @@ -1052,7 +1059,7 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, size_t csc_indptr_bytes = (num_csr_cols+1)*sizeof(CType); size_t csc_data_bytes = nnz*sizeof(DType); size_t scan_temp_storage_bytes = 0; - size_t temp_storage_bytes = SortByKeyWorkspaceSize(nnz); + size_t temp_storage_bytes = SortByKeyWorkspaceSize(nnz); IType* csr_indices_ptr = csr_indices.dptr(); cub::DeviceScan::ExclusiveSum(temp_storage_ptr, scan_temp_storage_bytes, @@ -1093,9 +1100,7 @@ inline void DotDnsCsrDnsImpl(const OpContext& ctx, Tensor csc_cols(csc_cols_ptr, Shape1(nnz), s); Tensor temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s); - int num_bits = 1; - unsigned int a = num_csr_cols - 1; - while (a >>= 1) num_bits++; + int num_bits = log2i(num_csr_cols - 1); SortByKey(csc_cols, original_idx, true, &temp_storage, 0, num_bits); // Scatter csr indptr to row id diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h index 69430ca55a68..2c9a483567f8 100644 --- a/src/operator/tensor/dot-inl.h +++ b/src/operator/tensor/dot-inl.h @@ -323,6 +323,15 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, dispatched = true; } } + if (!dispatched && dev_mask == mshadow::gpu::kDevMask && !param.transpose_a && + lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && + ograd_stype == kDefaultStorage) { + if (type_assign(&lhs_grad_stype, kDefaultStorage) && + type_assign(&rhs_grad_stype, kDefaultStorage)) { + DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx); + dispatched = true; + } + } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); } @@ -931,19 +940,101 @@ inline void DotCsrRspRspImpl(const OpContext& ctx, /* * \brief Impl of dot(dns, csr) = csr */ -template -inline void DotDnsCsrCsrImpl(const OpContext& ctx, - const TBlob& dns, const NDArray& rhs, - const OpReqType req, NDArray* ret); +inline void DotDnsCsrCsrImpl(const OpContext& ctx, const cpu& cpu_dev, + const TBlob& lhs, const NDArray& rhs, + const OpReqType req, NDArray* ret) { + if (kNullOp == req) return; + + CHECK_EQ(req, kWriteTo); + CHECK_EQ(rhs.storage_type(), kCSRStorage); + + using namespace mshadow; + using namespace mshadow::expr; + using nnvm::dim_t; + + /* Initialize data structures */ + mshadow::Stream* s = ctx.get_stream(); + const NDArray& out = *ret; + const TBlob data_l = lhs; + const TBlob data_r = rhs.data(); + const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); + const TBlob col_idx_r = rhs.aux_data(csr::kIdx); + if (!rhs.storage_initialized()) { + FillZerosCsrImpl(s, *ret); + return; + } + + MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type + /* Allocate workspace */ + CType num_cols_out = out.shape()[1]; + CType rhs_data_size = static_cast(col_idx_r.shape_.Size()); + size_t workspace_size = 2 * num_cols_out * sizeof(CType); + Tensor workspace = + ctx.requested[0].get_space_typed( + Shape1(workspace_size), s); + CType* col_flg = reinterpret_cast(workspace.dptr_); + + CType* prefix_sum = col_flg; + CType* nnc_idx = prefix_sum + num_cols_out; + + /* Set the column flags for nnz columns */ + mxnet_op::Kernel::Launch(s, num_cols_out, + col_flg); + mxnet_op::Kernel::Launch( + s, rhs_data_size, col_flg, col_idx_r.dptr()); + + /* 1. Calculate prefix sum from col flgs + * 2. Storage all non zero column indexes in nnc_idx + */ + CType cur = 0; + prefix_sum[0] = col_flg[0]; + if (prefix_sum[0]) nnc_idx[cur++] = 0; + for (CType i = 1; i < num_cols_out; i++) { + prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; + if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; + } + + /* Allocate aux data for out */ + IType num_rows_l = lhs.shape_[0]; + dim_t nnc = prefix_sum[num_cols_out - 1]; + dim_t nnz = nnc * num_rows_l; + out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); + out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + + /* Set csr indptr and index according to nnc_idx*/ + IType* indptr_out = out.aux_data(csr::kIndPtr).dptr(); + CType* col_idx_out = out.aux_data(csr::kIdx).dptr(); + DType* data_out = out.data().dptr(); + mxnet_op::Kernel::Launch( + s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); + mxnet_op::Kernel::Launch(s, nnz, data_out); + + const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); + const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; + + IType num_rows_r = rhs.shape()[0]; + mxnet_op::Kernel::Launch( + s, num_threads, data_out, data_l.dptr(), + indptr_r.dptr(), col_idx_r.dptr(), + data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, + nnc, prefix_sum); + }); + }); + }); +} /* * \brief Impl of dot(dns, csr) = dense (GPU only) */ -template -inline void DotDnsCsrDnsImpl(const OpContext& ctx, +inline void DotDnsCsrDnsImpl(const OpContext& ctx, const cpu& cpu_dev, const TBlob& dns, const NDArray& rhs, const OpReqType req, NDArray* ret, - const bool transpose_b); + const bool transpose_b) { + LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; +} inline bool DotShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -1023,11 +1114,11 @@ void DotForwardEx(const nnvm::NodeAttrs& attrs, out_stype == kCSRStorage && !(param.transpose_a || param.transpose_b)) { NDArray ret = outputs[0]; - DotDnsCsrCsrImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret); + DotDnsCsrCsrImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret); } else if (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage && out_stype == kDefaultStorage && !(param.transpose_a)) { NDArray ret = outputs[0]; - DotDnsCsrDnsImpl(ctx, inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b); + DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[1], req[0], &ret, param.transpose_b); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } @@ -1049,11 +1140,11 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; const DotParam& param = nnvm::get(attrs.parsed); - CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; CHECK_EQ(inputs[0].shape().ndim(), 2) << "sparse dot only supports 2 dimensional lhs"; CHECK_EQ(inputs[1].shape().ndim(), 2) << "sparse dot only supports 2 dimensional rhs"; const auto ograd_stype = inputs[0].storage_type(); const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); const auto grad_rhs_stype = outputs[1].storage_type(); if (ograd_stype == kDefaultStorage // ograd dns format && lhs_stype == kCSRStorage // csr input lhs of the op @@ -1072,6 +1163,11 @@ void DotBackwardEx(const nnvm::NodeAttrs& attrs, && grad_rhs_stype == kDefaultStorage && !param.transpose_b) { TBlob ret = outputs[1].data(); DotCsrRspDnsImpl(ctx, xpu(), inputs[1], inputs[0], req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage && // ograd dns format + lhs_stype == kDefaultStorage && // lhs dns format + rhs_stype == kCSRStorage && !param.transpose_a) { + NDArray ret = outputs[0]; + DotDnsCsrDnsImpl(ctx, xpu(), inputs[0].data(), inputs[2], req[0], &ret, !param.transpose_b); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc index 89985e3a3f09..9d62d0daa391 100644 --- a/src/operator/tensor/dot.cc +++ b/src/operator/tensor/dot.cc @@ -28,105 +28,6 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(DotParam); -/* - * \brief CPU Impl of dot(dns, csr) = csr - */ -template -inline void DotDnsCsrCsrImpl(const OpContext& ctx, - const TBlob& lhs, const NDArray& rhs, - const OpReqType req, NDArray* ret) { - if (kNullOp == req) return; - - CHECK_EQ(req, kWriteTo); - CHECK_EQ(rhs.storage_type(), kCSRStorage); - - using namespace mshadow; - using namespace mshadow::expr; - using nnvm::dim_t; - - /* Initialize data structures */ - mshadow::Stream* s = ctx.get_stream(); - const NDArray& out = *ret; - const TBlob data_l = lhs; - const TBlob data_r = rhs.data(); - const TBlob indptr_r = rhs.aux_data(csr::kIndPtr); - const TBlob col_idx_r = rhs.aux_data(csr::kIdx); - if (!rhs.storage_initialized()) { - FillZerosCsrImpl(s, *ret); - return; - } - - MSHADOW_SGL_DBL_TYPE_SWITCH(data_r.type_flag_, DType, { // data type - MSHADOW_IDX_TYPE_SWITCH(indptr_r.type_flag_, IType, { // indptr type - MSHADOW_IDX_TYPE_SWITCH(col_idx_r.type_flag_, CType, { // colidx type - /* Allocate workspace */ - CType num_cols_out = out.shape()[1]; - CType rhs_data_size = static_cast(col_idx_r.shape_.Size()); - size_t workspace_size = 2 * num_cols_out * sizeof(CType); - Tensor workspace = - ctx.requested[0].get_space_typed( - Shape1(workspace_size), s); - CType* col_flg = reinterpret_cast(workspace.dptr_); - - CType* prefix_sum = col_flg; - CType* nnc_idx = prefix_sum + num_cols_out; - - /* Set the column flags for nnz columns */ - mxnet_op::Kernel::Launch(s, num_cols_out, - col_flg); - mxnet_op::Kernel::Launch( - s, rhs_data_size, col_flg, col_idx_r.dptr()); - - /* 1. Calculate prefix sum from col flgs - * 2. Storage all non zero column indexes in nnc_idx - */ - CType cur = 0; - prefix_sum[0] = col_flg[0]; - if (prefix_sum[0]) nnc_idx[cur++] = 0; - for (CType i = 1; i < num_cols_out; i++) { - prefix_sum[i] = prefix_sum[i - 1] + col_flg[i]; - if (prefix_sum[i] > prefix_sum[i - 1]) nnc_idx[cur++] = i; - } - - /* Allocate aux data for out */ - IType num_rows_l = lhs.shape_[0]; - dim_t nnc = prefix_sum[num_cols_out - 1]; - dim_t nnz = nnc * num_rows_l; - out.CheckAndAllocAuxData(csr::kIndPtr, Shape1(num_rows_l + 1)); - out.CheckAndAllocAuxData(csr::kIdx, Shape1(nnz)); - out.CheckAndAllocData(Shape1(nnz)); - - /* Set csr indptr and index according to nnc_idx*/ - IType* indptr_out = out.aux_data(csr::kIndPtr).dptr(); - CType* col_idx_out = out.aux_data(csr::kIdx).dptr(); - DType* data_out = out.data().dptr(); - mxnet_op::Kernel::Launch( - s, num_rows_l, nnc_idx, indptr_out, col_idx_out, nnc, num_rows_l); - mxnet_op::Kernel::Launch(s, nnz, data_out); - - const dim_t num_threads = mxnet_op::get_num_threads(num_rows_l); - const dim_t seg_len = (num_rows_l + num_threads - 1) / num_threads; - - IType num_rows_r = rhs.shape()[0]; - mxnet_op::Kernel::Launch( - s, num_threads, data_out, data_l.dptr(), - indptr_r.dptr(), col_idx_r.dptr(), - data_r.dptr(), seg_len, num_rows_r, num_rows_l, num_cols_out, - nnc, prefix_sum); - }); - }); - }); -} - - -template -inline void DotDnsCsrDnsImpl(const OpContext& ctx, - const TBlob& dns, const NDArray& rhs, - const OpReqType req, NDArray* ret, - const bool transpose_b) { - LOG(FATAL) << "dot(dense, csr) = dense is not implemented on CPU"; -} - NNVM_REGISTER_OP(dot) .add_alias("_sparse_dot") // alias for op registration under mxnet.ndarray.sparse .describe(R"doc(Dot product of two arrays. @@ -158,8 +59,8 @@ Implemented sprase operations include: - dot(csr, default, transpose_a=True) = row_sparse - dot(csr, row_sparse) = default - dot(default, csr) = csr on CPU only -- dot(default, csr) = default on GPU only -- dot(default, csr, transpose_b=True) = default on GPU only +- dot(default, csr, forward_stype='default') = default on GPU only +- dot(default, csr, transpose_b=True, forward_stype='default') = default on GPU only - if the combination of input storage types and forward_stype_hint - does not match any of the above patterns, - dot will generate output with default storage diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 5382b70e3d8f..16b52f60ceb9 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1276,15 +1276,23 @@ def test_dot_dns_csr(lhs_shape, rhs_shape, lhs_density, rhs_density, trans_lhs=F location = {'lhs': lhs_nd, 'rhs': rhs_nd} check_symbolic_forward(out, location, [out_np], rtol=1e-3, atol=1e-4) - # test symbolic backward - backward_trans = not trans_lhs - rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() - if trans_rhs is True: - rhs_backward_grad = rhs_backward_grad.T - expected = {'rhs': rhs_backward_grad} - check_symbolic_backward(out, location, [out_np], expected, - grad_req={'lhs': 'null', 'rhs': 'write'}, - rtol=1e-3, atol=1e-4) + if default_context() == mx.cpu(): + # test symbolic backward + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_nd, out_dns, transpose_a=backward_trans).asnumpy() + if trans_rhs is True: + rhs_backward_grad = rhs_backward_grad.T + expected = {'rhs': rhs_backward_grad} + check_symbolic_backward(out, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + else: + transpose_b = not trans_rhs + lhs_backward_grad = mx.nd.dot(out_dns, rhs_dns, transpose_b=transpose_b) + expected = {'lhs': lhs_backward_grad.asnumpy()} + check_symbolic_backward(out, location, [out_np], expected, + grad_req={'lhs': 'write', 'rhs': 'null'}, + rtol=1e-3, atol=1e-4) def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): """Test for nnr_out = 0. Before the fix, the test would fail.""" @@ -1332,18 +1340,18 @@ def test_sparse_dot_zero_output(lhs_shape, trans_lhs, rhs_num_cols): @with_seed() def test_sparse_dot_determinism(): def test_dot_determinism(lhs_stype, rhs_stype, lhs_density, rhs_density, transpose_a, transpose_b): - lhs_row = rnd.randint(200, 400) - lhs_col = rnd.randint(200, 400) + lhs_row = rnd.randint(50, 100) + lhs_col = rnd.randint(50, 100) if transpose_a: if transpose_b: - rhs_shape = (rnd.randint(200, 400), lhs_row) + rhs_shape = (rnd.randint(50, 100), lhs_row) else: - rhs_shape = (lhs_row, rnd.randint(200, 400)) + rhs_shape = (lhs_row, rnd.randint(50, 100)) else: if transpose_b: - rhs_shape = (rnd.randint(200, 400), lhs_col) + rhs_shape = (rnd.randint(50, 100), lhs_col) else: - rhs_shape = (lhs_col, rnd.randint(200, 400)) + rhs_shape = (lhs_col, rnd.randint(50, 100)) if default_context() == mx.cpu(): forward_stype = 'csr' else: