diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9f89c5ee9476..29221884702c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1899,6 +1899,23 @@ def _impl_v1(cls, inputs, attr, params): ) return _op.reshape(onehot, shape_of(inputs[0])) + @classmethod + def _impl_v13(cls, inputs, attr, params) -> relay.Expr: + inferred_type = infer_type(inputs[0]) + dtype = inferred_type.checked_type.dtype + ndim = len(inferred_type.checked_type.shape) + axis = attr.get("axis", -1) % ndim + + argmax = _op.argmax(inputs[0], axis=axis) + return _op.one_hot( + argmax, + _op.const(1.0, dtype), + _op.const(0.0, dtype), + fold_constant(_op.take(shape_of(inputs[0]), _op.const([axis], "int64"))), + axis, + dtype, + ) + class OneHot(OnnxOpConverter): """Operator converter for OneHot.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8849a4e2717b..afca9a337066 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4722,9 +4722,6 @@ def verify_eyelike(indata): "test_einsum_transpose", "test_greater_equal", "test_greater_equal_bcast", - "test_hardmax_axis_0", - "test_hardmax_axis_1", - "test_hardmax_default_axis", "test_if_seq", "test_less_equal", "test_less_equal_bcast",