From be8e27b31a99e1dec36565de9462111fd85d3c3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Thu, 15 Feb 2024 13:02:07 -0500 Subject: [PATCH 1/6] updated formatting --- mlx/ops.cpp | 30 +++++++++++++ mlx/ops.h | 5 +++ python/src/array.cpp | 20 ++++++++- python/src/ops.cpp | 60 +++++++++++++++++++++++++ python/tests/test_array.py | 90 ++++++++++++++++++++++++++++++++++++++ python/tests/test_ops.py | 90 ++++++++++++++++++++++++++++++++++++++ tests/ops_tests.cpp | 51 +++++++++++++++++++++ 7 files changed, 345 insertions(+), 1 deletion(-) diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 32af8a0781..97d4a3a2d2 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3381,4 +3381,34 @@ std::vector depends( shapes, dtypes, std::make_shared(to_stream(s)), all_inputs); } +array atleast_1d(const array& a, StreamOrDevice s /* = {} */) { + if (a.ndim() == 0) { + return reshape(a, {1}, s); + } + return a; +} + +array atleast_2d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size())}, s); + default: + return a; + } +} + +array atleast_3d(const array& a, StreamOrDevice s /* = {} */) { + switch (a.ndim()) { + case 0: + return reshape(a, {1, 1, 1}, s); + case 1: + return reshape(a, {1, static_cast(a.size()), 1}, s); + case 2: + return reshape(a, {a.shape(0), a.shape(1), 1}, s); + default: + return a; + } +} } // namespace mlx::core diff --git a/mlx/ops.h b/mlx/ops.h index f7036b8c6b..b61224d65d 100644 --- a/mlx/ops.h +++ b/mlx/ops.h @@ -1121,4 +1121,9 @@ std::vector depends( const std::vector& inputs, const std::vector& dependencies); +/** convert an array to an atleast ndim array */ +array atleast_1d(const array& a, StreamOrDevice s = {}); +array atleast_2d(const array& a, StreamOrDevice s = {}); +array atleast_3d(const array& a, StreamOrDevice s = {}); + } // namespace mlx::core diff --git a/python/src/array.cpp b/python/src/array.cpp index 4395d50e60..0acd0c3598 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1506,5 +1506,23 @@ void init_array(py::module_& m) { "stream"_a = none, R"pbdoc( Extract a diagonal or construct a diagonal matrix. - )pbdoc"); + )pbdoc") + .def( + "atleast_1d", + [](const array& a, StreamOrDevice s) { return atleast_1d(a, s); }, + py::kw_only(), + "stream"_a = none, + "See :func:`atleast_1d`.") + .def( + "atleast_2d", + [](const array& a, StreamOrDevice s) { return atleast_2d(a, s); }, + py::kw_only(), + "stream"_a = none, + "See :func:`atleast_2d`.") + .def( + "atleast_3d", + [](const array& a, StreamOrDevice s) { return atleast_3d(a, s); }, + py::kw_only(), + "stream"_a = none, + "See :func:`atleast_3d`."); } diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 8e08e6ca9b..c63ea0706c 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) { Returns: array: The extracted diagonal or the constructed diagonal matrix. )pbdoc"); + m.def( + "atleast_1d", + &atleast_1d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least one dimension. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least one dimension. + + )pbdoc"); + m.def( + "atleast_2d", + &atleast_2d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to have at least two dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least two dimensions. + + )pbdoc"); + m.def( + "atleast_3d", + &atleast_3d, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array + + Convert array to contain at least three dimensions. + + args: + a (array): Input array + stream (Union[None, Stream, Device], optional): The stream to execute the operation on. + + Returns: + array: An array with at least three dimensions. + + )pbdoc"); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 507675d6ed..5a6efae896 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1381,6 +1381,96 @@ def test_inplace(self): b @= a self.assertTrue(mx.array_equal(a, b)) + def test_atleast_1d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.array(array).atleast_1d() + np_res = np.atleast_1d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_2d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.array(array).atleast_2d() + np_res = np.atleast_2d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_3d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.array(array).atleast_3d() + np_res = np.atleast_3d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + if __name__ == "__main__": unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py index 66e683303a..3401338f8f 100644 --- a/python/tests/test_ops.py +++ b/python/tests/test_ops.py @@ -1883,6 +1883,96 @@ def test_diag(self): expected = mx.array(np.diag(x, k=-1)) self.assertTrue(mx.array_equal(result, expected)) + def test_atleast_1d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_1d(mx.array(array)) + np_res = np.atleast_1d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_2d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_2d(mx.array(array)) + np_res = np.atleast_2d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + + def test_atleast_3d(self): + def compare_nested_lists(x, y): + if isinstance(x, list) and isinstance(y, list): + if len(x) != len(y): + return False + for i in range(len(x)): + if not compare_nested_lists(x[i], y[i]): + return False + return True + else: + return x == y + + # Test 1D input + arrays = [ + [1], + [1, 2, 3], + [1, 2, 3, 4], + [[1], [2], [3]], + [[1, 2], [3, 4]], + [[1, 2, 3], [4, 5, 6]], + [[[[1]], [[2]], [[3]]]], + ] + + for array in arrays: + mx_res = mx.atleast_3d(mx.array(array)) + np_res = np.atleast_3d(np.array(array)) + self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) + self.assertEqual(mx_res.shape, np_res.shape) + self.assertEqual(mx_res.ndim, np_res.ndim) + if __name__ == "__main__": unittest.main() diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp index e52c1294f6..2a38b18061 100644 --- a/tests/ops_tests.cpp +++ b/tests/ops_tests.cpp @@ -2716,3 +2716,54 @@ TEST_CASE("test diag") { out = diag(x, -1); CHECK(array_equal(out, array({3, 7}, {2})).item()); } + +TEST_CASE("test atleast_1d") { + auto x = array(1); + auto out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{1}); + + x = array({1, 2, 3}, {3}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_1d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_2d") { + auto x = array(1); + auto out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{1, 3}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_2d(x); + CHECK_EQ(out.ndim(), 2); + CHECK_EQ(out.shape(), std::vector{3, 1}); +} + +TEST_CASE("test atleast_3d") { + auto x = array(1); + auto out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + + x = array({1, 2, 3}, {3}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + + x = array({1, 2, 3}, {3, 1}); + out = atleast_3d(x); + CHECK_EQ(out.ndim(), 3); + CHECK_EQ(out.shape(), std::vector{3, 1, 1}); +} \ No newline at end of file From e0b1d600ddedce96f8d07130cbdbb897a6a47e43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Sat, 17 Feb 2024 10:37:10 -0500 Subject: [PATCH 2/6] removed array methods --- python/src/array.cpp | 20 +-------- python/tests/test_array.py | 90 -------------------------------------- 2 files changed, 1 insertion(+), 109 deletions(-) diff --git a/python/src/array.cpp b/python/src/array.cpp index 0acd0c3598..4395d50e60 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -1506,23 +1506,5 @@ void init_array(py::module_& m) { "stream"_a = none, R"pbdoc( Extract a diagonal or construct a diagonal matrix. - )pbdoc") - .def( - "atleast_1d", - [](const array& a, StreamOrDevice s) { return atleast_1d(a, s); }, - py::kw_only(), - "stream"_a = none, - "See :func:`atleast_1d`.") - .def( - "atleast_2d", - [](const array& a, StreamOrDevice s) { return atleast_2d(a, s); }, - py::kw_only(), - "stream"_a = none, - "See :func:`atleast_2d`.") - .def( - "atleast_3d", - [](const array& a, StreamOrDevice s) { return atleast_3d(a, s); }, - py::kw_only(), - "stream"_a = none, - "See :func:`atleast_3d`."); + )pbdoc"); } diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 5a6efae896..507675d6ed 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1381,96 +1381,6 @@ def test_inplace(self): b @= a self.assertTrue(mx.array_equal(a, b)) - def test_atleast_1d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - - # Test 1D input - arrays = [ - [1], - [1, 2, 3], - [1, 2, 3, 4], - [[1], [2], [3]], - [[1, 2], [3, 4]], - [[1, 2, 3], [4, 5, 6]], - [[[[1]], [[2]], [[3]]]], - ] - - for array in arrays: - mx_res = mx.array(array).atleast_1d() - np_res = np.atleast_1d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) - self.assertEqual(mx_res.shape, np_res.shape) - self.assertEqual(mx_res.ndim, np_res.ndim) - - def test_atleast_2d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - - # Test 1D input - arrays = [ - [1], - [1, 2, 3], - [1, 2, 3, 4], - [[1], [2], [3]], - [[1, 2], [3, 4]], - [[1, 2, 3], [4, 5, 6]], - [[[[1]], [[2]], [[3]]]], - ] - - for array in arrays: - mx_res = mx.array(array).atleast_2d() - np_res = np.atleast_2d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) - self.assertEqual(mx_res.shape, np_res.shape) - self.assertEqual(mx_res.ndim, np_res.ndim) - - def test_atleast_3d(self): - def compare_nested_lists(x, y): - if isinstance(x, list) and isinstance(y, list): - if len(x) != len(y): - return False - for i in range(len(x)): - if not compare_nested_lists(x[i], y[i]): - return False - return True - else: - return x == y - - # Test 1D input - arrays = [ - [1], - [1, 2, 3], - [1, 2, 3, 4], - [[1], [2], [3]], - [[1, 2], [3, 4]], - [[1, 2, 3], [4, 5, 6]], - [[[[1]], [[2]], [[3]]]], - ] - - for array in arrays: - mx_res = mx.array(array).atleast_3d() - np_res = np.atleast_3d(np.array(array)) - self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist())) - self.assertEqual(mx_res.shape, np_res.shape) - self.assertEqual(mx_res.ndim, np_res.ndim) - if __name__ == "__main__": unittest.main() From 87012b22419b2c1a08b30e5fc41c8cd471eb3d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Sat, 17 Feb 2024 10:39:37 -0500 Subject: [PATCH 3/6] fixed docs --- docs/src/python/ops.rst | 3 +++ python/src/ops.cpp | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index 09e2d5f71c..d08839e03f 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -125,3 +125,6 @@ Operations where zeros zeros_like + atleast_1d + atleast_2d + atleast_3d diff --git a/python/src/ops.cpp b/python/src/ops.cpp index c63ea0706c..2c2dcecfd3 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -3686,7 +3686,7 @@ void init_ops(py::module_& m) { R"pbdoc( atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array - Convert array to contain at least three dimensions. + Convert array to have at least three dimensions. args: a (array): Input array From 89e347d7689d05a6246e227ce05fd349f73e1698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Sat, 17 Feb 2024 10:44:08 -0500 Subject: [PATCH 4/6] Added name to acknowledgements --- ACKNOWLEDGMENTS.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 36aedc77a2..91b5300c8d 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals: - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. - Juarez Bochi: Fixed bug in cross attention. - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. -- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support +- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. +- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` support. From d967a66db3d51a9eac2cc8a4270bfa651b31151c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Sat, 17 Feb 2024 10:44:58 -0500 Subject: [PATCH 5/6] added acknoledgements --- ACKNOWLEDGMENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 91b5300c8d..c2cad615ef 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -12,7 +12,7 @@ MLX was developed with contributions from the following individuals: - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support. - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``. -- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` support. +- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops. From 957c549602ca5c32da200599afb6c1d857cfd2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hinrik=20Sn=C3=A6r=20Gu=C3=B0mundsson?= Date: Mon, 19 Feb 2024 11:05:01 -0500 Subject: [PATCH 6/6] fixe docs --- docs/src/python/ops.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/python/ops.rst b/docs/src/python/ops.rst index d08839e03f..7ec7defc9c 100644 --- a/docs/src/python/ops.rst +++ b/docs/src/python/ops.rst @@ -25,6 +25,9 @@ Operations argpartition argsort array_equal + atleast_1d + atleast_2d + atleast_3d broadcast_to ceil clip @@ -125,6 +128,3 @@ Operations where zeros zeros_like - atleast_1d - atleast_2d - atleast_3d