From 6dd34925ddcf10971da336cb22dbdd5fc192bb53 Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Fri, 3 Sep 2021 21:11:43 +0200 Subject: [PATCH 1/2] Add opset 13 impl for hardmax --- python/tvm/relay/frontend/onnx.py | 16 ++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 3 --- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9f89c5ee9476..c973ce3a71b3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1899,6 +1899,22 @@ def _impl_v1(cls, inputs, attr, params): ) return _op.reshape(onehot, shape_of(inputs[0])) + @classmethod + def _impl_v13(cls, inputs, attr, params) -> relay.Expr: + inferred_type = infer_type(inputs[0]) + dtype = inferred_type.checked_type.dtype + ndim = len(inferred_type.checked_type.shape) + axis = attr.get("axis", -1) % ndim + + argmax = _op.argmax(inputs[0], axis=axis) + return _op.one_hot( + argmax, + _op.const(1.0, dtype), + _op.const(0.0, dtype), + fold_constant(_op.take(shape_of(inputs[0]), _op.const([axis], "int64"))), + axis, + dtype, + ) class OneHot(OnnxOpConverter): """Operator converter for OneHot.""" diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8849a4e2717b..afca9a337066 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4722,9 +4722,6 @@ def verify_eyelike(indata): "test_einsum_transpose", "test_greater_equal", "test_greater_equal_bcast", - "test_hardmax_axis_0", - "test_hardmax_axis_1", - "test_hardmax_default_axis", "test_if_seq", "test_less_equal", "test_less_equal_bcast", From 9c4b301d7b101f0a52bce987387b852549331f3e Mon Sep 17 00:00:00 2001 From: Michal Piszczek Date: Fri, 3 Sep 2021 21:15:31 +0200 Subject: [PATCH 2/2] Format --- python/tvm/relay/frontend/onnx.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c973ce3a71b3..29221884702c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1916,6 +1916,7 @@ def _impl_v13(cls, inputs, attr, params) -> relay.Expr: dtype, ) + class OneHot(OnnxOpConverter): """Operator converter for OneHot."""