From 7d77d399d24ef47cacb9ee87b3d57bf006ec0021 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 May 2025 16:05:58 -0400 Subject: [PATCH 01/11] avoid defining 0-arg functions for docs --- src/algorithms.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 4a6bfe58..ea5d3e5c 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -73,8 +73,7 @@ automatically with [`MatrixAlgebraKit.default_algorithm`](@ref) and the keyword arguments in `kwargs` will be passed to the algorithm constructor. Finally, the same behavior is obtained when the keyword arguments are passed as the third positional argument in the form of a `NamedTuple`. -""" -function select_algorithm end +""" select_algorithm function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} return _select_algorithm(f, A, alg; kwargs...) @@ -109,8 +108,7 @@ end Select the default algorithm for a given factorization function `f` and input `A`. In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified explicitly. -""" -function default_algorithm end +""" default_algorithm @doc """ copy_input(f, A) @@ -118,8 +116,7 @@ function default_algorithm end Preprocess the input `A` for a given function, such that it may be handled correctly later. This may include a copy whenever the implementation would destroy the original matrix, or a change of element type to something that is supported. -""" -function copy_input end +""" copy_input @doc """ initialize_output(f, A, alg) @@ -127,8 +124,7 @@ function copy_input end Whenever possible, allocate the destination for applying a given algorithm in-place. If this is not possible, for example when the output size is not known a priori or immutable, this function may return `nothing`. -""" -function initialize_output end +""" initialize_output # Utility macros # -------------- From 3b16fc070794827ab7c5a2c2e96b4c66f139b622 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 22 May 2025 16:19:25 -0400 Subject: [PATCH 02/11] Rewrite `select` and `default_algorithm` in the type domain --- src/algorithms.jl | 60 +++++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index ea5d3e5c..aee55283 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -61,6 +61,7 @@ end MatrixAlgebraKit.select_algorithm(f, A, (; kwargs...)) Decide on an algorithm to use for implementing the function `f` on inputs of type `A`. +This can be obtained both for values `A` or types `A`. If `alg` is an `AbstractAlgorithm` instance, it will be returned as-is. @@ -76,39 +77,47 @@ passed as the third positional argument in the form of a `NamedTuple`. """ select_algorithm function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} + return select_algorithm(f, typeof(A), alg; kwargs...) +end +function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg} return _select_algorithm(f, A, alg; kwargs...) end -function _select_algorithm(f::F, A, alg::Nothing; kwargs...) where {F} +function _select_algorithm(f::F, ::Type{A}, alg::Nothing; kwargs...) where {F,A} return default_algorithm(f, A; kwargs...) end -function _select_algorithm(f::F, A, alg::Symbol; kwargs...) where {F} +function _select_algorithm(f::F, ::Type{A}, alg::Symbol; kwargs...) where {F,A} return Algorithm{alg}(; kwargs...) end -function _select_algorithm(f::F, A, ::Type{Alg}; kwargs...) where {F,Alg} +function _select_algorithm(f::F, ::Type{A}, ::Type{Alg}; kwargs...) where {F,A,Alg} return Alg(; kwargs...) end -function _select_algorithm(f::F, A, alg::NamedTuple; kwargs...) where {F} +function _select_algorithm(f::F, ::Type{A}, alg::NamedTuple; kwargs...) where {F,A} isempty(kwargs) || throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) return default_algorithm(f, A; alg...) end -function _select_algorithm(f::F, A, alg::AbstractAlgorithm; kwargs...) where {F} +function _select_algorithm(f::F, ::Type{A}, alg::AbstractAlgorithm; kwargs...) where {F,A} isempty(kwargs) || throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified.")) return alg end -function _select_algorithm(f::F, A, alg; kwargs...) where {F} +function _select_algorithm(f::F, ::Type{A}, alg; kwargs...) where {F,A} return throw(ArgumentError("Unknown alg $alg")) end @doc """ MatrixAlgebraKit.default_algorithm(f, A; kwargs...) + MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA} Select the default algorithm for a given factorization function `f` and input `A`. In general, this is called by [`select_algorithm`](@ref) if no algorithm is specified explicitly. +New types should prefer to register their default algorithms in the type domain. """ default_algorithm +default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) +# avoid infinite recursion: +default_algorithm(f, T::Type; kwargs...) = throw(MethodError(default_algorithm, (f, T))) @doc """ copy_input(f, A) @@ -172,25 +181,26 @@ macro functiondef(f) f isa Symbol || throw(ArgumentError("Unsupported usage of `@functiondef`")) f! = Symbol(f, :!) - return esc(quote - # out of place to inplace - $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) - $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) - - # fill in arguments - function $f!(A; alg=nothing, kwargs...) - return $f!(A, select_algorithm($f!, A, alg; kwargs...)) - end - function $f!(A, out; alg=nothing, kwargs...) - return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) - end - function $f!(A, alg::AbstractAlgorithm) - return $f!(A, initialize_output($f!, A, alg), alg) - end - - # copy documentation to both functions - Core.@__doc__ $f, $f! - end) + ex = quote + # out of place to inplace + $f(A; kwargs...) = $f!(copy_input($f, A); kwargs...) + $f(A, alg::AbstractAlgorithm) = $f!(copy_input($f, A), alg) + + # fill in arguments + function $f!(A; alg=nothing, kwargs...) + return $f!(A, select_algorithm($f!, A, alg; kwargs...)) + end + function $f!(A, out; alg=nothing, kwargs...) + return $f!(A, out, select_algorithm($f!, A, alg; kwargs...)) + end + function $f!(A, alg::AbstractAlgorithm) + return $f!(A, initialize_output($f!, A, alg), alg) + end + + # copy documentation to both functions + Core.@__doc__ $f, $f! + end + return esc(ex) end """ From d4af0aff55221d7ca8eee51288ae06adac1ff713 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 23 May 2025 09:48:12 -0400 Subject: [PATCH 03/11] Add algorithm selection in terms of inplace function --- src/algorithms.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index aee55283..b057cf6a 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -197,6 +197,14 @@ macro functiondef(f) return $f!(A, initialize_output($f!, A, alg), alg) end + # define fallbacks for algorithm selection + @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} + return select_algorithm($f!, A, alg; kwargs...) + end + @inline function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) + end + # copy documentation to both functions Core.@__doc__ $f, $f! end From 6b1fa572b00a4a29ce50b9988890e15b584ed115 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 09:46:36 -0400 Subject: [PATCH 04/11] change default function interfaces --- src/interface/eig.jl | 29 ++++++++++------------------- src/interface/eigh.jl | 29 ++++++++++------------------- src/interface/lq.jl | 16 ++++------------ src/interface/polar.jl | 20 ++++++-------------- src/interface/qr.jl | 16 ++++------------ src/interface/schur.jl | 17 +++++++---------- src/interface/svd.jl | 25 +++++++------------------ 7 files changed, 48 insertions(+), 104 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 9071a657..51a999aa 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -87,27 +87,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -for f in (:eig_full, :eig_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end - end +# Default to LAPACK for `StridedMatrix{<:BlasFloat}` +function default_algorithm(::typeof(eig_full!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_Expert(; kwargs...) end - -function select_algorithm(::typeof(eig_trunc), A, alg; kwargs...) - return select_algorithm(eig_trunc!, A, alg; kwargs...) +function default_algorithm(::typeof(eig_vals!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_Expert(; kwargs...) end -function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) + +function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end - -# Default to LAPACK -function default_eig_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_Expert(; kwargs...) -end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index b092795c..9eef8754 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -86,27 +86,18 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc) # Algorithm selection # ------------------- -for f in (:eigh_full, :eigh_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eigh_algorithm(A; kwargs...) - end - end +# Default to LAPACK for `StridedMatrix{<:BlasFloat}` +function default_algorithm(::typeof(eigh_full!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end - -function select_algorithm(::typeof(eigh_trunc), A, alg; kwargs...) - return select_algorithm(eigh_trunc!, A, alg; kwargs...) +function default_algorithm(::typeof(eigh_vals!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end -function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) + +function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end - -# Default to LAPACK -function default_eigh_algorithm(A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) -end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index e98223f1..acff0522 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -68,19 +68,11 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- -for f in (:lq_full, :lq_compact, :lq_null) - f! = Symbol(f, :!) +for f in (:lq_full!, :lq_compact!, :lq_null!) @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_lq_algorithm(A; kwargs...) + function default_algorithm(::typeof($f), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_HouseholderLQ(; kwargs...) end end end - -# Default to LAPACK -function default_lq_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_HouseholderLQ(; kwargs...) -end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index b209a327..2cea4ab6 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -60,19 +60,11 @@ end # Algorithm selection # ------------------- -for f in (:left_polar, :right_polar) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) +function default_algorithm(::typeof(left_polar!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_polar_algorithm(A; kwargs...) - end - end -end - -# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` -function default_polar_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return PolarViaSVD(default_svd_algorithm(A; kwargs...)) +function default_algorithm(::typeof(right_polar!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index cbded32d..bff4926e 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -68,19 +68,11 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- -for f in (:qr_full, :qr_compact, :qr_null) - f! = Symbol(f, :!) +for f in (:qr_full!, :qr_compact!, :qr_null!) @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_qr_algorithm(A; kwargs...) + function default_algorithm(::typeof($f), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_HouseholderQR(; kwargs...) end end end - -# Default to LAPACK -function default_qr_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_HouseholderQR(; kwargs...) -end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index c49e6a2b..64cfe32d 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -51,14 +51,11 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -for f in (:schur_full, :schur_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end - end +function default_algorithm(::typeof(schur_full!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return default_algorithm(eig_full!, A; kwargs...) +end +function default_algorithm(::typeof(schur_vals!), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return default_algorithm(eig_vals!, A; kwargs...) end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 1c5d7e3a..90fe7241 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -90,27 +90,16 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- -for f in (:svd_full, :svd_compact, :svd_vals) - f! = Symbol(f, :!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - function default_algorithm(::typeof($f!), A; kwargs...) - return default_svd_algorithm(A; kwargs...) - end +for f in (:svd_full!, :svd_compact!, :svd_vals!) + # Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` + @eval function default_algorithm(::typeof($f), ::Type{A}; + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + return LAPACK_DivideAndConquer(; kwargs...) end end -function select_algorithm(::typeof(svd_trunc), A, alg; kwargs...) - return select_algorithm(svd_trunc!, A, alg; kwargs...) -end -function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) +function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing, + kwargs...) where {A<:StridedMatrix{<:BlasFloat}} alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end - -# Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` -function default_svd_algorithm(A::StridedMatrix{<:BlasFloat}; kwargs...) - return LAPACK_DivideAndConquer(; kwargs...) -end From f7c324bc15cf689dfe2f9f9217afa224e3170252 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 09:46:44 -0400 Subject: [PATCH 05/11] fix tests --- test/orthnull.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index 7f497eaf..59554e54 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -4,7 +4,7 @@ using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, I, mul! using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_algorithm, initialize_output # Used to test non-AbstractMatrix codepaths. @@ -39,8 +39,9 @@ end function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) return check_input(right_orth!, parent(A), parent.(VC)) end -function MatrixAlgebraKit.default_svd_algorithm(A::LinearMap) - return default_svd_algorithm(parent(A)) +function MatrixAlgebraKit.default_algorithm(::typeof(svd_compact!), + ::Type{LinearMap{A}}) where {A} + return default_algorithm(svd_compact!, A) end function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, alg::LAPACK_SVDAlgorithm) From dbce8f470238b3fdda81c7250e7fa6827becafbb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 09:46:54 -0400 Subject: [PATCH 06/11] resolve method ambiguities --- src/algorithms.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index b057cf6a..749f2fee 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -117,7 +117,9 @@ New types should prefer to register their default algorithms in the type domain. """ default_algorithm default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) # avoid infinite recursion: -default_algorithm(f, T::Type; kwargs...) = throw(MethodError(default_algorithm, (f, T))) +function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T} + throw(MethodError(default_algorithm, (f, T))) +end @doc """ copy_input(f, A) @@ -198,10 +200,11 @@ macro functiondef(f) end # define fallbacks for algorithm selection - @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} + @inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg; + kwargs...) where {Alg,A} return select_algorithm($f!, A, alg; kwargs...) end - @inline function default_algorithm(::typeof($f), A; kwargs...) + @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_algorithm($f!, A; kwargs...) end From 396e7b01a6954f74780121a55e7b6791bf861e42 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 09:53:44 -0400 Subject: [PATCH 07/11] Add type alias for `BlasMat` --- src/interface/eig.jl | 8 ++++---- src/interface/eigh.jl | 8 ++++---- src/interface/lq.jl | 2 +- src/interface/polar.jl | 4 ++-- src/interface/qr.jl | 2 +- src/interface/schur.jl | 4 ++-- src/interface/svd.jl | 6 +++--- src/yalapack.jl | 3 +++ 8 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 51a999aa..ae0f2ade 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -87,18 +87,18 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -# Default to LAPACK for `StridedMatrix{<:BlasFloat}` +# Default to LAPACK for `YALAPACK.BlasMat` function default_algorithm(::typeof(eig_full!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_Expert(; kwargs...) end function default_algorithm(::typeof(eig_vals!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_Expert(; kwargs...) end function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 9eef8754..c49d5c01 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -86,18 +86,18 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc) # Algorithm selection # ------------------- -# Default to LAPACK for `StridedMatrix{<:BlasFloat}` +# Default to LAPACK for `YALAPACK.BlasMat` function default_algorithm(::typeof(eigh_full!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end function default_algorithm(::typeof(eigh_vals!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index acff0522..be254cf7 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -71,7 +71,7 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). for f in (:lq_full!, :lq_compact!, :lq_null!) @eval begin function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_HouseholderLQ(; kwargs...) end end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index 2cea4ab6..32489570 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -61,10 +61,10 @@ end # Algorithm selection # ------------------- function default_algorithm(::typeof(left_polar!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) end function default_algorithm(::typeof(right_polar!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index bff4926e..6542be26 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -71,7 +71,7 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). for f in (:qr_full!, :qr_compact!, :qr_null!) @eval begin function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_HouseholderQR(; kwargs...) end end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index 64cfe32d..7acea4db 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -52,10 +52,10 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- function default_algorithm(::typeof(schur_full!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return default_algorithm(eig_full!, A; kwargs...) end function default_algorithm(::typeof(schur_vals!), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return default_algorithm(eig_vals!, A; kwargs...) end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 90fe7241..593d770a 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -91,15 +91,15 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- for f in (:svd_full!, :svd_compact!, :svd_vals!) - # Default to LAPACK SDD for `StridedMatrix{<:BlasFloat}` + # Default to LAPACK SDD for `YALAPACK.BlasMat` @eval function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} return LAPACK_DivideAndConquer(; kwargs...) end end function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:StridedMatrix{<:BlasFloat}} + kwargs...) where {A<:YALAPACK.BlasMat} alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end diff --git a/src/yalapack.jl b/src/yalapack.jl index 2bc08946..e8f4b685 100644 --- a/src/yalapack.jl +++ b/src/yalapack.jl @@ -16,6 +16,9 @@ using LinearAlgebra: BlasFloat, BlasReal, BlasComplex, BlasInt, Char, LAPACK, using LinearAlgebra.BLAS: @blasfunc, libblastrampoline using LinearAlgebra.LAPACK: chkfinite, chktrans, chkside, chkuplofinite, chklapackerror +# type alias for matrices that are definitely supported by YALAPACK +const BlasMat{T<:BlasFloat} = StridedMatrix{T} + # LU factorisation for (getrf, getrs, elty) in ((:dgetrf_, :dgetrs_, :Float64), (:sgetrf_, :sgetrs_, :Float32), From 2c1b114b7b8c1b11e7fae4f4cabfe56643d16d1b Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 10:04:34 -0400 Subject: [PATCH 08/11] Bump version --- Project.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 53bac9a5..507d81df 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -36,4 +36,5 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote"] +test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras", "ChainRulesCore", + "ChainRulesTestUtils", "StableRNGs", "Zygote"] From a6fce548ea17457e3376231f194ee47da82b4462 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 11:47:57 -0400 Subject: [PATCH 09/11] Condense algorithm selection --- src/algorithms.jl | 40 +++++++++++++++++----------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 749f2fee..f559b42e 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -80,32 +80,26 @@ function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} return select_algorithm(f, typeof(A), alg; kwargs...) end function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg} - return _select_algorithm(f, A, alg; kwargs...) -end + if isnothing(alg) + return default_algorithm(f, A; kwargs...) + elseif alg isa Symbol + return Algorithm{alg}(; kwargs...) + elseif alg isa Type + return alg(; kwargs...) + elseif alg isa NamedTuple + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return default_algorithm(f, A; alg...) + elseif alg isa AbstractAlgorithm + isempty(kwargs) || + throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) + return alg + end -function _select_algorithm(f::F, ::Type{A}, alg::Nothing; kwargs...) where {F,A} - return default_algorithm(f, A; kwargs...) -end -function _select_algorithm(f::F, ::Type{A}, alg::Symbol; kwargs...) where {F,A} - return Algorithm{alg}(; kwargs...) -end -function _select_algorithm(f::F, ::Type{A}, ::Type{Alg}; kwargs...) where {F,A,Alg} - return Alg(; kwargs...) -end -function _select_algorithm(f::F, ::Type{A}, alg::NamedTuple; kwargs...) where {F,A} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed when algorithm parameters are specified.")) - return default_algorithm(f, A; alg...) -end -function _select_algorithm(f::F, ::Type{A}, alg::AbstractAlgorithm; kwargs...) where {F,A} - isempty(kwargs) || - throw(ArgumentError("Additional keyword arguments are not allowed when an algorithm is specified.")) - return alg -end -function _select_algorithm(f::F, ::Type{A}, alg; kwargs...) where {F,A} - return throw(ArgumentError("Unknown alg $alg")) + throw(ArgumentError("Unknown alg $alg")) end + @doc """ MatrixAlgebraKit.default_algorithm(f, A; kwargs...) MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA} From bd63c8fe6e559ac989a064b73b87bc34db41a034 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 20:29:36 -0400 Subject: [PATCH 10/11] Reinstate `default_f_algorithm` functions --- src/interface/eig.jl | 14 ++++++++------ src/interface/eigh.jl | 16 ++++++++++------ src/interface/lq.jl | 13 ++++++++++--- src/interface/polar.jl | 19 ++++++++++++------- src/interface/qr.jl | 10 +++++++++- src/interface/schur.jl | 11 ++++------- src/interface/svd.jl | 14 ++++++++++---- 7 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index ae0f2ade..77aa0672 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -87,14 +87,16 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -# Default to LAPACK for `YALAPACK.BlasMat` -function default_algorithm(::typeof(eig_full!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} +default_eig_algorithm(A; kwargs...) = default_eig_algorithm(typeof(A); kwargs...) +default_eig_algorithm(T::Type; kwargs...) = throw(MethodError(default_eig_algorithm, (T,))) +function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_Expert(; kwargs...) end -function default_algorithm(::typeof(eig_vals!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return LAPACK_Expert(; kwargs...) + +for f in (:eig_full!, :eig_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) + end end function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index c49d5c01..3ed38789 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -86,16 +86,20 @@ See also [`eigh_full(!)`](@ref eigh_full) and [`eigh_trunc(!)`](@ref eigh_trunc) # Algorithm selection # ------------------- -# Default to LAPACK for `YALAPACK.BlasMat` -function default_algorithm(::typeof(eigh_full!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) +default_eigh_algorithm(A; kwargs...) = default_eigh_algorithm(typeof(A); kwargs...) +function default_eigh_algorithm(T::Type; kwargs...) + throw(MethodError(default_eigh_algorithm, (T,))) end -function default_algorithm(::typeof(eigh_vals!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} +function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} return LAPACK_MultipleRelativelyRobustRepresentations(; kwargs...) end +for f in (:eigh_full!, :eigh_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eigh_algorithm(A; kwargs...) + end +end + function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, kwargs...) where {A<:YALAPACK.BlasMat} alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) diff --git a/src/interface/lq.jl b/src/interface/lq.jl index be254cf7..9de85bae 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -68,11 +68,18 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- +default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...) +function default_lq_algorithm(T::Type; kwargs...) + throw(MethodError(default_lq_algorithm, (T,))) +end +function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_HouseholderLQ(; kwargs...) +end + for f in (:lq_full!, :lq_compact!, :lq_null!) @eval begin - function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return LAPACK_HouseholderLQ(; kwargs...) + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_lq_algorithm(A; kwargs...) end end end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index 32489570..111a6e3a 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -60,11 +60,16 @@ end # Algorithm selection # ------------------- -function default_algorithm(::typeof(left_polar!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) - end -function default_algorithm(::typeof(right_polar!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return PolarViaSVD(default_algorithm(svd_compact!, A; kwargs...)) +default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...) +function default_polar_algorithm(T::Type; kwargs...) + throw(MethodError(default_polar_algorithm, (T,))) +end +function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...)) +end + +for f in (:left_polar!, :right_polar!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_polar_algorithm(A; kwargs...) + end end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index 6542be26..a1c12b7a 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -68,11 +68,19 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- +default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...) +function default_qr_algorithm(T::Type; kwargs...) + throw(MethodError(default_qr_algorithm, (T,))) +end +function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_HouseholderQR(; kwargs...) +end + for f in (:qr_full!, :qr_compact!, :qr_null!) @eval begin function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A<:YALAPACK.BlasMat} - return LAPACK_HouseholderQR(; kwargs...) + return default_qr_algorithm(A; kwargs...) end end end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index 7acea4db..19f6dc00 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -51,11 +51,8 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- -function default_algorithm(::typeof(schur_full!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return default_algorithm(eig_full!, A; kwargs...) -end -function default_algorithm(::typeof(schur_vals!), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return default_algorithm(eig_vals!, A; kwargs...) +for f in (:schur_full!, :schur_vals!) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) + end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 593d770a..e6f9021a 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -90,11 +90,17 @@ See also [`svd_full(!)`](@ref svd_full), [`svd_compact(!)`](@ref svd_compact) an # Algorithm selection # ------------------- +default_svd_algorithm(A; kwargs...) = default_svd_algorithm(typeof(A); kwargs...) +function default_svd_algorithm(T::Type; kwargs...) + throw(MethodError(default_svd_algorithm, (T,))) +end +function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} + return LAPACK_DivideAndConquer(; kwargs...) +end + for f in (:svd_full!, :svd_compact!, :svd_vals!) - # Default to LAPACK SDD for `YALAPACK.BlasMat` - @eval function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} - return LAPACK_DivideAndConquer(; kwargs...) + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) end end From 917c7fa29c53164984231a1f566e11662dcc91ea Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 27 May 2025 20:33:22 -0400 Subject: [PATCH 11/11] refix tests --- test/orthnull.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/orthnull.jl b/test/orthnull.jl index 59554e54..b0004739 100644 --- a/test/orthnull.jl +++ b/test/orthnull.jl @@ -4,7 +4,7 @@ using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, I, mul! using MatrixAlgebraKit: TruncationKeepAbove, TruncationKeepBelow -using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_algorithm, +using MatrixAlgebraKit: LAPACK_SVDAlgorithm, check_input, copy_input, default_svd_algorithm, initialize_output # Used to test non-AbstractMatrix codepaths. @@ -39,9 +39,8 @@ end function MatrixAlgebraKit.check_input(::typeof(right_orth!), A::LinearMap, VC) return check_input(right_orth!, parent(A), parent.(VC)) end -function MatrixAlgebraKit.default_algorithm(::typeof(svd_compact!), - ::Type{LinearMap{A}}) where {A} - return default_algorithm(svd_compact!, A) +function MatrixAlgebraKit.default_svd_algorithm(::Type{LinearMap{A}}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) end function MatrixAlgebraKit.initialize_output(::typeof(svd_compact!), A::LinearMap, alg::LAPACK_SVDAlgorithm)