From b68e7062bf13716eff1dd07d0226c7881f1f1fb9 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 28 Jan 2021 13:07:24 -0800 Subject: [PATCH 1/2] Add testing for datatypes and fix related bugs. --- python/tvm/relay/frontend/onnx.py | 6 +++--- tests/python/frontend/onnx/test_forward.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b1b01b87f715..9e1162113466 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1451,7 +1451,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return AttrCvt("argmax")(inputs, attr) + return _op.cast(AttrCvt("argmax")(inputs, attr), 'int64') class ArgMin(OnnxOpConverter): @@ -1462,7 +1462,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return AttrCvt("argmin")(inputs, attr) + return _op.cast(AttrCvt("argmin")(inputs, attr), 'int64') class Softmax(OnnxOpConverter): @@ -2000,7 +2000,7 @@ def _impl_v1(cls, inputs, attr, params): if largest == 0: raise ValueError("TVM only supports finding TopK largest elements") - return _op.topk(inputs[0], inputs[1], axis=axis) + return _op.topk(inputs[0], inputs[1], axis=axis, dtype='int64') class Range(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c666604d0e89..56d1dd5a5265 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -163,6 +163,7 @@ def verify_with_ort_with_inputs( ort_val = scipy.special.softmax(ort_val) tvm_val = scipy.special.softmax(tvm_val) tvm.testing.assert_allclose(ort_val, tvm_val, rtol=rtol, atol=atol) + assert ort_val.dtype == tvm_val.dtype def verify_with_ort( From 2c7f48f8c53ffd820b2fea348bbab1bb4651c97b Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 28 Jan 2021 14:48:11 -0800 Subject: [PATCH 2/2] Fix lint issue in onnx. --- python/tvm/relay/frontend/onnx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 9e1162113466..897c6a022594 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1451,7 +1451,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmax")(inputs, attr), 'int64') + return _op.cast(AttrCvt("argmax")(inputs, attr), "int64") class ArgMin(OnnxOpConverter): @@ -1462,7 +1462,7 @@ def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 0) keepdims = attr.get("keepdims", True) attr = {"axis": axis, "keepdims": keepdims} - return _op.cast(AttrCvt("argmin")(inputs, attr), 'int64') + return _op.cast(AttrCvt("argmin")(inputs, attr), "int64") class Softmax(OnnxOpConverter): @@ -2000,7 +2000,7 @@ def _impl_v1(cls, inputs, attr, params): if largest == 0: raise ValueError("TVM only supports finding TopK largest elements") - return _op.topk(inputs[0], inputs[1], axis=axis, dtype='int64') + return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64") class Range(OnnxOpConverter):