[MXNET-117] [WIP] [DO NOT MERGE] Sparse operator broadcast_mul/div(csr, dense) = csr#10150
[MXNET-117] [WIP] [DO NOT MERGE] Sparse operator broadcast_mul/div(csr, dense) = csr#10150haojin2 wants to merge 2 commits intoapache:masterfrom
Conversation
f0f77bf to
20ba3d0
Compare
| } | ||
| }; | ||
|
|
||
| template<typename DType, typename CType, typename RType, int req, typename OP> |
There was a problem hiding this comment.
can the template types be at function level and be inferred automatically by the arguments passed, or are you going for type checking?
There was a problem hiding this comment.
Agree with applying these as function's template arguments instead of the class's.
There was a problem hiding this comment.
Will make that change soon.
| MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, { | ||
| MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, { | ||
| MXNET_ASSIGN_REQ_SWITCH(req, req_type, { | ||
| Kernel<csr_dns_csr_broadcast_kernel<DType, CType, RType, req_type, OP>, xpu>::Launch( |
There was a problem hiding this comment.
yeah, you can probably get by without passing all of these template parameters
| bool col_vec = (dns.shape()[0] == csr.shape()[0])? true : false; | ||
| if (!csr.storage_initialized()) { | ||
| FillZerosCsrImpl(s, output); | ||
| return; |
There was a problem hiding this comment.
would just an else block rather than a return be more readable?
There was a problem hiding this comment.
Sure, will do that.
| // out_stype == kDefaultStorage) { | ||
| // BinaryBroadCastCsrDnsDnsImpl(ctx, inputs[0], input[1], req[0], outputs[0]); | ||
| } else { | ||
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
There was a problem hiding this comment.
would catching this in the storage type inference and then doing a fallback not work for this case?
There was a problem hiding this comment.
If we get really big sparse matrices as inputs then fallback may not be a better choice than throwing?
There was a problem hiding this comment.
Usually in the other cases, it falls back and will print a warning the first time. Are there other cases where it just throws an error rather than falling back?
There was a problem hiding this comment.
I agree that throwing an error is not desirable and blocks users from what they want to do. The problem is that finferstorage is not aware of shape and dtype, and dispatch only based on dev_mask and storage types. And for this sparse broadcast operator it's a lot of work to implement cases for 2-D and 3-D.
Maybe a temporary walk-around is to fallback inside the operator..
| } | ||
|
|
||
| template<typename xpu, typename OP> | ||
| void BinaryBroadCastCsrDnsCsrImpl(const OpContext& ctx, |
| } | ||
| }; | ||
|
|
||
| template<typename DType, typename CType, typename RType, int req, typename OP> |
There was a problem hiding this comment.
Agree with applying these as function's template arguments instead of the class's.
| return true; | ||
| } | ||
|
|
||
| inline bool BinaryBroadcastStorageTypeCsr(const nnvm::NodeAttrs& attrs, |
There was a problem hiding this comment.
No need to put Csr in the title since we may add row_sparse in the same function in future. Also the name is confusing because this is only for mul/div
There was a problem hiding this comment.
Changed to BinaryBroadcastMulStorageType
| std::vector<int>* out_attrs) { | ||
| CHECK_EQ(in_attrs->size(), 2U); | ||
| CHECK_EQ(out_attrs->size(), 1U); | ||
| const int in1_stype = in_attrs->at(0); |
There was a problem hiding this comment.
I'd think left/right-hand side (lhs/rhs) is a better name compared to in1/in2
| @@ -122,6 +122,8 @@ Example:: | |||
|
|
|||
There was a problem hiding this comment.
Need to clarify what's supported/not supported in doc like https://mxnet.incubator.apache.org/versions/master/api/python/ndarray/sparse.html#mxnet.ndarray.sparse.dot
There was a problem hiding this comment.
Please only add the docs for broadcast_mul here.
| // out_stype == kDefaultStorage) { | ||
| // BinaryBroadCastCsrDnsDnsImpl(ctx, inputs[0], input[1], req[0], outputs[0]); | ||
| } else { | ||
| LogUnimplementedOp(attrs, ctx, inputs, req, outputs); |
There was a problem hiding this comment.
I agree that throwing an error is not desirable and blocks users from what they want to do. The problem is that finferstorage is not aware of shape and dtype, and dispatch only based on dev_mask and storage types. And for this sparse broadcast operator it's a lot of work to implement cases for 2-D and 3-D.
Maybe a temporary walk-around is to fallback inside the operator..
| using namespace mxnet_op; | ||
| using namespace csr; | ||
| CHECK_EQ(dns.shape().ndim(), 1) << "input dense should be a vector"; | ||
| mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); |
There was a problem hiding this comment.
Need to check req != kAddTo / kWriteInplace
20ba3d0 to
c4f21dc
Compare
ddfa5b8 to
4626999
Compare
| MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices, | ||
| const RType *csr_indptr, const DType *dns, | ||
| DType *out, const nnvm::dim_t row_length, bool col_vec) { | ||
| nnvm::dim_t curr_row_i = csr_indptr[row]; |
| template <typename DType, typename CType, typename RType> | ||
| MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices, | ||
| const RType *csr_indptr, const DType *dns, | ||
| DType *out, const nnvm::dim_t row_length, bool col_vec) { |
There was a problem hiding this comment.
col_vec could be part of the template
There was a problem hiding this comment.
This flag may actually be abandoned later, will do if it still exists later.
| @@ -122,6 +122,8 @@ Example:: | |||
|
|
|||
There was a problem hiding this comment.
Please only add the docs for broadcast_mul here.
| @@ -156,6 +161,8 @@ Example:: | |||
|
|
|||
There was a problem hiding this comment.
Please update doc for broadcast_div
d086e34 to
edc6e28
Compare
edc6e28 to
e748187
Compare
Description
Add a sparse operator on CPU that supports broadcast_mul/div(csr, dense) = csr operations.
Checklist
Essentials
Changes
Comments
Example for broadcast_mul/div(csr, 1Ddense) = csr
import mxnet as mx
a = mx.nd.array([[0,0,3],[0,2,0],[1,0,0]]).tostype('csr')
b = mx.nd.array([1,2,3])
mx.nd.broadcast_mul(a,b).asnumpy()
array([[ 0., 0., 3.],
[ 0., 4., 0.],
[ 3., 0., 0.]], dtype=float32)