diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h b/src/operator/nn/dnnl/dnnl_base-inl.h index 20b8319ac110..52f2da322e7d 100644 --- a/src/operator/nn/dnnl/dnnl_base-inl.h +++ b/src/operator/nn/dnnl/dnnl_base-inl.h @@ -198,6 +198,7 @@ bool SupportDNNLBatchDot(const std::vector& inputs, const NDArray& outp bool SupportDNNLLayerNorm(const LayerNormParam& param, const std::vector& inputs); bool SupportDNNLReshape(const NDArray& input, const NDArray& output); bool SupportDNNLStack(const std::vector& inputs); +bool SupportDNNLBinary(const std::vector& inputs); } // namespace op static int GetTypeSize(int dtype) { diff --git a/src/operator/nn/dnnl/dnnl_binary-inl.h b/src/operator/nn/dnnl/dnnl_binary-inl.h new file mode 100644 index 000000000000..2cf63aa9a405 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_binary-inl.h @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_binary-inl.h + * \author: Adam Grabowski, adam.grabowski@intel.com + */ + +#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ +#define MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include "./dnnl_base-inl.h" +#include "./dnnl_ops-inl.h" +#include + +#include "../../tensor/elemwise_binary_broadcast_op.h" + +namespace mxnet { +namespace op { + +using binary_fwd_t = dnnl::binary; +using binary_fwd_pd_t = dnnl::binary::primitive_desc; + +class DNNLBinaryOpFwd { + public: + template + static DNNLBinaryOpFwd& GetBinaryOpForward(const std::vector& inputs, + const std::vector& outputs); + DNNLBinaryOpFwd(const dnnl::algorithm alg, + const std::vector& inputs, + const std::vector& outputs); + + void Execute(const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + std::shared_ptr fwd; + std::shared_ptr fwd_pd; +}; + +template +DNNLBinaryOpFwd& DNNLBinaryOpFwd::GetBinaryOpForward(const std::vector& inputs, + const std::vector& outputs) { + using binary_op_fwd_map = std::unordered_map; +#if DMLC_CXX11_THREAD_LOCAL + static thread_local binary_op_fwd_map fwds; +#else + static MX_THREAD_LOCAL binary_op_fwd_map fwds; +#endif + OpSignature key; + key.AddSign(static_cast(alg)); + key.AddSign(inputs[0]); + key.AddSign(inputs[1]); + key.AddSign(outputs[0]); + + auto it = fwds.find(key); + if (it == fwds.end()) { + const DNNLBinaryOpFwd fwd(alg, inputs, outputs); + it = AddToCache(&fwds, key, fwd); + } + return it->second; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BINARY_INL_H_ diff --git a/src/operator/nn/dnnl/dnnl_binary.cc b/src/operator/nn/dnnl/dnnl_binary.cc new file mode 100644 index 000000000000..b4d526cf1a49 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_binary.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_binary.cc + * \author: Adam Grabowski, adam.grabowski@intel.com + */ + +#if MXNET_USE_ONEDNN == 1 +#include "./dnnl_binary-inl.h" + +namespace mxnet { +namespace op { + +DNNLBinaryOpFwd::DNNLBinaryOpFwd(const dnnl::algorithm alg, + const std::vector& inputs, + const std::vector& outputs) { + auto src0_desc = inputs[0].GetDNNLData()->get_desc(); + auto src1_desc = inputs[1].GetDNNLData()->get_desc(); + auto dst_desc = outputs[0].GetDNNLData()->get_desc(); + + dnnl::binary::desc fwd_desc(alg, src0_desc, src1_desc, dst_desc); + fwd_pd = std::make_shared(fwd_desc, mxnet::CpuEngine::Get()->get_engine()); + fwd = std::make_shared(*fwd_pd); +} + +void DNNLBinaryOpFwd::Execute(const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto engine = mxnet::CpuEngine::Get()->get_engine(); + auto src0 = inputs[0].GetDNNLData(); + auto src1 = inputs[1].GetDNNLData(); + dnnl_output_t out_mem; + if (outputs[0].GetDNNLData()->get_data_handle() == inputs[1].GetDNNLData()->get_data_handle()) + out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[1]); + else + out_mem = CreateDNNLMem(outputs[0], fwd_pd->dst_desc(), req[0], &inputs[0]); + + dnnl_args_map_t args = { + {DNNL_ARG_SRC_0, *src0}, + {DNNL_ARG_SRC_1, *src1}, + {DNNL_ARG_DST, *out_mem.second}, + }; + + DNNLStream::Get()->RegisterPrimArgs(*fwd, args); + CommitOutput(outputs[0], out_mem); + DNNLStream::Get()->Submit(); +} + +bool SupportDNNLBinary(const std::vector& inputs) { + auto dtype = inputs[0].dtype(); + auto ndim_0 = inputs[0].shape().ndim(); + auto ndim_1 = inputs[1].shape().ndim(); + return ndim_0 >= 1 && ndim_0 <= 6 && ndim_1 >= 1 && ndim_1 <= 6 && + inputs[0].shape().Size() != 0 && inputs[1].shape().Size() != 0 && + dtype == mshadow::kFloat32 && dtype == inputs[1].dtype(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index fa329bf248d5..3d28ffc020e1 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -851,6 +851,53 @@ void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, } } +#if MXNET_USE_ONEDNN == 1 +inline bool NumpyBinaryBroadcastStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + + return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +} + +void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +template +void NumpyBinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (SupportDNNLBinary(inputs)) { + const dnnl::algorithm alg = DNNLAlgorithm::value; + DNNLRun(DNNLBinaryOpForward, attrs, ctx, inputs, req, outputs); + return; + } + using namespace op::mshadow_op; + std::vector in_data = {inputs[0].data(), inputs[1].data()}; + std::vector out_data = {outputs[0].data()}; + if (std::is_same::value) { + NumpyBinaryBroadcastComputeWithBool( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyBinaryBroadcastCompute( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyBinaryBroadcastComputeWithBool( + attrs, ctx, in_data, req, out_data); + } else if (std::is_same::value) { + NumpyDivideBroadcastComputeCPU(attrs, ctx, in_data, req, out_data); + } +} +#endif // MXNET_USE_ONEDNN + #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ diff --git a/src/operator/numpy/np_elemwise_broadcast_op_add.cc b/src/operator/numpy/np_elemwise_broadcast_op_add.cc index 50a79ab5dc2f..69fc12bfa7e6 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_add.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_add.cc @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) op::mshadow_op::plus, op::mshadow_op::mixed_plus, op::mshadow_op::mixed_plus>) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FComputeEx", NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) +#endif // MXNET_USE_ONEDNN .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_add"}); NNVM_REGISTER_OP(_backward_npi_broadcast_add) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc index 3e627c8c7e10..b450b816c39e 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_mul.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_mul.cc @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) op::mshadow_op::mul, op::mshadow_op::mixed_mul, op::mshadow_op::mixed_mul>) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FComputeEx", NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) +#endif // MXNET_USE_ONEDNN .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_mul"}); NNVM_REGISTER_OP(_backward_npi_broadcast_mul) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc index 5f3ba7653549..018b7a76c2ad 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_sub.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_sub.cc @@ -33,6 +33,10 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) op::mshadow_op::minus, op::mshadow_op::mixed_minus, op::mshadow_op::mixed_rminus>) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FComputeEx", NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) +#endif // MXNET_USE_ONEDNN .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_sub"}); NNVM_REGISTER_OP(_backward_npi_broadcast_sub) diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 639379d36cd0..3ef93c9d356b 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -61,6 +61,16 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs, return true; } +#if MXNET_USE_ONEDNN == 1 +void NumpyDivideBroadcastComputeCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TrueDivideBroadcastCompute(attrs, ctx, inputs, req, outputs); +} +#endif // MXNET_USE_ONEDNN + NNVM_REGISTER_OP(_npi_true_divide) .set_num_inputs(2) .set_num_outputs(1) @@ -79,6 +89,10 @@ NNVM_REGISTER_OP(_npi_true_divide) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", TrueDivideBroadcastCompute) +#if MXNET_USE_ONEDNN == 1 + .set_attr("FComputeEx", NumpyBinaryOperatorComputeExCPU) + .set_attr("FInferStorageType", NumpyBinaryBroadcastStorageType) +#endif // MXNET_USE_ONEDNN .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_broadcast_div"}) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index 20d874dbd826..1c4d84d8909e 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -91,8 +91,14 @@ inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs, int& out_stype = out_attrs->at(0); bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { +#if MXNET_USE_ONEDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet()) + dispatched = storage_type_assign( + &out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); +#else dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); +#endif // MXNET_USE_ONEDNN == 1 } if (!dispatched && lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) { dispatched = @@ -116,8 +122,14 @@ inline bool BinaryBroadcastAddStorageType(const nnvm::NodeAttrs& attrs, int& out_stype = out_attrs->at(0); bool dispatched = false; if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { +#if MXNET_USE_ONEDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask && DNNLEnvSet()) + dispatched = storage_type_assign( + &out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFComputeEx); +#else dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); +#endif // MXNET_USE_ONEDNN == 1 } if (!dispatched && ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) || (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) { @@ -788,6 +800,35 @@ void BinaryBroadcastBackwardUseIn(const nnvm::NodeAttrs& attrs, } } +#if MXNET_USE_ONEDNN == 1 +template +void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + +// template struct converting op::mshadow_op to dnnl::algorithm +template +struct DNNLAlgorithm {}; +template <> +struct DNNLAlgorithm { + static const dnnl::algorithm value = dnnl::algorithm::binary_add; +}; +template <> +struct DNNLAlgorithm { + static const dnnl::algorithm value = dnnl::algorithm::binary_sub; +}; +template <> +struct DNNLAlgorithm { + static const dnnl::algorithm value = dnnl::algorithm::binary_mul; +}; +template <> +struct DNNLAlgorithm { + static const dnnl::algorithm value = dnnl::algorithm::binary_div; +}; +#endif // MXNET_USE_ONEDNN == 1 + #define MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(2) \ diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 9d0f107aa760..cc66a1e59931 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -24,9 +24,76 @@ #include "./elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" #include "./elemwise_binary_broadcast_op.h" +#if MXNET_USE_ONEDNN == 1 +#include "../nn/dnnl/dnnl_binary-inl.h" +#endif // MXNET_USE_ONEDNN == 1 namespace mxnet { namespace op { + +#if MXNET_USE_ONEDNN == 1 +template +void DNNLBinaryOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim_diff = BinaryBroadcastShapeCompact(inputs[0].shape(), + inputs[1].shape(), + outputs[0].shape(), + &new_lshape, + &new_rshape, + &new_oshape); + std::vector new_inputs; + std::vector new_outputs; + if (ndim_diff) { + new_inputs = {inputs[0].Reshape(new_lshape), inputs[1].Reshape(new_rshape)}; + new_outputs = {outputs[0].Reshape(new_oshape)}; + } else if (inputs[0].shape().Size() == 1 && inputs[1].shape().Size() == 1) { + // BinaryBroadcastShapeCompact function doesn't reshape tensors of size (1,1,...,1) + // into shape (1). It is mandatory for oneDNN primitive to have this reshape done. + mxnet::TShape one_shape = mxnet::TShape(1, 1); + new_inputs = {inputs[0].Reshape(one_shape), inputs[1].Reshape(one_shape)}; + new_outputs = {outputs[0].Reshape(one_shape)}; + } else { + new_inputs = {inputs[0], inputs[1]}; + new_outputs = {outputs[0]}; + } + + DNNLBinaryOpFwd& fwd = DNNLBinaryOpFwd::GetBinaryOpForward(new_inputs, new_outputs); + fwd.Execute(new_inputs, req, new_outputs); +} +#endif + +template +static void BinaryOperatorComputeExCPU(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { +#if MXNET_USE_ONEDNN == 1 + if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) { + if (SupportDNNLBinary(inputs)) { + const dnnl::algorithm alg = DNNLAlgorithm::value; + DNNLRun(DNNLBinaryOpForward, attrs, ctx, inputs, req, outputs); + } else { + std::vector in_data = {inputs[0].data(), inputs[1].data()}; + std::vector out_data = {outputs[0].data()}; + BinaryBroadcastCompute(attrs, ctx, in_data, req, out_data); + } + return; + } +#endif // MXNET_USE_ONEDNN == 1 + if (std::is_same::value || + std::is_same::value) { + BinaryBroadcastComputeDenseEx(attrs, ctx, inputs, req, outputs); + } else if (std::is_same::value || + std::is_same::value) { + BinaryBroadcastComputeSparseEx(attrs, ctx, inputs, req, outputs); + } +} + MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_add) MXNET_ADD_SPARSE_OP_ALIAS(broadcast_add) MXNET_ADD_SPARSE_OP_ALIAS(broadcast_plus) @@ -56,8 +123,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeDenseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastAddStorageType) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); @@ -106,8 +172,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeDenseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastAddStorageType) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); @@ -148,8 +213,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeSparseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastMulStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); @@ -189,8 +253,7 @@ Supported sparse operations: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) - .set_attr("FComputeEx", - BinaryBroadcastComputeSparseEx) + .set_attr("FComputeEx", BinaryOperatorComputeExCPU) .set_attr("FInferStorageType", BinaryBroadcastMulStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7203212a0448..5f290318824b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -927,9 +927,9 @@ def test_sign(): assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() - npout_grad = 0; + npout_grad = 0 exe_test.backward(out_grad) assert_almost_equal(arr_grad, npout_grad) @@ -1076,7 +1076,7 @@ def test_abs(): assert_almost_equal(out, npout) out_grad = mx.nd.empty(shape) - out_grad[:] = 2; + out_grad[:] = 2 npout_grad = out_grad.asnumpy() npout_grad = npout_grad * np.sign(data_tmp) exe_test.backward(out_grad) @@ -1915,7 +1915,12 @@ def gen_broadcast_data(idx): [[1, 1, 65, 2, 22], [1, 1, 65, 1, 1]], [[1, 24, 103, 17, 18], [1, 24, 1, 1, 1]], [[1, 1, 1, 1, 2], [1, 24, 194, 50, 1]], - [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]]]) + [[1, 1, 107, 84, 9], [1, 1, 1, 1, 1]], + [[8, 1, 6, 1], [7, 1, 5]], [[5, 4], [1]], + [[256, 256, 3], [3]], [[5, 4], [4]], + [[15, 3, 5], [3, 5]], [[15, 3, 5], [1, 5]], + [[15, 3, 5], [3, 1]], [[1,1,1,1], [1,1]], + [[15,3], [4, 1, 3]], [[7, 1, 5], [8, 1, 6, 1]]]) if idx < binary_op_data_shape.shape[0]: l_shape = binary_op_data_shape[idx][0] r_shape = binary_op_data_shape[idx][1] @@ -1939,7 +1944,7 @@ def gen_broadcast_data(idx): def gen_broadcast_data_int(idx): - d = gen_broadcast_data(idx); + d = gen_broadcast_data(idx) return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] @@ -1951,7 +1956,7 @@ def gen_binary_data(dummy): def gen_binary_data_int(dummy): - d = gen_binary_data(dummy); + d = gen_binary_data(dummy) return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] @@ -2012,10 +2017,16 @@ def reduce_op(shape, x): if shape == x.shape: return x keepdims_shape = list(x.shape) + # calculate difference between output and input ndims + # to include cases where inputs' ndims are not equal + ndim_diff = len(x.shape) - len(shape) + for i in range(ndim_diff): + keepdims_shape[i] = 1 + x = np.sum(x, axis=i).reshape(keepdims_shape) for i in range(len(shape)): - if x.shape[i] != shape[i]: - keepdims_shape[i] = 1 - x = np.sum(x, axis=i).reshape(keepdims_shape) + if x.shape[ndim_diff + i] != shape[i]: + keepdims_shape[ndim_diff + i] = 1 + x = np.sum(x, axis=ndim_diff + i).reshape(keepdims_shape) return x baseline_grad1, baseline_grad2 = baseline(out, d[0], d[1])