Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3414,6 +3414,17 @@ array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}

std::vector<array> atleast_1d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_1d(a, s));
}
return out;
}

array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
Expand All @@ -3425,6 +3436,17 @@ array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
}
}

std::vector<array> atleast_2d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_2d(a, s));
}
return out;
}

array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
Expand All @@ -3437,4 +3459,16 @@ array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
return a;
}
}

std::vector<array> atleast_3d(
const std::vector<array>& arrays,
StreamOrDevice s /* = {} */) {
std::vector<array> out;
out.reserve(arrays.size());
for (const auto& a : arrays) {
out.push_back(atleast_3d(a, s));
}
return out;
}

} // namespace mlx::core
9 changes: 9 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,16 @@ std::vector<array> depends(

/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_1d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_2d(
const std::vector<array>& a,
StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});
std::vector<array> atleast_3d(
const std::vector<array>& a,
StreamOrDevice s = {});

} // namespace mlx::core
61 changes: 34 additions & 27 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3638,62 +3638,69 @@ void init_ops(py::module_& m) {
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_1d(arys[0].cast<array>(), s));
}
return py::cast(atleast_1d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_1d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]

Convert array to have at least one dimension.
Convert all arrays to have at least one dimension.

args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least one dimension.

array or list(array): An array or list of arrays with at least one dimension.
)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_2d(arys[0].cast<array>(), s));
}
return py::cast(atleast_2d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_2d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]

Convert array to have at least two dimensions.
Convert all arrays to have at least two dimensions.

args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least two dimensions.

array or list(array): An array or list of arrays with at least two dimensions.
)pbdoc");

m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
[](const py::args& arys, StreamOrDevice s) -> py::object {
if (arys.size() == 1) {
return py::cast(atleast_3d(arys[0].cast<array>(), s));
}
return py::cast(atleast_3d(arys.cast<std::vector<array>>(), s));
},
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array
atleast_3d(*arys: array, stream: Union[None, Stream, Device] = None) -> Union[array, List[array]]

Convert array to have at least three dimensions.
Convert all arrays to have at least three dimensions.

args:
a (array): Input array
Args:
*arys: Input arrays.
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least three dimensions.

array or list(array): An array or list of arrays with at least three dimensions.
)pbdoc");
}
18 changes: 15 additions & 3 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,12 +1932,16 @@ def compare_nested_lists(x, y):
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_arrays = [mx.atleast_1d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_1d(*mx_arrays)

for i, array in enumerate(arrays):
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))

def test_atleast_2d(self):
def compare_nested_lists(x, y):
Expand All @@ -1962,12 +1966,16 @@ def compare_nested_lists(x, y):
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_arrays = [mx.atleast_2d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_2d(*mx_arrays)

for i, array in enumerate(arrays):
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))

def test_atleast_3d(self):
def compare_nested_lists(x, y):
Expand All @@ -1992,12 +2000,16 @@ def compare_nested_lists(x, y):
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_arrays = [mx.atleast_3d(mx.array(x)) for x in arrays]
atleast_arrays = mx.atleast_3d(*mx_arrays)

for i, array in enumerate(arrays):
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)
self.assertTrue(mx.all(mx.equal(mx_res, atleast_arrays[i])))


if __name__ == "__main__":
Expand Down
39 changes: 39 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,19 @@ TEST_CASE("test atleast_1d") {
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_1d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_1d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 1);
CHECK_EQ(out[0].shape(), std::vector<int>{1});
CHECK_EQ(out[1].ndim(), 1);
CHECK_EQ(out[1].shape(), std::vector<int>{3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_2d") {
auto x = array(1);
auto out = atleast_2d(x);
Expand All @@ -2804,6 +2817,19 @@ TEST_CASE("test atleast_2d") {
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_2d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_2d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 2);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1});
CHECK_EQ(out[1].ndim(), 2);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3});
CHECK_EQ(out[2].ndim(), 2);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_3d") {
auto x = array(1);
auto out = atleast_3d(x);
Expand All @@ -2820,3 +2846,16 @@ TEST_CASE("test atleast_3d") {
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
}

TEST_CASE("test atleast_3d vector") {
auto x = std::vector<array>{
array(1), array({1, 2, 3}, {3}), array({1, 2, 3}, {3, 1})};
auto out = atleast_3d(x);
CHECK_EQ(out.size(), 3);
CHECK_EQ(out[0].ndim(), 3);
CHECK_EQ(out[0].shape(), std::vector<int>{1, 1, 1});
CHECK_EQ(out[1].ndim(), 3);
CHECK_EQ(out[1].shape(), std::vector<int>{1, 3, 1});
CHECK_EQ(out[2].ndim(), 3);
CHECK_EQ(out[2].shape(), std::vector<int>{3, 1, 1});
}