Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream` and safetensor support.
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented ``MaxPool1d``, ``MaxPool2d``, ``AvgPool1d``, ``AvgPool2d``.
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.

<a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
Expand Down
3 changes: 3 additions & 0 deletions docs/src/python/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Operations
argpartition
argsort
array_equal
atleast_1d
atleast_2d
atleast_3d
broadcast_to
ceil
clip
Expand Down
30 changes: 30 additions & 0 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3381,4 +3381,34 @@ std::vector<array> depends(
shapes, dtypes, std::make_shared<Depends>(to_stream(s)), all_inputs);
}

array atleast_1d(const array& a, StreamOrDevice s /* = {} */) {
if (a.ndim() == 0) {
return reshape(a, {1}, s);
}
return a;
}

array atleast_2d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
return reshape(a, {1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size())}, s);
default:
return a;
}
}

array atleast_3d(const array& a, StreamOrDevice s /* = {} */) {
switch (a.ndim()) {
case 0:
return reshape(a, {1, 1, 1}, s);
case 1:
return reshape(a, {1, static_cast<int>(a.size()), 1}, s);
case 2:
return reshape(a, {a.shape(0), a.shape(1), 1}, s);
default:
return a;
}
}
} // namespace mlx::core
5 changes: 5 additions & 0 deletions mlx/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1121,4 +1121,9 @@ std::vector<array> depends(
const std::vector<array>& inputs,
const std::vector<array>& dependencies);

/** convert an array to an atleast ndim array */
array atleast_1d(const array& a, StreamOrDevice s = {});
array atleast_2d(const array& a, StreamOrDevice s = {});
array atleast_3d(const array& a, StreamOrDevice s = {});

} // namespace mlx::core
60 changes: 60 additions & 0 deletions python/src/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3636,4 +3636,64 @@ void init_ops(py::module_& m) {
Returns:
array: The extracted diagonal or the constructed diagonal matrix.
)pbdoc");
m.def(
"atleast_1d",
&atleast_1d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_1d(a: array, stream: Union[None, Stream, Device] = None) -> array

Convert array to have at least one dimension.

args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least one dimension.

)pbdoc");
m.def(
"atleast_2d",
&atleast_2d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_2d(a: array, stream: Union[None, Stream, Device] = None) -> array

Convert array to have at least two dimensions.

args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least two dimensions.

)pbdoc");
m.def(
"atleast_3d",
&atleast_3d,
"a"_a,
py::pos_only(),
py::kw_only(),
"stream"_a = none,
R"pbdoc(
atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array

Convert array to have at least three dimensions.

args:
a (array): Input array
stream (Union[None, Stream, Device], optional): The stream to execute the operation on.

Returns:
array: An array with at least three dimensions.

)pbdoc");
}
90 changes: 90 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,96 @@ def test_diag(self):
expected = mx.array(np.diag(x, k=-1))
self.assertTrue(mx.array_equal(result, expected))

def test_atleast_1d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y

# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_res = mx.atleast_1d(mx.array(array))
np_res = np.atleast_1d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)

def test_atleast_2d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y

# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_res = mx.atleast_2d(mx.array(array))
np_res = np.atleast_2d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)

def test_atleast_3d(self):
def compare_nested_lists(x, y):
if isinstance(x, list) and isinstance(y, list):
if len(x) != len(y):
return False
for i in range(len(x)):
if not compare_nested_lists(x[i], y[i]):
return False
return True
else:
return x == y

# Test 1D input
arrays = [
[1],
[1, 2, 3],
[1, 2, 3, 4],
[[1], [2], [3]],
[[1, 2], [3, 4]],
[[1, 2, 3], [4, 5, 6]],
[[[[1]], [[2]], [[3]]]],
]

for array in arrays:
mx_res = mx.atleast_3d(mx.array(array))
np_res = np.atleast_3d(np.array(array))
self.assertTrue(compare_nested_lists(mx_res.tolist(), np_res.tolist()))
self.assertEqual(mx_res.shape, np_res.shape)
self.assertEqual(mx_res.ndim, np_res.ndim)


if __name__ == "__main__":
unittest.main()
51 changes: 51 additions & 0 deletions tests/ops_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2716,3 +2716,54 @@ TEST_CASE("test diag") {
out = diag(x, -1);
CHECK(array_equal(out, array({3, 7}, {2})).item<bool>());
}

TEST_CASE("test atleast_1d") {
auto x = array(1);
auto out = atleast_1d(x);
CHECK_EQ(out.ndim(), 1);
CHECK_EQ(out.shape(), std::vector<int>{1});

x = array({1, 2, 3}, {3});
out = atleast_1d(x);
CHECK_EQ(out.ndim(), 1);
CHECK_EQ(out.shape(), std::vector<int>{3});

x = array({1, 2, 3}, {3, 1});
out = atleast_1d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_2d") {
auto x = array(1);
auto out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{1, 1});

x = array({1, 2, 3}, {3});
out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{1, 3});

x = array({1, 2, 3}, {3, 1});
out = atleast_2d(x);
CHECK_EQ(out.ndim(), 2);
CHECK_EQ(out.shape(), std::vector<int>{3, 1});
}

TEST_CASE("test atleast_3d") {
auto x = array(1);
auto out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{1, 1, 1});

x = array({1, 2, 3}, {3});
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{1, 3, 1});

x = array({1, 2, 3}, {3, 1});
out = atleast_3d(x);
CHECK_EQ(out.ndim(), 3);
CHECK_EQ(out.shape(), std::vector<int>{3, 1, 1});
}