diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 6409d4322a27..41c9ea70eb61 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -75,6 +75,10 @@ bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ + .set_attr("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector{ResourceRequest::kTempSpace}; \ + }) \ .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") #else diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index d58245a798e5..a0e204318839 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -153,7 +153,6 @@ void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, const TBlob& lhs = inputs[0]; const TBlob& rhs = inputs[1]; const TBlob& out = outputs[0]; - if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { if (lhs.type_flag_ == out.type_flag_) { MixedAllRealBinaryElemwiseCompute(attrs.op->name, ctx, lhs, rhs, out, req[0]); @@ -252,7 +251,6 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); - if (!ndim) { MixedBinaryElemwiseCompute(attrs, ctx, inputs, req, outputs); } else { @@ -290,6 +288,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, }); } }); + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } @@ -320,6 +339,27 @@ void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, BinaryBroadcastCompute( attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); } + } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } } else { PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); } @@ -384,6 +424,30 @@ void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, BinaryBroadcastComputeWithBool(attrs, ctx, inputs, req, outputs); return; } + if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) { + Stream *s = ctx.get_stream(); + TBlob temp_tblob; + if (lhs.type_flag_ == out.type_flag_) { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor temp_tensor = + ctx.requested[0].get_space_typed(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + return; + } #ifndef _WIN32 MixedBinaryBroadcastCompute(attrs, ctx, inputs, req, outputs); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 98a0396853b4..ea1f09bf40e1 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2424,6 +2424,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): use_broadcast=False, equal_nan=True) if lgrad: + if (ltype in itypes) and (rtype in itypes): + continue y.backward() if ltype not in itypes: assert_almost_equal(mx_test_x1.grad.asnumpy(), @@ -2478,6 +2480,13 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): continue check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + if func == 'subtract': + continue + for type1, type2 in itertools.product(itypes, itypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) + @with_seed() @use_np