added atleast *args input support#710
Conversation
awni
left a comment
There was a problem hiding this comment.
Thanks for the follow up!
In the interest of keeping our C++ and Python APIs in sync, could you make this change on the C++ side and then dispatch to the appropriate function from Python?
Concretely something like:
- Have overloads in C++ for single
arrayorstd::vector<array> - Have the second one call the first in a loop
- Dispatch in Python to the appropriate function
|
Sure thing! I will look right into that. In the interest of clarity, do you have anything in mind I could use for reference? |
|
I wasn't referring to anything specific above. If you run into issues implementing just post here and I'm happy to help. |
f0a56f3 to
c216a3f
Compare
|
Hey there @awni, this new batch of changes should be better aligned with what you described. Please let me know if there is anything you would like me to change/improve. |
mlx/ops.cpp
Outdated
|
|
||
| std::vector<array> atleast_2d( | ||
| const std::vector<array>& arrays, | ||
| StreamOrDevice s) { |
There was a problem hiding this comment.
| StreamOrDevice s) { | |
| StreamOrDevice s /* = {} */) { |
python/src/ops.cpp
Outdated
| args: | ||
| a (array): Input array | ||
| Args: | ||
| args (array or list(array)): Input array or list of input arrays. |
There was a problem hiding this comment.
I think saying or list(array) here isn't quite right since the function won't accept a list. I'm looking to see if there is a standard for *args
python/src/ops.cpp
Outdated
| py::object result; | ||
| result = py::cast(atleast_1d(args[0].cast<array>(), s)); | ||
| return result; |
There was a problem hiding this comment.
| py::object result; | |
| result = py::cast(atleast_1d(args[0].cast<array>(), s)); | |
| return result; | |
| return py::cast(atleast_1d(args[0].cast<array>(), s)); |
Does that work?
|
Here is a set of changes based on your feedback @awni.
The only thing I am uncertain about is how to rename the resulting python function to contain *arys instead of *args using pybind11. |
I think you did that in the docstring + function signature already? That's the only change need as far as I can tell. |
|
Pushed the requested changes. |
* added atleast list(array) input support * function overloading implemented * Refactoring * fixed formatting * removed pos_only
Proposed changes
Following up with atleast feature request
Swapped out
arrayto*argsfor atleast_1d, atleast_2d, atleast_3d.The function now returns either an
arrayorlist(array)depending on the amount of inputs, similar to the Numpy and PyTorch implementation of the same function. The corresponding tests have also been updated.The current implementation simply iterates over the argument list and calls the original
atleast_ndc++ function. I struggled a bit figuring out how to get Pybind11 to output one of two different datatypes but this seems to do the trick. I also considered making a separatestd::vector<array>output variation of theatleast_ndfunctions and have the Pybind11 simply call the correct variation depending on the input length. I wasn't able to find a previously implemented function to use as a reference so I settled on simply iterating over the arguments in the pybind11 definition using a lambda function.Please let me know if this breaks any conventions within the codebase or if an alternative approach is more desirable.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes