Added support for atleast_1d, atleast_2d, atleast_3d#694
Added support for atleast_1d, atleast_2d, atleast_3d#694angeloskath merged 6 commits intoml-explore:mainfrom
Conversation
|
I suppose your current implementation can only take an array, but I think it should be able to handle a list of arrays, similar to Jax, NumPy, and Torch. |
|
I agree, that would be a good next step. I decided to start small to get some initial feedback for the implementation. I am not sure if they will want to hold the merge until that functionality is available. In either case, I will start working on it once I have wrapped up any necessary changes that may arise from the initial feedback. |
python/src/array.cpp
Outdated
| .def( | ||
| "atleast_1d", | ||
| [](const array& a, StreamOrDevice s) { return atleast_1d(a, s); }, | ||
| py::kw_only(), | ||
| "stream"_a = none, | ||
| "See :func:`atleast_1d`.") | ||
| .def( | ||
| "atleast_2d", | ||
| [](const array& a, StreamOrDevice s) { return atleast_2d(a, s); }, | ||
| py::kw_only(), | ||
| "stream"_a = none, | ||
| "See :func:`atleast_2d`.") | ||
| .def( | ||
| "atleast_3d", | ||
| [](const array& a, StreamOrDevice s) { return atleast_3d(a, s); }, | ||
| py::kw_only(), | ||
| "stream"_a = none, | ||
| "See :func:`atleast_3d`."); |
There was a problem hiding this comment.
Numpy does not have these methods so I don't think we need to add them either..
python/src/ops.cpp
Outdated
| R"pbdoc( | ||
| atleast_3d(a: array, stream: Union[None, Stream, Device] = None) -> array | ||
|
|
||
| Convert array to contain at least three dimensions. |
There was a problem hiding this comment.
| Convert array to contain at least three dimensions. | |
| Convert array to have at least three dimensions. |
awni
left a comment
There was a problem hiding this comment.
I think this is very nicely done! Just some very minor comments, can you take a look? Otherwise LGTM. We can merge it as is or if you prefer to extend the API to support multiple arrays in this PR that is fine too!
The only change is to add the new ops to the list of ops in the docs. |
|
Thanks for the great feedback @awni as usual, I believe that I have made all of the necessary changes for the current functionality. It would be great if you could merge the current changes, and I will include the array support as a separate push request. |
docs/src/python/ops.rst
Outdated
| atleast_1d | ||
| atleast_2d | ||
| atleast_3d |
There was a problem hiding this comment.
Sorry one more nit. We are trying to keep these in alphabetical order. Could you move them to the appropriate spot?
awni
left a comment
There was a problem hiding this comment.
Awesome thanks! Just one more nit for the docs then we can merge!
Proposed changes
Second go at the atleast functionality based on the feedback here. It is now implemented on the C++ side. This feature is based on the following request.
The functionality is implemented as both a core and array function. I am still trying to get more familiar with the stream implementation along with when a custom primitive is needed. I am guessing that custom primitives are not necessary since the functionality includes no additional arithmetic and only utilizes existing transformations?
I am also not sure if the documentation will be automatically updated or if I should make some modifications on my own.
Finally, the current functionality only works on a single mlx.core.array, if the feedback is good and if people are interested, I can extend it to allow for a list of mlx.core.array.
Since I am still relatively new to the framework, especially the C++ side of things, please let me know if there is something I missed. This feature has been a great introduction to the framework.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes