From 96fde14c19adbc918e05b770c0715ccd62dd090a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 09:43:12 -0400 Subject: [PATCH 01/12] Make select_algorithm more agnostic about being in the object or type domain --- Project.toml | 2 +- src/algorithms.jl | 6 +----- src/interface/eig.jl | 6 ++++-- src/interface/eigh.jl | 6 ++++-- src/interface/svd.jl | 6 ++++-- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 507d81df..6736baa5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MatrixAlgebraKit" uuid = "6c742aac-3347-4629-af66-fc926824e5e4" authors = ["Jutho and contributors"] -version = "0.2.1" +version = "0.2.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/algorithms.jl b/src/algorithms.jl index f559b42e..80309d47 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -76,10 +76,7 @@ Finally, the same behavior is obtained when the keyword arguments are passed as the third positional argument in the form of a `NamedTuple`. """ select_algorithm -function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} - return select_algorithm(f, typeof(A), alg; kwargs...) -end -function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F,A,Alg} +function select_algorithm(f::F, A::T, alg::Alg=nothing; kwargs...) where {F,T,Alg} if isnothing(alg) return default_algorithm(f, A; kwargs...) elseif alg isa Symbol @@ -99,7 +96,6 @@ function select_algorithm(f::F, ::Type{A}, alg::Alg=nothing; kwargs...) where {F throw(ArgumentError("Unknown alg $alg")) end - @doc """ MatrixAlgebraKit.default_algorithm(f, A; kwargs...) MatrixAlgebraKit.default_algorithm(f, ::Type{TA}; kwargs...) where {TA} diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 77aa0672..7b6ed7c8 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -99,8 +99,10 @@ for f in (:eig_full!, :eig_vals!) end end -function select_algorithm(::typeof(eig_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} +function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end +function select_algorithm(::typeof(eig_trunc), A, alg; trunc=nothing, kwargs...) + return select_algorithm(eig_trunc!, A, alg; trunc, kwargs...) +end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 3ed38789..73eec6da 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -100,8 +100,10 @@ for f in (:eigh_full!, :eigh_vals!) end end -function select_algorithm(::typeof(eigh_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} +function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end +function select_algorithm(::typeof(eigh_trunc), A, alg; trunc=nothing, kwargs...) + return select_algorithm(eigh_trunc!, A, alg; trunc, kwargs...) +end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index e6f9021a..a2c90731 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -104,8 +104,10 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end end -function select_algorithm(::typeof(svd_trunc!), ::Type{A}, alg; trunc=nothing, - kwargs...) where {A<:YALAPACK.BlasMat} +function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end +function select_algorithm(::typeof(svd_trunc), A, alg; trunc=nothing, kwargs...) + return select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) +end From 66fb7efba6bd5be0171020e19c6a41c60da0342d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 10:04:44 -0400 Subject: [PATCH 02/12] Fix ambiguity errors reported by Aqua --- src/interface/eig.jl | 9 +++++++-- src/interface/eigh.jl | 9 +++++++-- src/interface/lq.jl | 3 +++ src/interface/polar.jl | 14 ++++++++------ src/interface/qr.jl | 6 ++++-- src/interface/schur.jl | 9 +++++++-- src/interface/svd.jl | 9 +++++++-- 7 files changed, 43 insertions(+), 16 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 7b6ed7c8..5c495b1a 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -94,8 +94,13 @@ function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:eig_full!, :eig_vals!) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eig_algorithm(A; kwargs...) + @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_eig_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) + end end end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 73eec6da..91318aff 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -95,8 +95,13 @@ function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat end for f in (:eigh_full!, :eigh_vals!) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eigh_algorithm(A; kwargs...) + @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_eigh_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eigh_algorithm(A; kwargs...) + end end end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 9de85bae..65b6af1d 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -78,6 +78,9 @@ end for f in (:lq_full!, :lq_compact!, :lq_null!) @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_lq_algorithm(A; kwargs...) + end function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_lq_algorithm(A; kwargs...) end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index 111a6e3a..a9c1056d 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -61,15 +61,17 @@ end # Algorithm selection # ------------------- default_polar_algorithm(A; kwargs...) = default_polar_algorithm(typeof(A); kwargs...) -function default_polar_algorithm(T::Type; kwargs...) - throw(MethodError(default_polar_algorithm, (T,))) -end -function default_polar_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} +function default_polar_algorithm(::Type{T}; kwargs...) where {T} return PolarViaSVD(default_algorithm(svd_compact!, T; kwargs...)) end for f in (:left_polar!, :right_polar!) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_polar_algorithm(A; kwargs...) + @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_polar_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_polar_algorithm(A; kwargs...) + end end end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index a1c12b7a..d1538573 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -78,8 +78,10 @@ end for f in (:qr_full!, :qr_compact!, :qr_null!) @eval begin - function default_algorithm(::typeof($f), ::Type{A}; - kwargs...) where {A<:YALAPACK.BlasMat} + function default_algorithm(::typeof($f), A; kwargs...) + return default_qr_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_qr_algorithm(A; kwargs...) end end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index 19f6dc00..2e17fc52 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -52,7 +52,12 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- for f in (:schur_full!, :schur_vals!) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eig_algorithm(A; kwargs...) + @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_eig_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) + end end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index a2c90731..f4d5e8f3 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -99,8 +99,13 @@ function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:svd_full!, :svd_compact!, :svd_vals!) - @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_svd_algorithm(A; kwargs...) + @eval begin + function default_algorithm(::typeof($f), A; kwargs...) + return default_svd_algorithm(A; kwargs...) + end + function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) + end end end From af7c9aff328877812204def31d3f7aa9501e1c33 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 13:27:16 -0400 Subject: [PATCH 03/12] Simplify default_algorithm to default_f_algorithm forwarding --- src/interface/eig.jl | 9 ++------- src/interface/eigh.jl | 9 ++------- src/interface/lq.jl | 9 ++------- src/interface/polar.jl | 9 ++------- src/interface/qr.jl | 9 ++------- src/interface/schur.jl | 9 ++------- src/interface/svd.jl | 9 ++------- 7 files changed, 14 insertions(+), 49 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 5c495b1a..7b6ed7c8 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -94,13 +94,8 @@ function default_eig_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:eig_full!, :eig_vals!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eig_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) end end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 91318aff..73eec6da 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -95,13 +95,8 @@ function default_eigh_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat end for f in (:eigh_full!, :eigh_vals!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_eigh_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eigh_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eigh_algorithm(A; kwargs...) end end diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 65b6af1d..6f1ed12f 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -77,12 +77,7 @@ function default_lq_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:lq_full!, :lq_compact!, :lq_null!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_lq_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_lq_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_lq_algorithm(A; kwargs...) end end diff --git a/src/interface/polar.jl b/src/interface/polar.jl index a9c1056d..87346ff2 100644 --- a/src/interface/polar.jl +++ b/src/interface/polar.jl @@ -66,12 +66,7 @@ function default_polar_algorithm(::Type{T}; kwargs...) where {T} end for f in (:left_polar!, :right_polar!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_polar_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_polar_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_polar_algorithm(A; kwargs...) end end diff --git a/src/interface/qr.jl b/src/interface/qr.jl index d1538573..62d87080 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -77,12 +77,7 @@ function default_qr_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:qr_full!, :qr_compact!, :qr_null!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_qr_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_qr_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_qr_algorithm(A; kwargs...) end end diff --git a/src/interface/schur.jl b/src/interface/schur.jl index 2e17fc52..19f6dc00 100644 --- a/src/interface/schur.jl +++ b/src/interface/schur.jl @@ -52,12 +52,7 @@ See also [`eig_full(!)`](@ref eig_full) and [`eig_trunc(!)`](@ref eig_trunc). # Algorithm selection # ------------------- for f in (:schur_full!, :schur_vals!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_eig_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_eig_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_eig_algorithm(A; kwargs...) end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index f4d5e8f3..a2c90731 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -99,13 +99,8 @@ function default_svd_algorithm(::Type{T}; kwargs...) where {T<:YALAPACK.BlasMat} end for f in (:svd_full!, :svd_compact!, :svd_vals!) - @eval begin - function default_algorithm(::typeof($f), A; kwargs...) - return default_svd_algorithm(A; kwargs...) - end - function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} - return default_svd_algorithm(A; kwargs...) - end + @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_svd_algorithm(A; kwargs...) end end From 8fccedbd84e4e2c641097c3768df206eaeb13d97 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 13:31:08 -0400 Subject: [PATCH 04/12] More simplifications --- src/algorithms.jl | 5 ++--- src/interface/eig.jl | 3 --- src/interface/eigh.jl | 3 --- src/interface/svd.jl | 3 --- 4 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 80309d47..bf0bf32d 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -190,11 +190,10 @@ macro functiondef(f) end # define fallbacks for algorithm selection - @inline function select_algorithm(::typeof($f), ::Type{A}, alg::Alg; - kwargs...) where {Alg,A} + @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} return select_algorithm($f!, A, alg; kwargs...) end - @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + @inline function default_algorithm(::typeof($f), A; kwargs...) return default_algorithm($f!, A; kwargs...) end diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 7b6ed7c8..46621cfb 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -103,6 +103,3 @@ function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs... alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) end -function select_algorithm(::typeof(eig_trunc), A, alg; trunc=nothing, kwargs...) - return select_algorithm(eig_trunc!, A, alg; trunc, kwargs...) -end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 73eec6da..48602687 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -104,6 +104,3 @@ function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs.. alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) end -function select_algorithm(::typeof(eigh_trunc), A, alg; trunc=nothing, kwargs...) - return select_algorithm(eigh_trunc!, A, alg; trunc, kwargs...) -end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index a2c90731..48416944 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -108,6 +108,3 @@ function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs... alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) end -function select_algorithm(::typeof(svd_trunc), A, alg; trunc=nothing, kwargs...) - return select_algorithm(svd_trunc!, A, alg; trunc, kwargs...) -end From 750c096aeae353e94e5ab1ae33a705bddc782db2 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 30 May 2025 14:31:06 -0400 Subject: [PATCH 05/12] Fix ambiguity error --- src/algorithms.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index bf0bf32d..45268ebd 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -196,6 +196,10 @@ macro functiondef(f) @inline function default_algorithm(::typeof($f), A; kwargs...) return default_algorithm($f!, A; kwargs...) end + # fix ambiguity error + @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} + return default_algorithm($f!, A; kwargs...) + end # copy documentation to both functions Core.@__doc__ $f, $f! From f003e2f94d6a6346ed4d0be6dcd301d45b26d1f9 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sat, 31 May 2025 21:20:32 -0400 Subject: [PATCH 06/12] Simplify --- src/algorithms.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 45268ebd..30637b35 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -193,10 +193,6 @@ macro functiondef(f) @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} return select_algorithm($f!, A, alg; kwargs...) end - @inline function default_algorithm(::typeof($f), A; kwargs...) - return default_algorithm($f!, A; kwargs...) - end - # fix ambiguity error @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_algorithm($f!, A; kwargs...) end From acf36f206a2120b9aa57c68b5286534b8c24c611 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Sun, 1 Jun 2025 09:49:16 -0400 Subject: [PATCH 07/12] Bring back generic fallback --- src/algorithms.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/algorithms.jl b/src/algorithms.jl index 30637b35..6ac0193b 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -193,6 +193,15 @@ macro functiondef(f) @inline function select_algorithm(::typeof($f), A, alg::Alg; kwargs...) where {Alg} return select_algorithm($f!, A, alg; kwargs...) end + # define default algorithm fallbacks for out-of-place functions + # in terms of the corresponding in-place function + @inline function default_algorithm(::typeof($f), A; kwargs...) + return default_algorithm($f!, A; kwargs...) + end + # define default algorithm fallbacks for out-of-place functions + # in terms of the corresponding in-place function for types, + # in principle this is covered by the definition above but + # it is necessary to avoid ambiguity errors @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_algorithm($f!, A; kwargs...) end From 0cf1820eb57cff4c2cc9cee3d55e03dbffa476a7 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 3 Jun 2025 10:35:54 -0400 Subject: [PATCH 08/12] Remove unnecessary specialization --- src/algorithms.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index 6ac0193b..b2365bd0 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -76,7 +76,7 @@ Finally, the same behavior is obtained when the keyword arguments are passed as the third positional argument in the form of a `NamedTuple`. """ select_algorithm -function select_algorithm(f::F, A::T, alg::Alg=nothing; kwargs...) where {F,T,Alg} +function select_algorithm(f::F, A, alg::Alg=nothing; kwargs...) where {F,Alg} if isnothing(alg) return default_algorithm(f, A; kwargs...) elseif alg isa Symbol From ab13437f20a12f510ccd18ff3046bf063d29e74a Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 3 Jun 2025 10:59:14 -0400 Subject: [PATCH 09/12] Better handling of TruncatedAlgorithm in select_algorithm --- src/interface/eig.jl | 10 ++++++++-- src/interface/eigh.jl | 10 ++++++++-- src/interface/svd.jl | 10 ++++++++-- test/algorithms.jl | 6 ++++++ test/eig.jl | 14 ++++++++++++++ test/eigh.jl | 16 ++++++++++++++++ test/svd.jl | 15 +++++++++++++++ 7 files changed, 75 insertions(+), 6 deletions(-) diff --git a/src/interface/eig.jl b/src/interface/eig.jl index 46621cfb..fd75193f 100644 --- a/src/interface/eig.jl +++ b/src/interface/eig.jl @@ -100,6 +100,12 @@ for f in (:eig_full!, :eig_vals!) end function select_algorithm(::typeof(eig_trunc!), A, alg; trunc=nothing, kwargs...) - alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eig_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end diff --git a/src/interface/eigh.jl b/src/interface/eigh.jl index 48602687..a650ca44 100644 --- a/src/interface/eigh.jl +++ b/src/interface/eigh.jl @@ -101,6 +101,12 @@ for f in (:eigh_full!, :eigh_vals!) end function select_algorithm(::typeof(eigh_trunc!), A, alg; trunc=nothing, kwargs...) - alg_eigh = select_algorithm(eigh_full!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_eigh, select_truncation(trunc)) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_eig = select_algorithm(eigh_full!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_eig, select_truncation(trunc)) + end end diff --git a/src/interface/svd.jl b/src/interface/svd.jl index 48416944..fd4eb5a5 100644 --- a/src/interface/svd.jl +++ b/src/interface/svd.jl @@ -105,6 +105,12 @@ for f in (:svd_full!, :svd_compact!, :svd_vals!) end function select_algorithm(::typeof(svd_trunc!), A, alg; trunc=nothing, kwargs...) - alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) - return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + if alg isa TruncatedAlgorithm + isnothing(trunc) || + throw(ArgumentError("`trunc` can't be specified when `alg` is a `TruncatedAlgorithm`")) + return alg + else + alg_svd = select_algorithm(svd_compact!, A, alg; kwargs...) + return TruncatedAlgorithm(alg_svd, select_truncation(trunc)) + end end diff --git a/test/algorithms.jl b/test/algorithms.jl index 49524a89..ee3865c2 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -50,6 +50,12 @@ end NoTruncation()) end + alg = TruncatedAlgorithm(LAPACK_Simple(), TruncationKeepBelow(0.1, 0.0)) + for f in (eig_trunc!, eigh_trunc!, svd_trunc!) + @test @constinferred(select_algorithm(eig_trunc!, A, alg)) === alg + @test_throws ArgumentError select_algorithm(eig_trunc!, A, alg; trunc=(; maxrank=2)) + end + @test @constinferred(select_algorithm(svd_compact!, A)) === LAPACK_DivideAndConquer() @test @constinferred(select_algorithm(svd_compact!, A, nothing)) === LAPACK_DivideAndConquer() diff --git a/test/eig.jl b/test/eig.jl index cd73d94e..be5f9f74 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -57,3 +57,17 @@ end @test V2 * ((V2' * V2) \ (V2' * V1)) ≈ V1 end end + +@testset "eig_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(randn(rng, T, m, m))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) + D2, V2 = @constinferred eig_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eig_trunc(A; alg, trunc=(; maxrank=2)) +end diff --git a/test/eigh.jl b/test/eigh.jl index 5a3c5a8a..4b63ebd2 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -62,3 +62,19 @@ end @test V2 * (V2' * V1) ≈ V1 end end + +@testset "eigh_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, + ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + V = qr_compact(randn(rng, T, m, m))[1] + D = Diagonal([0.9, 0.3, 0.1, 0.01]) + A = V * D * V' + A = (A + A') / 2 + alg = TruncatedAlgorithm(LAPACK_QRIteration(), truncrank(2)) + D2, V2 = @constinferred eigh_trunc(A; alg) + @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError eigh_trunc(A; alg, trunc=(; maxrank=2)) +end diff --git a/test/svd.jl b/test/svd.jl index eb6a7805..95851d44 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -152,3 +152,18 @@ end end end end + +@testset "svd_trunc! specify truncation algorithm T = $T" for T in + (Float32, Float64, ComplexF32, + ComplexF64) + rng = StableRNG(123) + m = 4 + U = qr_compact(randn(rng, T, m, m))[1] + S = Diagonal([0.9, 0.3, 0.1, 0.01]) + Vᴴ = qr_compact(randn(rng, T, m, m))[1] + A = U * S * Vᴴ + alg = TruncatedAlgorithm(LAPACK_DivideAndConquer(), TruncationKeepAbove(0.2, 0.0)) + U2, S2, V2ᴴ = @constinferred svd_trunc(A; alg) + @test diagview(S2) ≈ diagview(S)[1:2] rtol = sqrt(eps(real(T))) + @test_throws ArgumentError svd_trunc(A; alg, trunc=(; maxrank=2)) +end From 5953909260af789846904d53bd442b5136ef23bf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 3 Jun 2025 11:01:58 -0400 Subject: [PATCH 10/12] Fix namespace issue --- test/algorithms.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/algorithms.jl b/test/algorithms.jl index ee3865c2..9f9c4542 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: LAPACK_SVDAlgorithm, NoTruncation, PolarViaSVD, TruncatedAlgorithm, - default_algorithm, select_algorithm + TruncationKeepBelow, default_algorithm, select_algorithm @testset "default_algorithm" begin A = randn(3, 3) From 58f7ae0366322a294333d8e845d3eec954ce7a88 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 3 Jun 2025 11:05:38 -0400 Subject: [PATCH 11/12] Another namespace fix --- test/eig.jl | 2 +- test/eigh.jl | 2 +- test/svd.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/eig.jl b/test/eig.jl index be5f9f74..3ff0fb90 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: Diagonal -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncatedAlgorithm, diagview @testset "eig_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/eigh.jl b/test/eigh.jl index 4b63ebd2..6e785158 100644 --- a/test/eigh.jl +++ b/test/eigh.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I -using MatrixAlgebraKit: diagview +using MatrixAlgebraKit: TruncatedAlgorithm, diagview @testset "eigh_full! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) diff --git a/test/svd.jl b/test/svd.jl index 95851d44..40de0897 100644 --- a/test/svd.jl +++ b/test/svd.jl @@ -3,7 +3,7 @@ using Test using TestExtras using StableRNGs using LinearAlgebra: LinearAlgebra, Diagonal, I, isposdef -using MatrixAlgebraKit: TruncationKeepAbove, diagview +using MatrixAlgebraKit: TruncatedAlgorithm, TruncationKeepAbove, diagview @testset "svd_compact! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) rng = StableRNG(123) From 631e53769665fea75646de92f24347b0f3d2bdcf Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 4 Jun 2025 08:12:00 -0400 Subject: [PATCH 12/12] Clarify comment on default_algorithm ambiguity more, make eig test more general --- src/algorithms.jl | 8 +++++++- test/eig.jl | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/algorithms.jl b/src/algorithms.jl index b2365bd0..af183103 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -201,7 +201,13 @@ macro functiondef(f) # define default algorithm fallbacks for out-of-place functions # in terms of the corresponding in-place function for types, # in principle this is covered by the definition above but - # it is necessary to avoid ambiguity errors + # it is necessary to avoid ambiguity errors with the generic definitions: + # ```julia + # default_algorithm(f::F, A; kwargs...) where {F} = default_algorithm(f, typeof(A); kwargs...) + # function default_algorithm(f::F, ::Type{T}; kwargs...) where {F,T} + # throw(MethodError(default_algorithm, (f, T))) + # end + # ``` @inline function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} return default_algorithm($f!, A; kwargs...) end diff --git a/test/eig.jl b/test/eig.jl index 3ff0fb90..cdaec9dc 100644 --- a/test/eig.jl +++ b/test/eig.jl @@ -63,9 +63,9 @@ end ComplexF64) rng = StableRNG(123) m = 4 - V = qr_compact(randn(rng, T, m, m))[1] + V = randn(rng, T, m, m) D = Diagonal([0.9, 0.3, 0.1, 0.01]) - A = V * D * V' + A = V * D * inv(V) alg = TruncatedAlgorithm(LAPACK_Simple(), truncrank(2)) D2, V2 = @constinferred eig_trunc(A; alg) @test diagview(D2) ≈ diagview(D)[1:2] rtol = sqrt(eps(real(T)))