Skip to content

added atleast *args input support#710

Merged
awni merged 6 commits intoml-explore:mainfrom
hinriksnaer:array_atleast
Feb 26, 2024
Merged

added atleast *args input support#710
awni merged 6 commits intoml-explore:mainfrom
hinriksnaer:array_atleast

Conversation

@hinriksnaer
Copy link
Contributor

Proposed changes

Following up with atleast feature request

Swapped out array to *args for atleast_1d, atleast_2d, atleast_3d.

The function now returns either an array or list(array) depending on the amount of inputs, similar to the Numpy and PyTorch implementation of the same function. The corresponding tests have also been updated.

The current implementation simply iterates over the argument list and calls the original atleast_nd c++ function. I struggled a bit figuring out how to get Pybind11 to output one of two different datatypes but this seems to do the trick. I also considered making a separate std::vector<array> output variation of the atleast_nd functions and have the Pybind11 simply call the correct variation depending on the input length. I wasn't able to find a previously implemented function to use as a reference so I settled on simply iterating over the arguments in the pybind11 definition using a lambda function.

Please let me know if this breaks any conventions within the codebase or if an alternative approach is more desirable.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the follow up!

In the interest of keeping our C++ and Python APIs in sync, could you make this change on the C++ side and then dispatch to the appropriate function from Python?

Concretely something like:

  • Have overloads in C++ for single array or std::vector<array>
  • Have the second one call the first in a loop
  • Dispatch in Python to the appropriate function

@hinriksnaer
Copy link
Contributor Author

Sure thing! I will look right into that. In the interest of clarity, do you have anything in mind I could use for reference?

@awni
Copy link
Member

awni commented Feb 20, 2024

I wasn't referring to anything specific above. If you run into issues implementing just post here and I'm happy to help.

@hinriksnaer hinriksnaer changed the title added atleast list(array) input support added atleast *args input support Feb 21, 2024
@hinriksnaer
Copy link
Contributor Author

hinriksnaer commented Feb 21, 2024

Hey there @awni, this new batch of changes should be better aligned with what you described. Please let me know if there is anything you would like me to change/improve.

mlx/ops.cpp Outdated

std::vector<array> atleast_2d(
const std::vector<array>& arrays,
StreamOrDevice s) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
StreamOrDevice s) {
StreamOrDevice s /* = {} */) {

args:
a (array): Input array
Args:
args (array or list(array)): Input array or list of input arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think saying or list(array) here isn't quite right since the function won't accept a list. I'm looking to see if there is a standard for *args

Comment on lines +3644 to +3646
py::object result;
result = py::cast(atleast_1d(args[0].cast<array>(), s));
return result;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
py::object result;
result = py::cast(atleast_1d(args[0].cast<array>(), s));
return result;
return py::cast(atleast_1d(args[0].cast<array>(), s));

Does that work?

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments on atleast_1d which apply to the others as well. Could you check and address the ones that make sense?

PS thanks for the contributions!!

@hinriksnaer
Copy link
Contributor Author

Here is a set of changes based on your feedback @awni.

  • added /* = {} */
  • refactored pybind11 defenition
  • fixed docstring

The only thing I am uncertain about is how to rename the resulting python function to contain *arys instead of *args using pybind11.

@awni
Copy link
Member

awni commented Feb 26, 2024

The only thing I am uncertain about is how to rename the resulting python function to contain *arys instead of *args using pybind11.

I think you did that in the docstring + function signature already? That's the only change need as far as I can tell.

@hinriksnaer hinriksnaer requested a review from awni February 26, 2024 18:54
@hinriksnaer
Copy link
Contributor Author

hinriksnaer commented Feb 26, 2024

Pushed the requested changes.

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!!

@awni awni merged commit 08226ab into ml-explore:main Feb 26, 2024
awni pushed a commit that referenced this pull request Feb 27, 2024
* added atleast list(array) input support

* function overloading implemented

* Refactoring

* fixed formatting

* removed pos_only
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants