BUG: Fix np.dot to allow user dtypes#30931
Conversation
01363f7 to
f057091
Compare
PyArray_InnerProduct to allow user dtypesnp.dot to allow user dtypes
f057091 to
9481f81
Compare
|
ping @SwayamInSync this seems like something you might be interested in |
…ypes in PyArray_MatrixProduct2
seberg
left a comment
There was a problem hiding this comment.
Thanks, I am very happy to do this PyArray_DTypeFromObject but I think it makes sense here to apply it accross all functions that end up using new_array_for_sum.
seberg
left a comment
There was a problem hiding this comment.
There is still a segfault, hopefully something small. But I also noticed 2 smaller things.
Thanks for looking at all of it now!
Co-authored-by: Sebastian Berg <sebastian@sipsolutions.net>
970355f to
0f8f781
Compare
SwayamInSync
left a comment
There was a problem hiding this comment.
I'm sorry coming late, this looks good to me.
Thanks @MaanasArora for working on this
seberg
left a comment
There was a problem hiding this comment.
Thanks for the final fixes, let's get this in. (It wouldn't hurt to add more byte-swap related tests, but I don't want to block on that -- all modified code paths include NPY_DT_CALL_ensure_canonical.)
Brings the maxlag2025 branch up to date with current numpy/numpy main (843 commits since branch base). Conflict resolution in numpy/_core/src/multiarray/multiarraymodule.c: - Adopt upstream's _pyarray_correlate signature change from `int typenum` to `PyArray_Descr *typec` (PR numpy#30931). - Keep our refactor: the function still takes (minlag, maxlag, lagstep) instead of (mode, *inverted), and handles array swap, negative-step normalization, and output reversal internally. - Update PyArray_Correlate, PyArray_Correlate2, and PyArray_CorrelateLags to use upstream's PyArray_DTypeFromObject + NPY_DT_CALL_ensure_canonical pattern for type resolution, with proper Py_DECREF(typec) cleanup. - Update the new array_correlatelags argument parser to use upstream's brace-syntax for npy_parse_arguments. Add NPY_2_6_API_VERSION 0x00000016 and matching version-string clause in numpy/_core/include/numpy/numpyconfig.h to support the bumped C API version that registers PyArray_CorrelateLags at slot 369. Tests: numpy/_core/tests/test_numeric.py: 512 passed, 1 skipped.
This is a start to fix
PyArray_MatrixProduct2for user-defined dtypes, as a bug was reported in jax-ml/ml_dtypes#360. This is a draft as I've not tested this outside of a dummy dtype. I will try to use the linked PR with this, but wanted to get this out if anyone has comments. ping @seberg - thanks!There also seem to be other methods with the same issue, so we might need a helper function long run. Hopefully this stays minimal on more testing, though not sure if even this is small enough to backport.