-
Notifications
You must be signed in to change notification settings - Fork 54
Add specification for returning the least-squares solution to a linear matrix equation (linalg: lstsq) #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The JAX docs say that behaviour matches NumPy, empty array can be returned: jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linalg.lstsq.html. The in-progress PR for |
|
@rgommers Re: JAX. Sorry, I should have clarified. JAX's default behavior does not match NumPy's.
I've updated the OP accordingly. |
|
Renamed |
|
Renamed |
|
The PR looks fine to me, just a few high-level design questions:
I feel it's not very convenient to always request a matrix and forbid vector inputs. I understand we can always broadcast if x1.ndim == x2.ndim + 1:
x2 = x2[..., None] # or use newaxis
assert x1.ndim == x2.ndim
No, NumPy and CuPy return a tuple. Given that we didn't return namedtuple in SVD, perhaps we shouldn't do it here either to be consistent? |
Ah OK, thanks Athan! I missed that and thought |
|
@leofang Re: matrix/vector input. I've updated the proposal to include support for an ordinate vector. |
leofang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @kgryte!
0607525 to
138e963
Compare
|
Thanks, @leofang, for the review! This PR is ready for merge... |
This PR
Notes
Only TF allows for providing a stack of matrices. Torch, MXNet, CuPy, NumPy, and JAX do not. This proposal follows TF and ensures consistency with other linalg interfaces which currently support stacks.
TF supports
l2_regularizerandfastkeyword arguments and is alone in doing so.Neither Dask, Torch, nor TF support an
rcondkeyword argument. This proposal includes anrtolargument (note:rtolis renamed fromrcondto unify tolerance keywords acrosspinv,lstsq, andmatrix_rank), similar to the pinv proposal.Similar to pinv, the
rcondargument can either be afloator anarrayand have default values determined by type promotion rules.NumPy, MXNet, CuPy, and JAX all support
bbeing specified as either a vector or matrix. TF requires an(..., M,K)matrix. This PR follows NumPy.Return results:
rankfield which is an array.rankfield which is an integer.residualsfield which is empty for low-rank or over-determined solutions. JAX always returns residuals for JIT purposes, unless one setsnumpy_resid=True.This proposal returns a namedtuple with a
rankfield which is an array due to support for providing stacks of matrices and also returns that theresidualsfield always be returned, following JAX.