diff --git a/tests/python-pytest/onnx/test_operators.py b/tests/python-pytest/onnx/test_operators.py index d2c3884f4775..b03f473aef95 100644 --- a/tests/python-pytest/onnx/test_operators.py +++ b/tests/python-pytest/onnx/test_operators.py @@ -720,3 +720,25 @@ def test_onnx_export_batch_dot(tmp_path, dtype, transpose_a, transpose_b): y2 = mx.nd.random.normal(0, 10, (2, 3, 4, 5, 5), dtype=dtype) M2 = def_model('batch_dot', transpose_a=transpose_a, transpose_b=transpose_b) op_export_test('batch_dot2', M2, [x2, y2], tmp_path) + + +@pytest.mark.parametrize('dtype', ['int32', 'int64', 'float16', 'float32', 'float64']) +@pytest.mark.parametrize('axis', [None, 1, [1,2], -1]) +def test_onnx_export_sum(tmp_path, dtype, axis): + if 'int' in dtype: + x = mx.nd.random.randint(0, 10, (5, 6, 7, 8), dtype=dtype) + else: + x = mx.nd.random.normal(0, 10, (5, 6, 7, 8), dtype=dtype) + if axis is not None: + M = def_model('sum', axis=axis) + else: + M = def_model('sum') + op_export_test('sum', M, [x], tmp_path) + + +@pytest.mark.parametrize('dtype', ['float16', 'float32', 'float64', 'int32', 'int64']) +def test_onnx_export_broadcast_mul(tmp_path, dtype): + M = def_model('broadcast_mul') + x = mx.nd.array([[1,2,3],[4,5,6]], dtype=dtype) + y = mx.nd.array([[0],[3]], dtype=dtype) + op_export_test('broadcast_mul', M, [x, y], tmp_path)