diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index ae96fd87b0db..faffe1bdea99 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -156,7 +156,7 @@ class NDArray { } /* \brief Check whether the two arrays are the same array */ - inline bool IsSame(const NDArray& other) { + inline bool IsSame(const NDArray& other) const { return ptr_ == other.ptr_ && shape_ == other.shape_ && byte_offset_ == other.byte_offset_ && diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h index 911c369b3e69..878dfb218f9b 100644 --- a/src/operator/tensor/elemwise_binary_op-inl.h +++ b/src/operator/tensor/elemwise_binary_op-inl.h @@ -31,22 +31,6 @@ namespace mxnet { namespace op { -template -void ElemwiseBinaryOp::RspRspOp(mshadow::Stream *s, - const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const NDArray &lhs, - const NDArray &rhs, - const OpReqType req, - const NDArray &output, - const bool lhs_may_be_dense, - const bool rhs_may_be_dense, - const bool allow_inplace, - const bool scatter) { - LOG(FATAL) << "GPU not supported for RspRspOp"; -} - - /*! \brief binary op handling for the following row sparse inputs/outputs rsp, rsp -> rsp, dns, rsp -> rsp, @@ -622,7 +606,7 @@ void ElemwiseBinaryOp::DnsRspDnsOp(mshadow::Stream *s, const bool reverse) { using namespace mshadow; using namespace mxnet_op; - CHECK_EQ(dns.storage_type(), kDefaultStorage); + CHECK(dns.storage_type() == kDefaultStorage || dns.storage_type() == kRowSparseStorage); CHECK_EQ(rsp.storage_type(), kRowSparseStorage); CHECK_EQ(output.data().Size(), dns.data().Size()); CHECK(req != kAddTo); diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index ad4b3e7cc4a3..fbd79bb3dd69 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -420,7 +420,7 @@ class ElemwiseBinaryOp : public OpBase { if (!dispatched && rsp && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) { // rsp, rsp, ... -> rsp dispatched = storage_type_assign(out_attrs, kRowSparseStorage, - dispatch_mode, dispatch_ex); + dispatch_mode, DispatchMode::kFComputeEx); } if (!dispatched && csr && ContainsOnlyStorage(*in_attrs, kCSRStorage)) { // csr, csr, ... -> csr diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index 5cdd8947dd49..ea8c1fb9c65d 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -22,12 +22,146 @@ * \file elemwise_binary_scalar_op.cu * \brief GPU Implementation of unary function. */ +#include #include "./elemwise_binary_op.h" #include "./elemwise_binary_op-inl.h" namespace mxnet { namespace op { +template +struct RspElemwiseKernel { + template + static MSHADOW_XINLINE void Map(int i, DType* out, const IType* lookup_table, + const DType* data, const IType* indices, + const nnvm::dim_t nz_rows, const nnvm::dim_t num_cols) { + if (i < nz_rows * num_cols) { + const nnvm::dim_t row = i / num_cols; + const nnvm::dim_t col = i % num_cols; + const nnvm::dim_t out_row = lookup_table[indices[row]] - 1; + const nnvm::dim_t out_idx = out_row * num_cols + col; + out[out_idx] = OP::Map(out[out_idx], data[i]); + } + } +}; + +template +void ElemwiseBinaryOp::RspRspOp(mshadow::Stream *s, + const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &lhs, + const NDArray &rhs, + const OpReqType req, + const NDArray &output, + const bool lhs_may_be_dense, + const bool rhs_may_be_dense, + const bool allow_inplace, + const bool scatter) { + using namespace mshadow; + using namespace mxnet_op; + using namespace mshadow::expr; + using namespace rowsparse; + + if (req == kNullOp) return; + + CHECK(!scatter) << "scatter is not supported in RspRspOp on GPU yet..."; + CHECK(lhs.storage_type() == kRowSparseStorage && rhs.storage_type() == kRowSparseStorage); + CHECK(output.storage_type() == kRowSparseStorage); + CHECK(req != kAddTo); + + const nnvm::dim_t num_rows = output.shape()[0]; + MSHADOW_TYPE_SWITCH(lhs.data().type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(lhs.aux_data(kIdx).type_flag_, IType, { + if (lhs.storage_initialized() && rhs.storage_initialized()) { + const nnvm::dim_t lhs_nz_rows = lhs.storage_shape()[0]; + const nnvm::dim_t rhs_nz_rows = rhs.storage_shape()[0]; + const nnvm::dim_t num_cols = lhs.data().Size() / lhs_nz_rows; + // Optimize for the case where one of the rsps is actually dense + if ((lhs_nz_rows == num_rows || rhs_nz_rows == num_rows) && req == kWriteInplace) { + const NDArray& dns = (output.IsSame(lhs)) ? lhs : rhs; + const NDArray& rsp = (output.IsSame(lhs)) ? rhs : lhs; + const bool reverse = !(lhs_nz_rows == num_rows); + ElemwiseBinaryOp::DnsRspDnsOp(s, attrs, ctx, dns, rsp, req, output, reverse); + return; + } + CHECK(req == kWriteTo) << "Should be kWriteTo but got " << req; + const TBlob& lhs_indices = lhs.aux_data(kIdx); + const TBlob& rhs_indices = rhs.aux_data(kIdx); + size_t common_row_table_bytes = num_rows * sizeof(IType); + IType* common_row_table = NULL; + void* temp_storage_ptr = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceScan::InclusiveSum(temp_storage_ptr, + temp_storage_bytes, + common_row_table, + common_row_table, + num_rows, + mshadow::Stream::GetStream(s)); + size_t workspace_bytes = common_row_table_bytes + temp_storage_bytes; + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_bytes), s); + common_row_table = reinterpret_cast(workspace.dptr_); + temp_storage_ptr = workspace.dptr_ + common_row_table_bytes; + mxnet_op::Kernel::Launch(s, num_rows, common_row_table); + Kernel::Launch( + s, lhs_nz_rows, common_row_table, lhs_indices.dptr(), lhs_nz_rows); + Kernel::Launch( + s, rhs_nz_rows, common_row_table, rhs_indices.dptr(), rhs_nz_rows); + cub::DeviceScan::InclusiveSum(temp_storage_ptr, + temp_storage_bytes, + common_row_table, + common_row_table, + num_rows, + mshadow::Stream::GetStream(s)); + nnvm::dim_t nnr_out = 0; + CUDA_CALL(cudaMemcpy(&nnr_out, &common_row_table[num_rows-1], sizeof(nnvm::dim_t), + cudaMemcpyDeviceToHost)); + output.CheckAndAlloc({mshadow::Shape1(nnr_out)}); + Kernel::Launch( + s, num_rows, output.aux_data(kIdx).dptr(), common_row_table, num_rows); + Kernel::Launch(s, nnr_out * num_cols, output.data().dptr()); + Kernel, gpu>::Launch( + s, lhs_nz_rows * num_cols, output.data().dptr(), common_row_table, + lhs.data().dptr(), lhs_indices.dptr(), lhs_nz_rows, num_cols); + Kernel, gpu>::Launch( + s, rhs_nz_rows * num_cols, output.data().dptr(), common_row_table, + rhs.data().dptr(), rhs_indices.dptr(), rhs_nz_rows, num_cols); + } else { + if (lhs.storage_initialized()) { + if (req == kWriteTo) { + output.CheckAndAlloc({lhs.aux_shape(kIdx)}); + Copy(output.data().FlatTo1D(), + lhs.data().FlatTo1D(), s); + Copy(output.aux_data(kIdx).FlatTo1D(), + lhs.aux_data(kIdx).FlatTo1D(), s); + } else if (req == kWriteInplace && rhs.IsSame(output)) { + LOG(FATAL) << "Inplace on an empty rhs is not supported"; + } + } else if (rhs.storage_initialized()) { + if (req == kWriteTo) { + output.CheckAndAlloc({rhs.aux_shape(kIdx)}); + } else if (req == kWriteInplace && lhs.IsSame(output)) { + LOG(FATAL) << "Inplace on an empty lhs is not supported"; + } + if (std::is_same::value) { + Kernel, gpu>::Launch( + s, rhs.data().Size(), output.data().dptr(), rhs.data().dptr()); + } else if (req == kWriteTo) { + Copy(output.data().FlatTo1D(), + rhs.data().FlatTo1D(), s); + } + if (req == kWriteTo) { + Copy(output.aux_data(kIdx).FlatTo1D(), + rhs.aux_data(kIdx).FlatTo1D(), s); + } + } else { + FillZerosRspImpl(s, output); + } + } + }); + }); +} + NNVM_REGISTER_OP(elemwise_add) .set_attr("FCompute", ElemwiseBinaryOp::ComputeWithHalf2) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index b2ff0fecb5a7..87a341586ce1 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -360,7 +360,8 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, verbose=False) if ((lhs_stype is 'default' and rhs_stype is 'row_sparse') or - (lhs_stype is 'default' and rhs_stype is 'csr')): + (lhs_stype is 'default' and rhs_stype is 'csr') or + (lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (rhs_density == 0.0)): test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape, lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=l), lambda l, r: l + r, @@ -371,6 +372,38 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, force_grad_overlap=force_grad_overlap, lhs_density=lhs_density, rhs_density=rhs_density, verbose=False) + test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape, + lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l), + lambda l, r: l - r, + lambda outg, l, r: (outg, -outg), + lhs_grad_stype, rhs_grad_stype, + ograd_density=ograd_density, + force_lr_overlap=force_lr_overlap, + force_grad_overlap=force_grad_overlap, + lhs_density=lhs_density, rhs_density=rhs_density, + verbose=False) + + if ((lhs_stype is 'row_sparse' and rhs_stype is 'row_sparse') and (lhs_density == 0.0)): + test_elemwise_binary_op("elemwise_add", lhs_stype, rhs_stype, shape, + lambda l, r: mx.sym.sparse.elemwise_add(l, r, out=r), + lambda l, r: l + r, + lambda outg, l, r: (outg, outg), + lhs_grad_stype, rhs_grad_stype, + ograd_density=ograd_density, + force_lr_overlap=force_lr_overlap, + force_grad_overlap=force_grad_overlap, + lhs_density=lhs_density, rhs_density=rhs_density, + verbose=False) + test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape, + lambda l, r: mx.sym.sparse.elemwise_sub(l, r, out=l), + lambda l, r: l - r, + lambda outg, l, r: (outg, -outg), + lhs_grad_stype, rhs_grad_stype, + ograd_density=ograd_density, + force_lr_overlap=force_lr_overlap, + force_grad_overlap=force_grad_overlap, + lhs_density=lhs_density, rhs_density=rhs_density, + verbose=False) test_elemwise_binary_op("elemwise_sub", lhs_stype, rhs_stype, shape, lambda l, r: mx.sym.sparse.elemwise_sub(l, r),