From ba85a2b75d21519562c8f8505048b18e1e7b0ac6 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Wed, 1 Dec 2021 15:18:35 +0100 Subject: [PATCH 1/8] Where operator enabled in oneDNN --- src/operator/nn/dnnl/dnnl_ops-inl.h | 7 + src/operator/nn/dnnl/dnnl_where-inl.h | 72 ++++++++ src/operator/nn/dnnl/dnnl_where.cc | 208 ++++++++++++++++++++++ src/operator/numpy/np_where_forward_op.cc | 39 ++++ 4 files changed, 326 insertions(+) create mode 100644 src/operator/nn/dnnl/dnnl_where-inl.h create mode 100644 src/operator/nn/dnnl/dnnl_where.cc diff --git a/src/operator/nn/dnnl/dnnl_ops-inl.h b/src/operator/nn/dnnl/dnnl_ops-inl.h index 40e944939bea..06ed1e0f2625 100644 --- a/src/operator/nn/dnnl/dnnl_ops-inl.h +++ b/src/operator/nn/dnnl/dnnl_ops-inl.h @@ -210,6 +210,13 @@ void DNNLReshapeForward(const nnvm::NodeAttrs& attrs, const NDArray& input, const OpReqType& req, const NDArray& output); + +void DNNLWhereForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h new file mode 100644 index 000000000000..4360b5e76a5c --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_where-inl.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_where-inl.h + */ + +#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ +#define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include + +#include "./dnnl_base-inl.h" +#include "./dnnl_ops-inl.h" + +namespace mxnet { +namespace op { + +class DNNLWhereFwd { + public: + struct Tensors { + Tensors(const std::vector& inputs, const NDArray& output); + const NDArray& condition; + const NDArray& left; + const NDArray& right; + const NDArray& output; + }; + + static DNNLWhereFwd GetCached(const Tensors& tensors); + + DNNLWhereFwd(const Tensors& tensors); + + void Execute(const Tensors& tensors, + const std::vector& req, + const OpContext& ctx) const; + + private: + dnnl::binary::primitive_desc binary_eq_zero_pd; + dnnl::binary::primitive_desc binary_ne_zero_pd; + dnnl::binary::primitive_desc binary_mul_l_pd; + dnnl::binary::primitive_desc binary_mul_r_pd; + dnnl::binary::primitive_desc binary_sum_pd; + dnnl::binary binary_eq_zero; + dnnl::binary binary_ne_zero; + dnnl::binary binary_mul_l; + dnnl::binary binary_mul_r; + dnnl::binary binary_sum; +}; + +bool SupportDNNLWhere(const std::vector& inputs); + +} // namespace op +} // namespace mxnet +#endif +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ \ No newline at end of file diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc new file mode 100644 index 000000000000..e28905b30634 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_where.cc + */ + +#if MXNET_USE_ONEDNN == 1 + +#include "./dnnl_where-inl.h" +#include "../../operator_common.h" + +namespace mxnet { +namespace op { + +bool SupportDNNLWhere(const std::vector& inputs) { + static const std::set supported_dtypes = { + mshadow::kFloat32, mshadow::kBfloat16, mshadow::kInt8, mshadow::kUint8}; + for (int i = 0; i < inputs.size(); ++i) { + if (!supported_dtypes.count(inputs[i].dtype()) || inputs[i].shape().Size() <= 0 || + inputs[i].shape().ndim() <= 0) { + return false; + } + } + return true; +} + +void DNNLWhereForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(ctx.requested[0]); + const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs[0]); + const auto fwd = DNNLWhereFwd::GetCached(tensors); + fwd.Execute(tensors, req, ctx); +} + +DNNLWhereFwd::Tensors::Tensors(const std::vector& inputs, const NDArray& output) + : condition(inputs[0]), left(inputs[1]), right(inputs[2]), output(output) {} + +DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) { + using where_op_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local where_op_fwd_map fwds; +#else + static MX_THREAD_LOCAL where_op_fwd_map fwds; +#endif + + OpSignature key; + key.AddSign(tensors.condition); + key.AddSign(tensors.left); + key.AddSign(tensors.right); + key.AddSign(tensors.output); + + auto it = fwds.find(key); + if (it == fwds.end()) { + DNNLWhereFwd fwd(tensors); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, + const mxnet::TShape& out_shape) { + if (in_shape == out_shape) { + return in_shape; + } + + mxnet::TShape broadcastable_in_shape(out_shape.ndim(), -1); + const int lack_dims = out_shape.ndim() - in_shape.ndim(); + for (int i = 0; i < out_shape.ndim(); ++i) { + int y = 1; + if (i >= lack_dims) { + y = in_shape[i - lack_dims]; + } + broadcastable_in_shape[i] = y; + } + return broadcastable_in_shape; +} + +DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) { + const auto cpu_engine = CpuEngine::Get()->get_engine(); + + const auto cnd = tensors.condition; + const auto lhs = tensors.left; + const auto rhs = tensors.right; + const auto out = tensors.output; + + const auto cnd_shape = GetBroadcastableShape(cnd.shape(), out.shape()); + const auto lhs_shape = GetBroadcastableShape(lhs.shape(), out.shape()); + const auto rhs_shape = GetBroadcastableShape(rhs.shape(), out.shape()); + + const auto& cnd_dtype = get_dnnl_type(cnd.dtype()); + const auto& inp_dtype = get_dnnl_type(lhs.dtype()); + const auto& def_ft = static_cast(GetDefaultFormat(lhs_shape.ndim())); + + const auto& cnd_dims = dnnl::memory::dims(cnd_shape.begin(), cnd_shape.end()); + const auto& lhs_dims = dnnl::memory::dims(lhs_shape.begin(), lhs_shape.end()); + const auto& rhs_dims = dnnl::memory::dims(rhs_shape.begin(), rhs_shape.end()); + const auto& out_dims = dnnl::memory::dims(out.shape().begin(), out.shape().end()); + const auto& scalar_dims = dnnl::memory::dims(cnd_shape.ndim(), 1); // broadcastable scalar + + auto cnd_md = dnnl::memory::desc(cnd_dims, cnd_dtype, def_ft); + auto lhs_md = dnnl::memory::desc(lhs_dims, inp_dtype, def_ft); + auto rhs_md = dnnl::memory::desc(rhs_dims, inp_dtype, def_ft); + auto out_md = dnnl::memory::desc(out_dims, inp_dtype, def_ft); + auto scalar_md = dnnl::memory::desc(scalar_dims, cnd_dtype, def_ft); + + binary_eq_zero_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md, cnd_md), cpu_engine); + binary_ne_zero_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_eq, cnd_md, scalar_md, cnd_md), cpu_engine); + + // if broadcast is needed output must be larger in size + auto lmask_dim = lhs_shape.Size() > cnd_shape.Size() ? lhs_dims : cnd_dims; + auto lmask_md = dnnl::memory::desc(lmask_dim, inp_dtype, def_ft); + binary_mul_l_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_mul, lhs_md, cnd_md, lmask_md), cpu_engine); + + auto rmask_dim = rhs_shape.Size() > cnd_shape.Size() ? rhs_dims : cnd_dims; + auto rmask_md = dnnl::memory::desc(rmask_dim, inp_dtype, def_ft); + binary_mul_r_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_mul, rhs_md, cnd_md, rmask_md), cpu_engine); + + binary_sum_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_add, lmask_md, rmask_md, out_md), cpu_engine); + + binary_eq_zero = dnnl::binary(binary_eq_zero_pd); + binary_ne_zero = dnnl::binary(binary_ne_zero_pd); + binary_mul_l = dnnl::binary(binary_mul_l_pd); + binary_mul_r = dnnl::binary(binary_mul_r_pd); + binary_sum = dnnl::binary(binary_sum_pd); +} + +void DNNLWhereFwd::Execute(const Tensors& tensors, + const std::vector& req, + const OpContext& ctx) const { + const auto& cpu_engine = CpuEngine::Get()->get_engine(); + const auto& cpu_stream = ctx.get_stream(); + + const auto& cnd_tensor = tensors.condition.GetDNNLDataReorder(binary_eq_zero_pd.src0_desc()); + const auto& lhs_tensor = tensors.left.GetDNNLDataReorder(binary_mul_l_pd.src0_desc()); + const auto& rhs_tensor = tensors.right.GetDNNLDataReorder(binary_mul_r_pd.src0_desc()); + + mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]); + + const auto& ishape = tensors.left.shape(); + const int dtype_size = GetTypeSize(tensors.output.dtype()); + + // allocate temporary memory for 4 additional tensors + mshadow::Tensor tmp_workspace = ctx.requested[0].get_space( + mshadow::Shape1(tensors.output.shape().Size() * dtype_size * 4), cpu_stream); + char* workspace_ptr = reinterpret_cast(tmp_workspace.dptr_); + const int offset_size = tensors.output.shape().Size() * dtype_size; + + dnnl::memory cnd_lhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr); + dnnl::memory cnd_rhs(binary_ne_zero_pd.dst_desc(), cpu_engine, workspace_ptr + offset_size); + dnnl::memory masked_lhs(binary_mul_l_pd.dst_desc(), cpu_engine, workspace_ptr + 2 * offset_size); + dnnl::memory masked_rhs(binary_mul_r_pd.dst_desc(), cpu_engine, workspace_ptr + 3 * offset_size); + + double zero{0}; + dnnl::memory zero_scalar(binary_ne_zero_pd.src1_desc(), cpu_engine, &zero); + + DNNLStream::Get()->RegisterPrimArgs( + binary_eq_zero, + {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_lhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_ne_zero, + {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_rhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_mul_l, + {{DNNL_ARG_SRC_0, *lhs_tensor}, {DNNL_ARG_SRC_1, cnd_lhs}, {DNNL_ARG_DST, masked_lhs}}); + + DNNLStream::Get()->RegisterPrimArgs( + binary_mul_r, + {{DNNL_ARG_SRC_0, *rhs_tensor}, {DNNL_ARG_SRC_1, cnd_rhs}, {DNNL_ARG_DST, masked_rhs}}); + + DNNLStream::Get()->RegisterPrimArgs(binary_sum, + {{DNNL_ARG_SRC_0, masked_lhs}, + {DNNL_ARG_SRC_1, masked_rhs}, + {DNNL_ARG_DST, *out_mem.second}}); + + CommitOutput(tensors.output, out_mem); + DNNLStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet +#endif \ No newline at end of file diff --git a/src/operator/numpy/np_where_forward_op.cc b/src/operator/numpy/np_where_forward_op.cc index bef9b19b0c94..c60b533c19ab 100644 --- a/src/operator/numpy/np_where_forward_op.cc +++ b/src/operator/numpy/np_where_forward_op.cc @@ -23,6 +23,7 @@ */ #include "np_where_op-inl.h" +#include "../nn/dnnl/dnnl_where-inl.h" namespace mxnet { namespace op { @@ -89,6 +90,35 @@ inline bool NumpyWhereScalarOpType(const nnvm::NodeAttrs& attrs, DMLC_REGISTER_PARAMETER(NumpyWhereScalarParam); DMLC_REGISTER_PARAMETER(NumpyWhereScalar2Param); +#if MXNET_USE_ONEDNN == 1 +static void WhereForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK(!inputs.empty()); + if (req[0] == kNullOp) { + return; + } + if (SupportDNNLWhere(inputs)) { + DNNL_OPCHECK_INIT(/*is backward*/ false, outputs.size(), inputs, outputs); + DNNLRun(DNNLWhereForward, attrs, op_ctx, inputs, req, outputs); + DNNL_OPCHECK_RUN(NumpyWhereOpForward, attrs, op_ctx, inputs, req, outputs); + } else { + FallBackCompute(NumpyWhereOpForward, attrs, op_ctx, inputs, req, outputs); + } +} + +inline static bool WhereInferStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + return DNNLStorageType( + attrs, dev_mask, /*support onednn*/ true, dispatch_mode, in_attrs, out_attrs); +} +#endif // MXNET_USE_ONEDNN == 1 + NNVM_REGISTER_OP(_npi_where) .set_num_inputs(3) .set_num_outputs(1) @@ -103,6 +133,15 @@ NNVM_REGISTER_OP(_npi_where) return std::vector >{{1, 0}, {2, 0}}; }) .set_attr("FCompute", NumpyWhereOpForward) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; + }) + .set_attr("FComputeEx", WhereForwardEx) + .set_attr("TIsDNNL", true) + .set_attr("FInferStorageType", WhereInferStorageType) +#endif .set_attr( "FGradient", // Use the following lambda function instead of ElemwiseGradUseIn From bc9a8458b462f763f19f1f6492fdc89f1a415a30 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Mon, 31 Jan 2022 14:33:34 +0100 Subject: [PATCH 2/8] Fix bug & refactor --- src/operator/nn/dnnl/dnnl_where.cc | 12 +++++++----- src/operator/numpy/np_where_forward_op.cc | 12 ++++++++---- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index e28905b30634..7ee45e588dd5 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -23,8 +23,8 @@ #if MXNET_USE_ONEDNN == 1 -#include "./dnnl_where-inl.h" -#include "../../operator_common.h" +#include "dnnl_where-inl.h" +#include "src/operator/operator_common.h" namespace mxnet { namespace op { @@ -161,12 +161,14 @@ void DNNLWhereFwd::Execute(const Tensors& tensors, mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]); - const auto& ishape = tensors.left.shape(); - const int dtype_size = GetTypeSize(tensors.output.dtype()); + const auto& ishape = tensors.left.shape(); + + const int dtype_size = + std::max(GetTypeSize(tensors.condition.dtype()), GetTypeSize(tensors.left.dtype())); // allocate temporary memory for 4 additional tensors mshadow::Tensor tmp_workspace = ctx.requested[0].get_space( - mshadow::Shape1(tensors.output.shape().Size() * dtype_size * 4), cpu_stream); + mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size, cpu_stream); char* workspace_ptr = reinterpret_cast(tmp_workspace.dptr_); const int offset_size = tensors.output.shape().Size() * dtype_size; diff --git a/src/operator/numpy/np_where_forward_op.cc b/src/operator/numpy/np_where_forward_op.cc index c60b533c19ab..901814ee6137 100644 --- a/src/operator/numpy/np_where_forward_op.cc +++ b/src/operator/numpy/np_where_forward_op.cc @@ -114,8 +114,12 @@ inline static bool WhereInferStorageType(const nnvm::NodeAttrs& attrs, DispatchMode* dispatch_mode, std::vector* in_attrs, std::vector* out_attrs) { - return DNNLStorageType( - attrs, dev_mask, /*support onednn*/ true, dispatch_mode, in_attrs, out_attrs); + return DNNLStorageType(attrs, + dev_mask, + /*support onednn*/ true, + dispatch_mode, + in_attrs, + out_attrs); } #endif // MXNET_USE_ONEDNN == 1 @@ -145,8 +149,8 @@ NNVM_REGISTER_OP(_npi_where) .set_attr( "FGradient", // Use the following lambda function instead of ElemwiseGradUseIn - // for best efficiency. grad[condition] = 0; to calculate grad[x] and grad[y] - // we need only condition from input. + // for best efficiency. grad[condition] = 0; to calculate grad[x] and + // grad[y] we need only condition from input. [](const nnvm::ObjectPtr& n, const std::vector& ograds) { std::vector ret; // make zero grad node for grad[condition] From 02d4af06229e28a8bf1a7f2ae67972b513e22121 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Tue, 1 Feb 2022 11:02:36 +0100 Subject: [PATCH 3/8] fix sanity --- src/operator/nn/dnnl/dnnl_where-inl.h | 7 ++++--- src/operator/nn/dnnl/dnnl_where.cc | 9 ++++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h index 4360b5e76a5c..7d619a356b4e 100644 --- a/src/operator/nn/dnnl/dnnl_where-inl.h +++ b/src/operator/nn/dnnl/dnnl_where-inl.h @@ -25,8 +25,9 @@ #define MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ #if MXNET_USE_ONEDNN == 1 +#include +#include #include - #include "./dnnl_base-inl.h" #include "./dnnl_ops-inl.h" @@ -45,7 +46,7 @@ class DNNLWhereFwd { static DNNLWhereFwd GetCached(const Tensors& tensors); - DNNLWhereFwd(const Tensors& tensors); + explicit DNNLWhereFwd(const Tensors& tensors); void Execute(const Tensors& tensors, const std::vector& req, @@ -69,4 +70,4 @@ bool SupportDNNLWhere(const std::vector& inputs); } // namespace op } // namespace mxnet #endif -#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ \ No newline at end of file +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_WHERE_INL_H_ diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index 7ee45e588dd5..0e1350ec2b9c 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -23,8 +23,11 @@ #if MXNET_USE_ONEDNN == 1 +#include +#include +#include #include "dnnl_where-inl.h" -#include "src/operator/operator_common.h" +#include "operator/operator_common.h" namespace mxnet { namespace op { @@ -168,7 +171,7 @@ void DNNLWhereFwd::Execute(const Tensors& tensors, // allocate temporary memory for 4 additional tensors mshadow::Tensor tmp_workspace = ctx.requested[0].get_space( - mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size, cpu_stream); + mshadow::Shape1(tensors.output.shape().Size() * 4 * dtype_size), cpu_stream); char* workspace_ptr = reinterpret_cast(tmp_workspace.dptr_); const int offset_size = tensors.output.shape().Size() * dtype_size; @@ -207,4 +210,4 @@ void DNNLWhereFwd::Execute(const Tensors& tensors, } // namespace op } // namespace mxnet -#endif \ No newline at end of file +#endif From 6219ec77d1dbe4eca8a8f68fa255206c6ab3c924 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 3 Feb 2022 14:17:29 +0100 Subject: [PATCH 4/8] apply review --- src/operator/nn/dnnl/dnnl_where-inl.h | 4 ++-- src/operator/nn/dnnl/dnnl_where.cc | 8 ++------ src/operator/numpy/np_where_forward_op.cc | 5 ++--- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h index 7d619a356b4e..badc754cd872 100644 --- a/src/operator/nn/dnnl/dnnl_where-inl.h +++ b/src/operator/nn/dnnl/dnnl_where-inl.h @@ -28,8 +28,8 @@ #include #include #include -#include "./dnnl_base-inl.h" -#include "./dnnl_ops-inl.h" +#include "dnnl_base-inl.h" +#include "dnnl_ops-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index 0e1350ec2b9c..2c74b6acc1ee 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -88,12 +88,8 @@ static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, mxnet::TShape broadcastable_in_shape(out_shape.ndim(), -1); const int lack_dims = out_shape.ndim() - in_shape.ndim(); - for (int i = 0; i < out_shape.ndim(); ++i) { - int y = 1; - if (i >= lack_dims) { - y = in_shape[i - lack_dims]; - } - broadcastable_in_shape[i] = y; + for (int i = lack_dims; i < out_shape.ndim(); ++i) { + broadcastable_in_shape[i] = in_shape[i - lack_dims]; } return broadcastable_in_shape; } diff --git a/src/operator/numpy/np_where_forward_op.cc b/src/operator/numpy/np_where_forward_op.cc index 901814ee6137..6caa58d197ac 100644 --- a/src/operator/numpy/np_where_forward_op.cc +++ b/src/operator/numpy/np_where_forward_op.cc @@ -148,9 +148,8 @@ NNVM_REGISTER_OP(_npi_where) #endif .set_attr( "FGradient", - // Use the following lambda function instead of ElemwiseGradUseIn - // for best efficiency. grad[condition] = 0; to calculate grad[x] and - // grad[y] we need only condition from input. + // Use the following lambda function instead of ElemwiseGradUseIn for best efficiency. + // grad[condition] = 0; to calculate grad[x] and grad[y] we need only condition from input. [](const nnvm::ObjectPtr& n, const std::vector& ograds) { std::vector ret; // make zero grad node for grad[condition] From 96f27cedd065eb42f10d19d83b59eed128a6c51c Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 10 Feb 2022 09:15:14 +0100 Subject: [PATCH 5/8] Fix get_broadcastable_shape function --- src/operator/nn/dnnl/dnnl_where.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index 2c74b6acc1ee..f590ff838c4f 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -86,7 +86,7 @@ static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, return in_shape; } - mxnet::TShape broadcastable_in_shape(out_shape.ndim(), -1); + mxnet::TShape broadcastable_in_shape(out_shape.ndim(), 1); const int lack_dims = out_shape.ndim() - in_shape.ndim(); for (int i = lack_dims; i < out_shape.ndim(); ++i) { broadcastable_in_shape[i] = in_shape[i - lack_dims]; From a8a83671dc0609fccd3e880785ca08a4552d0561 Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Thu, 24 Feb 2022 16:49:22 +0100 Subject: [PATCH 6/8] Apply review --- src/operator/nn/dnnl/dnnl_where-inl.h | 2 +- src/operator/nn/dnnl/dnnl_where.cc | 39 +++++++++++++++++++-------- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where-inl.h b/src/operator/nn/dnnl/dnnl_where-inl.h index badc754cd872..bfda68466892 100644 --- a/src/operator/nn/dnnl/dnnl_where-inl.h +++ b/src/operator/nn/dnnl/dnnl_where-inl.h @@ -37,7 +37,7 @@ namespace op { class DNNLWhereFwd { public: struct Tensors { - Tensors(const std::vector& inputs, const NDArray& output); + Tensors(const std::vector& inputs, const std::vector& outputs); const NDArray& condition; const NDArray& left; const NDArray& right; diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index f590ff838c4f..99d44c04e7f3 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -50,13 +50,14 @@ void DNNLWhereForward(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { TmpMemMgr::Get()->Init(ctx.requested[0]); - const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs[0]); + const auto tensors = DNNLWhereFwd::Tensors(inputs, outputs); const auto fwd = DNNLWhereFwd::GetCached(tensors); fwd.Execute(tensors, req, ctx); } -DNNLWhereFwd::Tensors::Tensors(const std::vector& inputs, const NDArray& output) - : condition(inputs[0]), left(inputs[1]), right(inputs[2]), output(output) {} +DNNLWhereFwd::Tensors::Tensors(const std::vector& inputs, + const std::vector& outputs) + : condition(inputs[0]), left(inputs[1]), right(inputs[2]), output(outputs[0]) {} DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) { using where_op_fwd_map = std::unordered_map; @@ -80,6 +81,13 @@ DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) { return it->second; } +/*! + * \brief Align number of input dimensions to output. It is done by prepending shape with ones. + * oneDNN requires shapes to have same number of dimension even if they are broadcastable. + * \param in_shape input shape which should be broadcastable with output + * \param out_shape output shape to which number of dimensions of input should be aligned + * \return input shape with extended number of dimensions by one + */ static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, const mxnet::TShape& out_shape) { if (in_shape == out_shape) { @@ -122,9 +130,9 @@ DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) { auto out_md = dnnl::memory::desc(out_dims, inp_dtype, def_ft); auto scalar_md = dnnl::memory::desc(scalar_dims, cnd_dtype, def_ft); - binary_eq_zero_pd = dnnl::binary::primitive_desc( - dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md, cnd_md), cpu_engine); binary_ne_zero_pd = dnnl::binary::primitive_desc( + dnnl::binary::desc(dnnl::algorithm::binary_ne, cnd_md, scalar_md, cnd_md), cpu_engine); + binary_eq_zero_pd = dnnl::binary::primitive_desc( dnnl::binary::desc(dnnl::algorithm::binary_eq, cnd_md, scalar_md, cnd_md), cpu_engine); // if broadcast is needed output must be larger in size @@ -141,13 +149,22 @@ DNNLWhereFwd::DNNLWhereFwd(const Tensors& tensors) { binary_sum_pd = dnnl::binary::primitive_desc( dnnl::binary::desc(dnnl::algorithm::binary_add, lmask_md, rmask_md, out_md), cpu_engine); - binary_eq_zero = dnnl::binary(binary_eq_zero_pd); binary_ne_zero = dnnl::binary(binary_ne_zero_pd); + binary_eq_zero = dnnl::binary(binary_eq_zero_pd); binary_mul_l = dnnl::binary(binary_mul_l_pd); binary_mul_r = dnnl::binary(binary_mul_r_pd); binary_sum = dnnl::binary(binary_sum_pd); } +/*! + * \brief + * Execute where operator by oneDNN primitives. + * 1. Create tensor cnd_lhs = condition == 0 ==> convert 0 to 1 and all other values to 0 + * 2. Create tensor cnd_rhs = condition != 0 ==> convert all non-zero values to 1 + * 3. Mask lhs tensor by cnd_lhs => mask_lhs = lhs * cnd_lhs + * 4. Mask rhs tensor by cnd_hs => mask_rhs = rhs * cnd_rhs + * 5. output = mask_lhs + mask_rhs + */ void DNNLWhereFwd::Execute(const Tensors& tensors, const std::vector& req, const OpContext& ctx) const { @@ -171,20 +188,20 @@ void DNNLWhereFwd::Execute(const Tensors& tensors, char* workspace_ptr = reinterpret_cast(tmp_workspace.dptr_); const int offset_size = tensors.output.shape().Size() * dtype_size; - dnnl::memory cnd_lhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr); - dnnl::memory cnd_rhs(binary_ne_zero_pd.dst_desc(), cpu_engine, workspace_ptr + offset_size); + dnnl::memory cnd_lhs(binary_ne_zero_pd.dst_desc(), cpu_engine, workspace_ptr); + dnnl::memory cnd_rhs(binary_eq_zero_pd.dst_desc(), cpu_engine, workspace_ptr + offset_size); dnnl::memory masked_lhs(binary_mul_l_pd.dst_desc(), cpu_engine, workspace_ptr + 2 * offset_size); dnnl::memory masked_rhs(binary_mul_r_pd.dst_desc(), cpu_engine, workspace_ptr + 3 * offset_size); double zero{0}; - dnnl::memory zero_scalar(binary_ne_zero_pd.src1_desc(), cpu_engine, &zero); + dnnl::memory zero_scalar(binary_eq_zero_pd.src1_desc(), cpu_engine, &zero); DNNLStream::Get()->RegisterPrimArgs( - binary_eq_zero, + binary_ne_zero, {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_lhs}}); DNNLStream::Get()->RegisterPrimArgs( - binary_ne_zero, + binary_eq_zero, {{DNNL_ARG_SRC_0, *cnd_tensor}, {DNNL_ARG_SRC_1, zero_scalar}, {DNNL_ARG_DST, cnd_rhs}}); DNNLStream::Get()->RegisterPrimArgs( From 7eea3bdb8e3a88f9ec03117d77fdf215e630c37a Mon Sep 17 00:00:00 2001 From: Bartlomiej Gawrych Date: Fri, 25 Feb 2022 09:38:24 +0100 Subject: [PATCH 7/8] Remove unused variable --- src/operator/nn/dnnl/dnnl_where.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index 99d44c04e7f3..71b053e02ebf 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -177,8 +177,6 @@ void DNNLWhereFwd::Execute(const Tensors& tensors, mxnet::dnnl_output_t out_mem = CreateDNNLMem(tensors.output, binary_sum_pd.dst_desc(), req[0]); - const auto& ishape = tensors.left.shape(); - const int dtype_size = std::max(GetTypeSize(tensors.condition.dtype()), GetTypeSize(tensors.left.dtype())); From 1c1ed599b446691b0432213105edd60369128283 Mon Sep 17 00:00:00 2001 From: bgawrych Date: Thu, 3 Mar 2022 09:38:44 +0100 Subject: [PATCH 8/8] Apply suggestions from code review Co-authored-by: bartekkuncer --- src/operator/nn/dnnl/dnnl_where.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_where.cc b/src/operator/nn/dnnl/dnnl_where.cc index 71b053e02ebf..c2335b9c8d63 100644 --- a/src/operator/nn/dnnl/dnnl_where.cc +++ b/src/operator/nn/dnnl/dnnl_where.cc @@ -83,10 +83,10 @@ DNNLWhereFwd DNNLWhereFwd::GetCached(const Tensors& tensors) { /*! * \brief Align number of input dimensions to output. It is done by prepending shape with ones. - * oneDNN requires shapes to have same number of dimension even if they are broadcastable. + * oneDNN requires shapes to have same number of dimensions even if they are broadcastable. * \param in_shape input shape which should be broadcastable with output * \param out_shape output shape to which number of dimensions of input should be aligned - * \return input shape with extended number of dimensions by one + * \return input shape extended with ones to match number of dimensions of output */ static mxnet::TShape GetBroadcastableShape(const mxnet::TShape& in_shape, const mxnet::TShape& out_shape) {