From 57b72e5a7c09e9da1b9cf74cc49390318ede9801 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Tue, 2 Feb 2021 17:44:48 +0000 Subject: [PATCH 1/2] Add unit test for onnx export of sum operator. --- tests/python-pytest/onnx/test_operators.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index d2c3884f4775..f4012b9cbfb8 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -720,3 +720,17 @@ def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b): y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b) op_export_test('batch_dot2', M2, [x2, y2], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [None, 1, [1,2], -1]) +def test_onnx_export_sum(tmp_path, dtype, axis): + if 'int' in dtype: + x = mx.nd.random.randint(0, 10, (5, 6, 7, 8), dtype=dtype) + else: + x = mx.nd.random.normal(0, 10, (5, 6, 7, 8), dtype=dtype) + if axis is not None: + M = def_model('sum', axis=axis) + else: + M = def_model('sum') + op_export_test('sum', M, [x], tmp_path) From dc31acdca19531487da60243050a72976d935d58 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Tue, 2 Feb 2021 17:47:19 +0000 Subject: [PATCH 2/2] Add unit test for onnx export of broadcast_mul operator. --- tests/python-pytest/onnx/test_operators.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index f4012b9cbfb8..b03f473aef95 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -734,3 +734,11 @@ def test_onnx_export_sum(tmp_path, dtype, axis): else: M = def_model('sum') op_export_test('sum', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +def test_onnx_export_broadcast_mul(tmp_path, dtype): + M = def_model('broadcast_mul') + x = mx.nd.array([[1,2,3],[4,5,6]], dtype=dtype) + y = mx.nd.array([[0],[3]], dtype=dtype) + op_export_test('broadcast_mul', M, [x, y], tmp_path)