From dc03ce447cba567c8ff24ed01ecee5e7110e66f5 Mon Sep 17 00:00:00 2001 From: JiangZhaoh Date: Fri, 27 Mar 2020 03:45:07 +0000 Subject: [PATCH 1/3] resolution --- .../numpy/np_elemwise_broadcast_op.cc | 4 ++ src/operator/numpy/np_elemwise_broadcast_op.h | 72 ++++++++++++++++++- tests/python/unittest/test_numpy_op.py | 9 +++ 3 files changed, 83 insertions(+), 2 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 6409d4322a27..41c9ea70eb61 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -75,6 +75,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") #else diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index d58245a798e5..d15328e6e261 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -51,6 +51,10 @@ inline void PrintErrorMessage(const std::string& op_name, const int dtype1, cons << " yet..."; } +inline bool is_integer(const int dtype) { + return dtype == mshadow::kBool || dtype == mshadow::kInt8 || dtype == mshadow::kInt32 || dtype == mshadow::kInt64; +} + #ifndef _WIN32 template void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, @@ -153,7 +157,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, lhs, rhs, out, req[0]); @@ -252,7 +255,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { @@ -290,6 +292,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } }); + } else if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } @@ -320,6 +343,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); } + } else if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } @@ -384,6 +428,30 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); return; } + if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + Stream *s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + return; + } #ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 98a0396853b4..ea1f09bf40e1 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2424,6 +2424,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) if lgrad: + if (ltype in itypes) and (rtype in itypes): + continue y.backward() if ltype not in itypes: assert_almost_equal(mx_test_x1.grad.asnumpy(), @@ -2478,6 +2480,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): continue check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + if func == 'subtract': + continue + for type1, type2 in itertools.product(itypes, itypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + @with_seed() @use_np From 34ab9d7822b40ffe0ce1e5184df4b0ede29a391f Mon Sep 17 00:00:00 2001 From: JiangZhaoh Date: Fri, 27 Mar 2020 06:46:27 +0000 Subject: [PATCH 2/3] fix sanity error --- src/operator/numpy/np_elemwise_broadcast_op.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index d15328e6e261..eefc312504eb 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -52,7 +52,8 @@ inline void PrintErrorMessage(const std::string& op_name, const int dtype1, cons } inline bool is_integer(const int dtype) { - return dtype == mshadow::kBool || dtype == mshadow::kInt8 || dtype == mshadow::kInt32 || dtype == mshadow::kInt64; + return dtype == mshadow::kBool || dtype == mshadow::kInt8 || + dtype == mshadow::kInt32 || dtype == mshadow::kInt64; } #ifndef _WIN32 From fcf37882315f1ff6c0795d1f61c0c13e0c36a265 Mon Sep 17 00:00:00 2001 From: JiangZhaoh Date: Wed, 8 Apr 2020 03:42:57 +0000 Subject: [PATCH 3/3] remove func 'is_integer' --- src/operator/numpy/np_elemwise_broadcast_op.h | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index eefc312504eb..a0e204318839 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -51,11 +51,6 @@ inline void PrintErrorMessage(const std::string& op_name, const int dtype1, cons << " yet..."; } -inline bool is_integer(const int dtype) { - return dtype == mshadow::kBool || dtype == mshadow::kInt8 || - dtype == mshadow::kInt32 || dtype == mshadow::kInt64; -} - #ifndef _WIN32 template void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, @@ -293,7 +288,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } }); - } else if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { TBlob temp_tblob; if (lhs.type_flag_ == out.type_flag_) { MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { @@ -344,7 +339,7 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); } - } else if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { TBlob temp_tblob; if (lhs.type_flag_ == out.type_flag_) { MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { @@ -429,7 +424,7 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); return; } - if (is_integer(lhs.type_flag_) && is_integer(rhs.type_flag_)) { + if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { Stream *s = ctx.get_stream(); TBlob temp_tblob; if (lhs.type_flag_ == out.type_flag_) {