Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_BROADCAST_OP_H_

#include <mxnet/operator_util.h>
#include <mxnet/op_attr_types.h>
#include <algorithm>
#include <vector>
#include <string>
Expand Down Expand Up @@ -76,6 +77,31 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
return true;
}

inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
const int lhs_stype = in_attrs->at(0);
const int rhs_stype = in_attrs->at(1);
int& out_stype = out_attrs->at(0);
bool dispatched = false;
if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kCSRStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
if (!dispatched) {
dispatched = dispatch_fallback(out_attrs, dispatch_mode);
}
return dispatched;
}

#define BROADCAST_NDIM_SWITCH(ndim, NDim, ...) \
if (ndim <= 2) { \
const int NDim = 2; \
Expand Down Expand Up @@ -155,6 +181,22 @@ struct binary_broadcast_kernel {
}
}
};

template<int req, typename OP>
struct csr_dns_csr_broadcast_kernel {
template <typename DType, typename CType, typename RType>
MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType *csr_indices,
const RType *csr_indptr, const DType *dns,
DType *out, const nnvm::dim_t row_length, bool col_vec) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

col_vec could be part of the template

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag may actually be abandoned later, will do if it still exists later.

const nnvm::dim_t curr_row_i = csr_indptr[row];
const nnvm::dim_t next_row_i = csr_indptr[row + 1];
for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
KERNEL_ASSIGN(out[iter], req, OP::Map(csr_data[iter],
(col_vec)? dns[row] : dns[csr_indices[iter]]));
}
}
};

} // namespace mxnet_op

template<typename xpu, typename OP>
Expand Down Expand Up @@ -185,6 +227,81 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
}
}

template<typename xpu, typename OP>
void BinaryBroadcastCsrDnsCsrImpl(const OpContext& ctx,
const NDArray& csr,
const NDArray& dns,
const OpReqType req,
const NDArray& output) {
using namespace mshadow;
using namespace mxnet_op;
using namespace csr;
CHECK(req != kAddTo && req != kWriteInplace);
CHECK_EQ(dns.shape().ndim(), 1) << "input dense should be a vector";
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
bool col_vec = (dns.shape()[0] == csr.shape()[0])? true : false;
if (csr.storage_initialized()) {
const nnvm::dim_t nnz = csr.storage_shape()[0];
const nnvm::dim_t num_rows = output.shape()[0];
output.CheckAndAlloc({Shape1(num_rows + 1), Shape1(nnz)});

MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIdx), CType, {
MSHADOW_IDX_TYPE_SWITCH(output.aux_type(kIndPtr), RType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
Kernel<csr_dns_csr_broadcast_kernel<req_type, OP>, xpu>::Launch(
s, num_rows, csr.data().dptr<DType>(), csr.aux_data(kIdx).dptr<CType>(),
csr.aux_data(kIndPtr).dptr<RType>(), dns.data().dptr<DType>(),
output.data().dptr<DType>(), csr.shape()[1], col_vec);
Copy(output.aux_data(kIdx).FlatTo1D<xpu, CType>(),
csr.aux_data(kIdx).FlatTo1D<xpu, CType>());
Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, RType>(),
csr.aux_data(kIndPtr).FlatTo1D<xpu, RType>());
});
});
});
});
// If input csr is an empty matrix, fill zeros and return
} else {
FillZerosCsrImpl(s, output);
return;
}
}

template<typename xpu, typename OP>
void BinaryBroadcastComputeCsrEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2U) << "input dense matrix should have less than 2 dimensions";
const NDArray& lhs = inputs[0];
const NDArray& rhs = inputs[1];
const NDArray& out = outputs[0];
const auto lhs_stype = lhs.storage_type();
const auto rhs_stype = rhs.storage_type();
const auto out_stype = out.storage_type();
// if (!(inputs[1].shape().ndim() == 1U)) {
// ElemwiseBinaryOp::ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
// } else {
if (req[0] != kNullOp) {
// broadcast(CSR, Dense(1D)) = CSR
if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kCSRStorage) {
BinaryBroadcastCsrDnsCsrImpl<xpu, OP>(ctx, lhs, rhs, req[0], out);
// broadcast(CSR, Dense(1D)) = Dense
//} else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage &&
// out_stype == kDefaultStorage) {
// BinaryBroadCastCsrDnsDnsImpl(ctx, lhs, rhs, req[0], out);
} else {
LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
}
}
// }
}

template<typename xpu, typename LOP, typename ROP>
void BinaryBroadcastBackwardUseNone(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
10 changes: 10 additions & 0 deletions src/operator/tensor/elemwise_binary_broadcast_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,13 @@ Example::
broadcast_mul(x, y) = [[ 0., 0., 0.],
[ 1., 1., 1.]]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please only add the docs for broadcast_mul here.

Supported sparse operations:
broadcast_mul(csr, dense(1D)) = csr

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>)
.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeCsrEx<cpu, op::mshadow_op::mul>)
.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"});


Expand Down Expand Up @@ -154,8 +159,13 @@ Example::
broadcast_div(x, y) = [[ 3., 3., 3.],
[ 2., 2., 2.]]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update doc for broadcast_div

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Supported sparse operations:
broadcast_div(csr, dense(1D)) = csr

)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::div>)
.set_attr<FComputeEx>("FComputeEx<cpu>", BinaryBroadcastComputeCsrEx<cpu, op::mshadow_op::div>)
.set_attr<FInferStorageType>("FInferStorageType", BinaryBroadcastMulStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"});

NNVM_REGISTER_OP(_backward_broadcast_div)
Expand Down