This repository was archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.7k
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Still too strict dtype requirements for broadcast_like #19343
Copy link
Copy link
Open
Labels
Description
Description
mx.nd.broadcast_like has unnecessarily strict dtype requirements for its two data inputs. PR #17977 was aimed to relax them but in MXNet 1.7 I still get the following error from this minimum example:
> a = mx.nd.ones((96, 1), dtype='float32')
> b = mx.nd.ones((96, 32, 32), dtype='float16')
> mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))Error Message
---------------------------------------------------------------------------
MXNetError Traceback (most recent call last)
<ipython-input-17-3cde0a810695> in <module>
----> 1 mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))
~/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/register.py in broadcast_like(lhs, rhs, lhs_axes, rhs_axes, out, name, **kwargs)
~/miniconda3/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list)
89 c_str_array(keys),
90 c_str_array([str(s) for s in vals]),
---> 91 ctypes.byref(out_stypes)))
92
93 create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls
~/miniconda3/lib/python3.7/site-packages/mxnet/base.py in check_call(ret)
244 """
245 if ret != 0:
--> 246 raise get_last_ffi_error()
247
248
MXNetError: Traceback (most recent call last):
File "/home/centos/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node at 1-th input: expected float32, got float16To Reproduce
a = mx.nd.ones((96, 1), dtype='float32')
b = mx.nd.ones((96, 32, 32), dtype='float16')
mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))What have you tried to solve it?
aligning dtypes by either casting a or b to the other dtype resolves the issue. But this is not a practical solution in my AMP use case as both tensors are fairly large and I want to avoid a copy.
Environment
mxnet-cu92 1.7.0