From 3eb9987830d0fff9b723f6b669aa8e75afafd44d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Mon, 19 Feb 2024 14:49:07 -0500 Subject: [PATCH 1/5] added atleast list(array) input support --- python/src/ops.cpp | 64 +++++++++++++++++++++++++--------------- python/tests/test_ops.py | 18 +++++++++-- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 2c2dcecfd3..4f789beb01 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3638,62 +3638,78 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_1d", - &atleast_1d, - "a"_a, + [](const py::args& args, + StreamOrDevice s) -> std::variant { + py::list result; + for (const auto& arg : args) { + result.append(atleast_1d(arg.cast(), s)); + } + return result.size() == 1 ? py::cast(result[0]) : result; + }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_1d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least one dimension. + Convert all arrays to have at least one dimension. - args: - a (array): Input array + Args: + args (array or list(array)): Input array or list of input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least one dimension. - + array or list(array): An array or list of arrays with at least one dimension. )pbdoc"); m.def( "atleast_2d", - &atleast_2d, - "a"_a, + [](const py::args& args, + StreamOrDevice s) -> std::variant { + py::list result; + for (const auto& arg : args) { + result.append(atleast_2d(arg.cast(), s)); + } + return result.size() == 1 ? py::cast(result[0]) : result; + }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_2d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least two dimensions. + Convert all arrays to have at least two dimensions. - args: - a (array): Input array + Args: + args (array or list(array)): Input array or list of input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least two dimensions. - + array or list(array): An array or list of arrays with at least two dimensions. )pbdoc"); + m.def( "atleast_3d", - &atleast_3d, - "a"_a, + [](const py::args& args, + StreamOrDevice s) -> std::variant { + py::list result; + for (const auto& arg : args) { + result.append(atleast_3d(arg.cast(), s)); + } + return result.size() == 1 ? py::cast(result[0]) : result; + }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + atleast_3d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] - Convert array to have at least three dimensions. + Convert all arrays to have at least three dimensions. - args: - a (array): Input array + Args: + args (array or list(array)): Input array or list of input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: - array: An array with at least three dimensions. - + array or list(array): An array or list of arrays with at least three dimensions. )pbdoc"); } diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 3401338f8f..f725bffa85 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1906,12 +1906,16 @@ def compare_nested_lists(x, y): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_1d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_1d(mx.array(array)) np_res = np.atleast_1d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) def test_atleast_2d(self): def compare_nested_lists(x, y): @@ -1936,12 +1940,16 @@ def compare_nested_lists(x, y): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_2d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_2d(mx.array(array)) np_res = np.atleast_2d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) def test_atleast_3d(self): def compare_nested_lists(x, y): @@ -1966,12 +1974,16 @@ def compare_nested_lists(x, y): [[[[1]], [[2]], [[3]]]], ] - for array in arrays: + mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays] + atleast_arrays = mx.atleast_3d(*mx_arrays) + + for i, array in enumerate(arrays): mx_res = mx.atleast_3d(mx.array(array)) np_res = np.atleast_3d(np.array(array)) self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) self.assertEqual(mx_res.shape, np_res.shape) self.assertEqual(mx_res.ndim, np_res.ndim) + self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i]))) if __name__ == "__main__": From c216a3fe17fc54cd6bd38f5bf275b597cd972e0a Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Wed, 21 Feb 2024 18:00:58 -0500 Subject: [PATCH 2/5] function overloading implemented --- mlx/ops.cpp | 34 ++++++++++++++++++++++++++++++++++ mlx/ops.h | 9 +++++++++ python/src/ops.cpp | 42 ++++++++++++++++++++++++++++++------------ tests/ops_tests.cpp | 39 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 12 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 97d4a3a2d2..948d97e78c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3388,6 +3388,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { return a; } +std::vector atleast_1d( + const std::vector& arrays, + StreamOrDevice s) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_1d(a, s)); + } + return out; +} + array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { switch (a.ndim()) { case 0: @@ -3399,6 +3410,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { } } +std::vector atleast_2d( + const std::vector& arrays, + StreamOrDevice s) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_2d(a, s)); + } + return out; +} + array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { switch (a.ndim()) { case 0: @@ -3411,4 +3433,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { return a; } } + +std::vector atleast_3d( + const std::vector& arrays, + StreamOrDevice s) { + std::vector out; + out.reserve(arrays.size()); + for (const auto& a : arrays) { + out.push_back(atleast_3d(a, s)); + } + return out; +} + } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index b61224d65d..a92a4f8c09 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1123,7 +1123,16 @@ std::vector depends( /** convert an array to an atleast ndim array */ array atleast_1d(const array& a, StreamOrDevice s = {}); +std::vector atleast_1d( + const std::vector& a, + StreamOrDevice s = {}); array atleast_2d(const array& a, StreamOrDevice s = {}); +std::vector atleast_2d( + const std::vector& a, + StreamOrDevice s = {}); array atleast_3d(const array& a, StreamOrDevice s = {}); +std::vector atleast_3d( + const std::vector& a, + StreamOrDevice s = {}); } // namespace mlx::core diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 4f789beb01..25fd122132 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3640,11 +3640,17 @@ void init_ops(py::module_& m) { "atleast_1d", [](const py::args& args, StreamOrDevice s) -> std::variant { - py::list result; - for (const auto& arg : args) { - result.append(atleast_1d(arg.cast(), s)); + if (args.size() == 1) { + py::object result; + result = py::cast(atleast_1d(args[0].cast(), s)); + return result; } - return result.size() == 1 ? py::cast(result[0]) : result; + + std::vector arrays = args.cast>(); + std::vector result = atleast_1d(arrays, s); + py::list py_result = py::cast(result); + + return py_result; }, py::pos_only(), py::kw_only(), @@ -3665,11 +3671,17 @@ void init_ops(py::module_& m) { "atleast_2d", [](const py::args& args, StreamOrDevice s) -> std::variant { - py::list result; - for (const auto& arg : args) { - result.append(atleast_2d(arg.cast(), s)); + if (args.size() == 1) { + py::object result; + result = py::cast(atleast_2d(args[0].cast(), s)); + return result; } - return result.size() == 1 ? py::cast(result[0]) : result; + + std::vector arrays = args.cast>(); + std::vector result = atleast_2d(arrays, s); + py::list py_result = py::cast(result); + + return py_result; }, py::pos_only(), py::kw_only(), @@ -3691,11 +3703,17 @@ void init_ops(py::module_& m) { "atleast_3d", [](const py::args& args, StreamOrDevice s) -> std::variant { - py::list result; - for (const auto& arg : args) { - result.append(atleast_3d(arg.cast(), s)); + if (args.size() == 1) { + py::object result; + result = py::cast(atleast_3d(args[0].cast(), s)); + return result; } - return result.size() == 1 ? py::cast(result[0]) : result; + + std::vector arrays = args.cast>(); + std::vector result = atleast_3d(arrays, s); + py::list py_result = py::cast(result); + + return py_result; }, py::pos_only(), py::kw_only(), diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index ba4ab552f8..52f679d9e1 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2734,6 +2734,19 @@ TEST_CASE("test atleast_1d") { CHECK_EQ(out.shape(), std::vector{3, 1}); } +TEST_CASE("test atleast_1d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_1d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 1); + CHECK_EQ(out[0].shape(), std::vector{1}); + CHECK_EQ(out[1].ndim(), 1); + CHECK_EQ(out[1].shape(), std::vector{3}); + CHECK_EQ(out[2].ndim(), 2); + CHECK_EQ(out[2].shape(), std::vector{3, 1}); +} + TEST_CASE("test atleast_2d") { auto x = array(1); auto out = atleast_2d(x); @@ -2751,6 +2764,19 @@ TEST_CASE("test atleast_2d") { CHECK_EQ(out.shape(), std::vector{3, 1}); } +TEST_CASE("test atleast_2d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_2d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 2); + CHECK_EQ(out[0].shape(), std::vector{1, 1}); + CHECK_EQ(out[1].ndim(), 2); + CHECK_EQ(out[1].shape(), std::vector{1, 3}); + CHECK_EQ(out[2].ndim(), 2); + CHECK_EQ(out[2].shape(), std::vector{3, 1}); +} + TEST_CASE("test atleast_3d") { auto x = array(1); auto out = atleast_3d(x); @@ -2766,4 +2792,17 @@ TEST_CASE("test atleast_3d") { out = atleast_3d(x); CHECK_EQ(out.ndim(), 3); CHECK_EQ(out.shape(), std::vector{3, 1, 1}); +} + +TEST_CASE("test atleast_3d vector") { + auto x = std::vector{ + array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})}; + auto out = atleast_3d(x); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].ndim(), 3); + CHECK_EQ(out[0].shape(), std::vector{1, 1, 1}); + CHECK_EQ(out[1].ndim(), 3); + CHECK_EQ(out[1].shape(), std::vector{1, 3, 1}); + CHECK_EQ(out[2].ndim(), 3); + CHECK_EQ(out[2].shape(), std::vector{3, 1, 1}); } \ No newline at end of file From 4b4972761e2c1c98cf021dbf7f0350dced1716a1 Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Mon, 26 Feb 2024 12:50:31 -0500 Subject: [PATCH 3/5] Refactoring --- mlx/ops.cpp | 6 ++--- python/src/ops.cpp | 63 ++++++++++++++++------------------------------ 2 files changed, 24 insertions(+), 45 deletions(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 948d97e78c..2d9295c7fc 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3390,7 +3390,7 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { std::vector atleast_1d( const std::vector& arrays, - StreamOrDevice s) { + StreamOrDevice s /* = {} */) { std::vector out; out.reserve(arrays.size()); for (const auto& a : arrays) { @@ -3412,7 +3412,7 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { std::vector atleast_2d( const std::vector& arrays, - StreamOrDevice s) { + StreamOrDevice s /* = {} */) { std::vector out; out.reserve(arrays.size()); for (const auto& a : arrays) { @@ -3436,7 +3436,7 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { std::vector atleast_3d( const std::vector& arrays, - StreamOrDevice s) { + StreamOrDevice s /* = {} */) { std::vector out; out.reserve(arrays.size()); for (const auto& a : arrays) { diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 25fd122132..893ef31110 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3638,30 +3638,23 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_1d", - [](const py::args& args, - StreamOrDevice s) -> std::variant { - if (args.size() == 1) { - py::object result; - result = py::cast(atleast_1d(args[0].cast(), s)); - return result; + [](const py::args& arys, + StreamOrDevice s) -> py::object { + if (arys.size() == 1) { + return py::cast(atleast_1d(arys[0].cast(), s)); } - - std::vector arrays = args.cast>(); - std::vector result = atleast_1d(arrays, s); - py::list py_result = py::cast(result); - - return py_result; + return py::cast(atleast_1d(arys.cast>(), s)); }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_1d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] + atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] Convert all arrays to have at least one dimension. Args: - args (array or list(array)): Input array or list of input arrays. + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: @@ -3669,30 +3662,23 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_2d", - [](const py::args& args, - StreamOrDevice s) -> std::variant { - if (args.size() == 1) { - py::object result; - result = py::cast(atleast_2d(args[0].cast(), s)); - return result; + [](const py::args& arys, + StreamOrDevice s) -> py::object{ + if (arys.size() == 1) { + return py::cast(atleast_2d(arys[0].cast(), s)); } - - std::vector arrays = args.cast>(); - std::vector result = atleast_2d(arrays, s); - py::list py_result = py::cast(result); - - return py_result; + return py::cast(atleast_2d(arys.cast>(), s)); }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_2d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] + atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] Convert all arrays to have at least two dimensions. Args: - args (array or list(array)): Input array or list of input arrays. + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: @@ -3701,30 +3687,23 @@ void init_ops(py::module_& m) { m.def( "atleast_3d", - [](const py::args& args, - StreamOrDevice s) -> std::variant { - if (args.size() == 1) { - py::object result; - result = py::cast(atleast_3d(args[0].cast(), s)); - return result; + [](const py::args& arys, + StreamOrDevice s) -> py::object { + if (arys.size() == 1) { + return py::cast(atleast_3d(arys[0].cast(), s)); } - - std::vector arrays = args.cast>(); - std::vector result = atleast_3d(arrays, s); - py::list py_result = py::cast(result); - - return py_result; + return py::cast(atleast_3d(arys.cast>(), s)); }, py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( - atleast_3d(*args: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] + atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]] Convert all arrays to have at least three dimensions. Args: - args (array or list(array)): Input array or list of input arrays. + *arys: Input arrays. stream (Union[None, Stream, Device], optional): The stream to execute the operation on. Returns: From 19cc692f6e8e19a0eee46ee2a733581f25c0daa7 Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Mon, 26 Feb 2024 12:51:59 -0500 Subject: [PATCH 4/5] fixed formatting --- python/src/ops.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 893ef31110..c29a06fb49 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3638,8 +3638,7 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_1d", - [](const py::args& arys, - StreamOrDevice s) -> py::object { + [](const py::args& arys, StreamOrDevice s) -> py::object { if (arys.size() == 1) { return py::cast(atleast_1d(arys[0].cast(), s)); } @@ -3662,8 +3661,7 @@ void init_ops(py::module_& m) { )pbdoc"); m.def( "atleast_2d", - [](const py::args& arys, - StreamOrDevice s) -> py::object{ + [](const py::args& arys, StreamOrDevice s) -> py::object { if (arys.size() == 1) { return py::cast(atleast_2d(arys[0].cast(), s)); } @@ -3687,8 +3685,7 @@ void init_ops(py::module_& m) { m.def( "atleast_3d", - [](const py::args& arys, - StreamOrDevice s) -> py::object { + [](const py::args& arys, StreamOrDevice s) -> py::object { if (arys.size() == 1) { return py::cast(atleast_3d(arys[0].cast(), s)); } From 5cac17b6b4466fa25a564cc949d8b60d7f3c5d63 Mon Sep 17 00:00:00 2001 From: hinriksnaer Date: Mon, 26 Feb 2024 13:42:07 -0500 Subject: [PATCH 5/5] removed pos_only --- python/src/ops.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c29a06fb49..56a1ac8de1 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3644,7 +3644,6 @@ void init_ops(py::module_& m) { } return py::cast(atleast_1d(arys.cast>(), s)); }, - py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( @@ -3667,7 +3666,6 @@ void init_ops(py::module_& m) { } return py::cast(atleast_2d(arys.cast>(), s)); }, - py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc( @@ -3691,7 +3689,6 @@ void init_ops(py::module_& m) { } return py::cast(atleast_3d(arys.cast>(), s)); }, - py::pos_only(), py::kw_only(), "stream"_a = none, R"pbdoc(