Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Still too strict dtype requirements for broadcast_like #19343

@fhieber

Description

@fhieber

Description

mx.nd.broadcast_like has unnecessarily strict dtype requirements for its two data inputs. PR #17977 was aimed to relax them but in MXNet 1.7 I still get the following error from this minimum example:

> a = mx.nd.ones((96, 1), dtype='float32')
> b = mx.nd.ones((96, 32, 32), dtype='float16')
> mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

Error Message

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
<ipython-input-17-3cde0a810695> in <module>
----> 1 mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

~/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/register.py in broadcast_like(lhs, rhs, lhs_axes, rhs_axes, out, name, **kwargs)

~/miniconda3/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list)
     89         c_str_array(keys),
     90         c_str_array([str(s) for s in vals]),
---> 91         ctypes.byref(out_stypes)))
     92
     93     create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls

~/miniconda3/lib/python3.7/site-packages/mxnet/base.py in check_call(ret)
    244     """
    245     if ret != 0:
--> 246         raise get_last_ffi_error()
    247
    248

MXNetError: Traceback (most recent call last):
  File "/home/centos/mxnet/3rdparty/mshadow/../../src/operator/tensor/../elemwise_op_common.h", line 135
MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float32, got float16

To Reproduce

a = mx.nd.ones((96, 1), dtype='float32')
b = mx.nd.ones((96, 32, 32), dtype='float16')
mx.nd.broadcast_like(a, b, lhs_axes=(1,), rhs_axes=(1,))

What have you tried to solve it?

aligning dtypes by either casting a or b to the other dtype resolves the issue. But this is not a practical solution in my AMP use case as both tensors are fairly large and I want to avoid a copy.

Environment

mxnet-cu92             1.7.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions