diff --git a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl index 8a7c2ef2..c8662254 100644 --- a/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl +++ b/ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl @@ -4,9 +4,9 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm +using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank using AMDGPU using LinearAlgebra @@ -28,10 +28,10 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T return ROCSOLVER_DivideAndConquer(; kwargs...) end -_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A) -_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ) -_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) = - YArocSOLVER.unmqr!(side, trans, A, τ, C) +for f in (:geqrf!, :ungqr!, :unmqr!) + @eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...) +end + _gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = YArocSOLVER.gesvd!(A, S, U, Vᴴ) # not yet supported diff --git a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl index fb67149e..5b3f91f8 100644 --- a/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl +++ b/ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl @@ -4,9 +4,9 @@ using MatrixAlgebraKit using MatrixAlgebraKit: @algdef, Algorithm, check_input using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! using MatrixAlgebraKit: diagview, sign_safe -using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm +using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm -import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! +import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank using CUDA, CUDA.CUBLAS using CUDA: i32 @@ -32,6 +32,7 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT return CUSOLVER_DivideAndConquer(; kwargs...) end + # include for block sector support function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}} return CUSOLVER_HouseholderQR(; kwargs...) @@ -50,14 +51,12 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2, return CUSOLVER_DivideAndConquer(; kwargs...) end +for f in (:geqrf!, :ungqr!, :unmqr!) + @eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...) +end + _gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) = YACUSOLVER.Xgeev!(A, D, V) -_gpu_geqrf!(A::StridedCuMatrix) = - YACUSOLVER.geqrf!(A) -_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) = - YACUSOLVER.ungqr!(A, τ) -_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) = - YACUSOLVER.unmqr!(side, trans, A, τ, C) _gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) = YACUSOLVER.gesvd!(A, S, U, Vᴴ) _gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) = diff --git a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl index fd4cec35..e9e1e395 100644 --- a/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl +++ b/ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl @@ -1,8 +1,7 @@ module MatrixAlgebraKitGenericLinearAlgebraExt using MatrixAlgebraKit -using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge -using MatrixAlgebraKit: left_orth_alg +using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr! using LinearAlgebra: I, Diagonal, lmul! @@ -57,38 +56,25 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration) return eigvals!(Hermitian(A); sortby = real) end -function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} - return GLA_HouseholderQR(; kwargs...) -end - -function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) - check_input(qr_full!, A, QR, alg) - Q, R = QR - return _gla_householder_qr!(A, Q, R; alg.kwargs...) -end - -function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR) - check_input(qr_compact!, A, QR, alg) - Q, R = QR - return _gla_householder_qr!(A, Q, R; alg.kwargs...) -end - -function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR) - check_input(qr_null!, A, N, alg) - return _gla_householder_qr_null!(A, N; alg.kwargs...) -end - -function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false) - pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) - (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) +function MatrixAlgebraKit.householder_qr!( + driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + ) + blocksize == 1 || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) + pivoted && + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) m, n = size(A) - k = min(m, n) + minmn = min(m, n) + computeR = length(R) > 0 + + # compute QR Q̃, R̃ = qr!(A) lmul!(Q̃, MatrixAlgebraKit.one!(Q)) if positive - @inbounds for j in 1:k + @inbounds for j in 1:minmn s = sign_safe(R̃[j, j]) @simd for i in 1:m Q[i, j] *= s @@ -96,42 +82,39 @@ function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksiz end end - computeR = length(R) > 0 if computeR if positive @inbounds for j in n:-1:1 - @simd for i in 1:min(k, j) + @simd for i in 1:min(minmn, j) R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) end - @simd for i in (min(k, j) + 1):size(R, 1) + @simd for i in (min(minmn, j) + 1):size(R, 1) R[i, j] = zero(eltype(R)) end end else - R[1:k, :] .= R̃ - MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :])) + R[1:minmn, :] .= R̃ + MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :])) end end return Q, R end -function _gla_householder_qr_null!( - A::AbstractMatrix, N::AbstractMatrix; - positive = true, blocksize = 1, pivoted = false +function MatrixAlgebraKit.householder_qr_null!( + driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 ) - pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR.")) - (blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR.")) + blocksize == 1 || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) + pivoted && + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) + m, n = size(A) minmn = min(m, n) - fill!(N, zero(eltype(N))) + zero!(N) one!(view(N, (minmn + 1):m, 1:(m - minmn))) Q̃, = qr!(A) - lmul!(Q̃, N) - return N -end - -function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}} - return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...)) + return lmul!(Q̃, N) end MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg) diff --git a/src/algorithms.jl b/src/algorithms.jl index e9e7b8e8..8ffc22a5 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -143,6 +143,59 @@ If this is not possible, for example when the output size is not known a priori this function may return `nothing`. """ initialize_output + +# Drivers +# ------- +""" + abstract type Driver + +Supertype used for customizing various implementations of the same algorithm. +""" +abstract type Driver end + +""" + DefaultDriver <: Driver + +Select a default driver at runtime, based on the input matrix. +""" +struct DefaultDriver <: Driver end + +""" + LAPACK <: Driver + +Driver to select LAPACK as the implementation strategy. +""" +struct LAPACK <: Driver end + +""" + CUSOLVER <: Driver + +Driver to select CUSOLVER as the implementation strategy. +""" +struct CUSOLVER <: Driver end + +""" + ROCSOLVER <: Driver + +Driver to select ROCSOLVER as the implementation strategy. +""" +struct ROCSOLVER <: Driver end + +""" + GLA <: Driver + +Driver to select GenericLinearAlgebra.jl as the implementation strategy. +""" +struct GLA <: Driver end + +""" + Native <: Driver + +Driver to select a native implementation in MatrixAlgebraKit as the implementation strategy. +""" +struct Native <: Driver end + + # Truncation strategy # ------------------- """ diff --git a/src/common/householder.jl b/src/common/householder.jl index 0aac21bc..01567232 100644 --- a/src/common/householder.jl +++ b/src/common/householder.jl @@ -1,32 +1,32 @@ const IndexRange{T <: Integer} = Base.AbstractRange{T} # Elementary Householder reflection -struct Householder{T, V <: AbstractVector, R <: IndexRange} +struct HouseholderReflection{T, V <: AbstractVector, R <: IndexRange} β::T v::V r::R end -Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r) +Base.adjoint(H::HouseholderReflection) = HouseholderReflection(conj(H.β), H.v, H.r) function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r)) i = findfirst(==(k), r) i == nothing && error("k = $k should be in the range r = $r") β, v, ν = _householder!(x[r], i) - return Householder(β, v, r), ν + return HouseholderReflection(β, v, r), ν end # Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A) function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r)) i = findfirst(==(k), r) i == nothing && error("k = $k should be in the range r = $r") β, v, ν = _householder!(A[r, col], i) - return Householder(β, v, r), ν + return HouseholderReflection(β, v, r), ν end # Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h') function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r)) i = findfirst(==(k), r) i == nothing && error("k = $k should be in the range r = $r") β, v, ν = _householder!(conj!(A[row, r]), i) - return Householder(β, v, r), ν + return HouseholderReflection(β, v, r), ν end # generate Householder vector based on vector v, such that applying the reflection @@ -66,7 +66,7 @@ function _householder!(v::AbstractVector{T}, i::Int = 1) where {T} return β, v, ν end -function LinearAlgebra.lmul!(H::Householder, x::AbstractVector) +function LinearAlgebra.lmul!(H::HouseholderReflection, x::AbstractVector) v = H.v r = H.r β = H.β @@ -87,7 +87,7 @@ function LinearAlgebra.lmul!(H::Householder, x::AbstractVector) end return x end -function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2)) +function LinearAlgebra.lmul!(H::HouseholderReflection, A::AbstractMatrix; cols = axes(A, 2)) v = H.v r = H.r β = H.β @@ -110,7 +110,7 @@ function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2 end return A end -function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1)) +function LinearAlgebra.rmul!(A::AbstractMatrix, H::HouseholderReflection; rows = axes(A, 1)) v = H.v r = H.r β = H.β diff --git a/src/implementations/lq.jl b/src/implementations/lq.jl index b11533a8..4e2c1150 100644 --- a/src/implementations/lq.jl +++ b/src/implementations/lq.jl @@ -86,65 +86,48 @@ for f! in (:lq_full!, :lq_compact!) end end -# Implementation -# -------------- -function lq_full!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) - check_input(lq_full!, A, LQ, alg) - L, Q = LQ - _lapack_lq!(A, L, Q; alg.kwargs...) - return L, Q -end -function lq_compact!(A::AbstractMatrix, LQ, alg::LAPACK_HouseholderLQ) - check_input(lq_compact!, A, LQ, alg) - L, Q = LQ - _lapack_lq!(A, L, Q; alg.kwargs...) - return L, Q -end -function lq_null!(A::AbstractMatrix, Nᴴ, alg::LAPACK_HouseholderLQ) - check_input(lq_null!, A, Nᴴ, alg) - _lapack_lq_null!(A, Nᴴ; alg.kwargs...) - return Nᴴ -end +# ========================== +# IMPLEMENTATIONS +# ========================== -function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) +# Householder +# ----------- +function lq_full!(A, LQ, alg::Householder) check_input(lq_full!, A, LQ, alg) - L, Q = LQ - lq_via_qr!(A, L, Q, alg.qr_alg) - return L, Q + return householder_lq!(alg.driver, A, LQ...; alg.kwargs...) end -function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) +function lq_compact!(A, LQ, alg::Householder) check_input(lq_compact!, A, LQ, alg) - L, Q = LQ - lq_via_qr!(A, L, Q, alg.qr_alg) - return L, Q + return householder_lq!(alg.driver, A, LQ...; alg.kwargs...) end -function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR) +function lq_null!(A, Nᴴ, alg::Householder) check_input(lq_null!, A, Nᴴ, alg) - lq_null_via_qr!(A, Nᴴ, alg.qr_alg) - return Nᴴ + return householder_lq_null!(alg.driver, A, Nᴴ; alg.kwargs...) end -function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) - check_input(lq_full!, A, LQ, alg) - L, Q = LQ - _diagonal_lq!(A, L, Q; alg.kwargs...) - return L, Q +householder_lq!(::DefaultDriver, A, L, Q; kwargs...) = + householder_lq!(default_householder_driver(A), A, L, Q; kwargs...) +householder_lq_null!(::DefaultDriver, A, Nᴴ; kwargs...) = + householder_lq_null!(default_householder_driver(A), A, Nᴴ; kwargs...) + +# dispatch helpers +for f in (:gelqt!, :gemlqt!, :gelqf!, :unglq!, :unmlq!) + @eval begin + $f(::LAPACK, args...) = YALAPACK.$f(args...) + end end -function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) - check_input(lq_compact!, A, LQ, alg) - L, Q = LQ - _diagonal_lq!(A, L, Q; alg.kwargs...) - return L, Q + +function householder_lq!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, L, Q; kwargs...) + qr_alg = driver === GLA() ? GLA_HouseholderQR(; kwargs...) : Householder(driver; kwargs...) + return lq_via_qr!(A, L, Q, qr_alg) end -function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm) - check_input(lq_null!, A, N, alg) - return _diagonal_lq_null!(A, N; alg.kwargs...) +function householder_lq_null!(driver::Union{CUSOLVER, ROCSOLVER, GLA}, A, Nᴴ; kwargs...) + qr_alg = driver === GLA() ? GLA_HouseholderQR(; kwargs...) : Householder(driver; kwargs...) + return lq_null_via_qr!(A, Nᴴ, qr_alg) end -# LAPACK logic -# ------------ -function _lapack_lq!( - A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; +function householder_lq!( + driver::LAPACK, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive = true, pivoted = false, blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) ) @@ -153,31 +136,29 @@ function _lapack_lq!( computeL = length(L) > 0 inplaceQ = Q === A - if pivoted && (blocksize > 1) - throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted LQ decomposition")) - end - if inplaceQ && (computeL || positive || blocksize > 1 || n < m) - throw(ArgumentError("inplace Q only supported if matrix is wide (`m <= n`), L is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`")) - end + pivoted && (blocksize > 1) && + throw(ArgumentError(lazy"$driver does not provide a blocked pivoted LQ decomposition")) + (inplaceQ && (computeL || positive || blocksize > 1 || n < m)) && + throw(ArgumentError("inplace Q only supported if matrix is wide (`m <= n`), L is not required, and using the unblocked algorithm (`blocksize = 1`) with `positive = false`")) if blocksize > 1 mb = min(minmn, blocksize) if computeL # first use L as space for T - A, T = YALAPACK.gelqt!(A, view(L, 1:mb, 1:minmn)) + A, T = gelqt!(driver, A, view(L, 1:mb, 1:minmn)) else - A, T = YALAPACK.gelqt!(A, similar(A, mb, minmn)) + A, T = gelqt!(driver, A, similar(A, mb, minmn)) end - Q = YALAPACK.gemlqt!('R', 'N', A, T, one!(Q)) + Q = gemlqt!(driver, 'R', 'N', A, T, one!(Q)) else - A, τ = YALAPACK.gelqf!(A) + A, τ = gelqf!(driver, A) if inplaceQ - Q = YALAPACK.unglq!(A, τ) + Q = unglq!(driver, A, τ) else - Q = YALAPACK.unmlq!('R', 'N', A, τ, one!(Q)) + Q = unmlq!(driver, 'R', 'N', A, τ, one!(Q)) end end - if positive # already fix Q even if we do not need R + if positive # already fix Q even if we do not need L @inbounds for j in 1:n @simd for i in 1:minmn s = sign_safe(A[i, i]) @@ -200,28 +181,102 @@ function _lapack_lq!( end return L, Q end - -function _lapack_lq_null!( - A::AbstractMatrix, Nᴴ::AbstractMatrix; +function householder_lq_null!( + driver::LAPACK, A::AbstractMatrix, Nᴴ::AbstractMatrix; positive = true, pivoted = false, blocksize = YALAPACK.default_qr_blocksize(A) ) m, n = size(A) minmn = min(m, n) - fill!(Nᴴ, zero(eltype(Nᴴ))) + zero!(Nᴴ) one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n)) if blocksize > 1 mb = min(minmn, blocksize) - A, T = YALAPACK.gelqt!(A, similar(A, mb, minmn)) - Nᴴ = YALAPACK.gemlqt!('R', 'N', A, T, Nᴴ) + A, T = gelqt!(driver, A, similar(A, mb, minmn)) + Nᴴ = gemlqt!(driver, 'R', 'N', A, T, Nᴴ) else - A, τ = YALAPACK.gelqf!(A) - Nᴴ = YALAPACK.unmlq!('R', 'N', A, τ, Nᴴ) + A, τ = gelqf!(driver, A) + Nᴴ = unmlq!(driver, 'R', 'N', A, τ, Nᴴ) + end + return Nᴴ +end +function householder_lq!( + ::Native, A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; + positive::Bool = true # always true regardless of setting + ) + m, n = size(A) + minmn = min(m, n) + @inbounds for i in 1:minmn + for j in 1:(i - 1) + L[i, j] = A[i, j] + end + β, v, L[i, i] = _householder!(conj!(view(A, i, i:n)), 1) + for j in (i + 1):size(L, 2) + L[i, j] = 0 + end + H = HouseholderReflection(conj(β), v, i:n) + rmul!(A, H; rows = (i + 1):m) + # A[i, i] == 1; store β instead + A[i, i] = β + end + # copy remaining rows for m > n + @inbounds for j in 1:size(L, 2) + for i in (minmn + 1):m + L[i, j] = A[i, j] + end + end + # build Q + one!(Q) + @inbounds for i in minmn:-1:1 + β = A[i, i] + A[i, i] = 1 + Hᴴ = HouseholderReflection(β, view(A, i, i:n), i:n) + rmul!(Q, Hᴴ) + end + return L, Q +end +function householder_lq_null!(::Native, A::AbstractMatrix, Nᴴ::AbstractMatrix; positive::Bool = true) + m, n = size(A) + minmn = min(m, n) + @inbounds for i in 1:minmn + β, v, ν = _householder!(conj!(view(A, i, i:n)), 1) + H = HouseholderReflection(conj(β), v, i:n) + rmul!(A, H; rows = (i + 1):m) + # A[i, i] == 1; store β instead + A[i, i] = β + end + # build Nᴴ + zero!(Nᴴ) + one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n)) + @inbounds for i in minmn:-1:1 + β = A[i, i] + A[i, i] = 1 + Hᴴ = HouseholderReflection(β, view(A, i, i:n), i:n) + rmul!(Nᴴ, Hᴴ) end return Nᴴ end + # LQ via transposition and QR # --------------------------- +function lq_full!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) + check_input(lq_full!, A, LQ, alg) + L, Q = LQ + lq_via_qr!(A, L, Q, alg.qr_alg) + return L, Q +end +function lq_compact!(A::AbstractMatrix, LQ, alg::LQViaTransposedQR) + check_input(lq_compact!, A, LQ, alg) + L, Q = LQ + lq_via_qr!(A, L, Q, alg.qr_alg) + return L, Q +end +function lq_null!(A::AbstractMatrix, Nᴴ, alg::LQViaTransposedQR) + check_input(lq_null!, A, Nᴴ, alg) + lq_null_via_qr!(A, Nᴴ, alg.qr_alg) + return Nᴴ +end + function lq_via_qr!( A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, qr_alg::AbstractAlgorithm ) @@ -250,8 +305,26 @@ function lq_null_via_qr!(A::AbstractMatrix, N::AbstractMatrix, qr_alg::AbstractA return N end -# Diagonal logic -# -------------- + +# Diagonal +# -------- +function lq_full!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) + check_input(lq_full!, A, LQ, alg) + L, Q = LQ + _diagonal_lq!(A, L, Q; alg.kwargs...) + return L, Q +end +function lq_compact!(A::AbstractMatrix, LQ, alg::DiagonalAlgorithm) + check_input(lq_compact!, A, LQ, alg) + L, Q = LQ + _diagonal_lq!(A, L, Q; alg.kwargs...) + return L, Q +end +function lq_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm) + check_input(lq_null!, A, N, alg) + return _diagonal_lq_null!(A, N; alg.kwargs...) +end + function _diagonal_lq!( A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; positive::Bool = true ) @@ -271,84 +344,22 @@ end _diagonal_lq_null!(A::AbstractMatrix, N; positive::Bool = true) = N -# Native logic -# ------------- -function lq_full!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ) - check_input(lq_full!, A, LQ, alg) - L, Q = LQ - A === Q && - throw(ArgumentError("inplace Q not supported with native LQ implementation")) - _native_lq!(A, L, Q; alg.kwargs...) - return L, Q -end -function lq_compact!(A::AbstractMatrix, LQ, alg::Native_HouseholderLQ) - check_input(lq_compact!, A, LQ, alg) - L, Q = LQ - A === Q && - throw(ArgumentError("inplace Q not supported with native LQ implementation")) - _native_lq!(A, L, Q; alg.kwargs...) - return L, Q -end -function lq_null!(A::AbstractMatrix, N, alg::Native_HouseholderLQ) - check_input(lq_null!, A, N, alg) - _native_lq_null!(A, N; alg.kwargs...) - return N -end - -function _native_lq!( - A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix; - positive::Bool = true # always true regardless of setting - ) - m, n = size(A) - minmn = min(m, n) - @inbounds for i in 1:minmn - for j in 1:(i - 1) - L[i, j] = A[i, j] - end - β, v, L[i, i] = _householder!(conj!(view(A, i, i:n)), 1) - for j in (i + 1):size(L, 2) - L[i, j] = 0 - end - H = Householder(conj(β), v, i:n) - rmul!(A, H; rows = (i + 1):m) - # A[i, i] == 1; store β instead - A[i, i] = β - end - # copy remaining rows for m > n - @inbounds for j in 1:size(L, 2) - for i in (minmn + 1):m - L[i, j] = A[i, j] - end - end - # build Q - one!(Q) - @inbounds for i in minmn:-1:1 - β = A[i, i] - A[i, i] = 1 - Hᴴ = Householder(β, view(A, i, i:n), i:n) - rmul!(Q, Hᴴ) - end - return L, Q -end - -function _native_lq_null!(A::AbstractMatrix, Nᴴ::AbstractMatrix; positive::Bool = true) - m, n = size(A) - minmn = min(m, n) - @inbounds for i in 1:minmn - β, v, ν = _householder!(conj!(view(A, i, i:n)), 1) - H = Householder(conj(β), v, i:n) - rmul!(A, H; rows = (i + 1):m) - # A[i, i] == 1; store β instead - A[i, i] = β - end - # build Nᴴ - fill!(Nᴴ, zero(eltype(Nᴴ))) - one!(view(Nᴴ, 1:(n - minmn), (minmn + 1):n)) - @inbounds for i in minmn:-1:1 - β = A[i, i] - A[i, i] = 1 - Hᴴ = Householder(β, view(A, i, i:n), i:n) - rmul!(Nᴴ, Hᴴ) +# Deprecations +# ------------ +for drivertype in (:LAPACK, :Native) + algtype = Symbol(drivertype, :_HouseholderLQ) + @eval begin + Base.@deprecate( + lq_full!(A, LQ, alg::$algtype), + lq_full!(A, LQ, Householder($drivertype(), alg.kwargs)) + ) + Base.@deprecate( + lq_compact!(A, LQ, alg::$algtype), + lq_compact!(A, LQ, Householder($drivertype(), alg.kwargs)) + ) + Base.@deprecate( + lq_null!(A, Nᴴ, alg::$algtype), + lq_null!(A, Nᴴ, Householder($drivertype(), alg.kwargs)) + ) end - return Nᴴ end diff --git a/src/implementations/qr.jl b/src/implementations/qr.jl index fd7ce01b..ac60cd96 100644 --- a/src/implementations/qr.jl +++ b/src/implementations/qr.jl @@ -86,101 +86,109 @@ for f! in (:qr_full!, :qr_compact!) end end -# Implementation -# -------------- -# actual implementation -function qr_full!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) +# ========================== +# IMPLEMENTATIONS +# ========================== + +# Householder +# ----------- +function qr_full!(A, QR, alg::Householder) check_input(qr_full!, A, QR, alg) - Q, R = QR - _lapack_qr!(A, Q, R; alg.kwargs...) - return Q, R + return householder_qr!(alg.driver, A, QR...; alg.kwargs...) end -function qr_compact!(A::AbstractMatrix, QR, alg::LAPACK_HouseholderQR) +function qr_compact!(A, QR, alg::Householder) check_input(qr_compact!, A, QR, alg) - Q, R = QR - _lapack_qr!(A, Q, R; alg.kwargs...) - return Q, R + return householder_qr!(alg.driver, A, QR...; alg.kwargs...) end -function qr_null!(A::AbstractMatrix, N, alg::LAPACK_HouseholderQR) +function qr_null!(A, N, alg::Householder) check_input(qr_null!, A, N, alg) - _lapack_qr_null!(A, N; alg.kwargs...) - return N + return householder_qr_null!(alg.driver, A, N; alg.kwargs...) end -function qr_full!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm) - check_input(qr_full!, A, QR, alg) - Q, R = QR - _diagonal_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function qr_compact!(A::AbstractMatrix, QR, alg::DiagonalAlgorithm) - check_input(qr_compact!, A, QR, alg) - Q, R = QR - _diagonal_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function qr_null!(A::AbstractMatrix, N, alg::DiagonalAlgorithm) - check_input(qr_null!, A, N, alg) - _diagonal_qr_null!(A, N; alg.kwargs...) - return N +householder_qr!(::DefaultDriver, A, Q, R; kwargs...) = + householder_qr!(default_householder_driver(A), A, Q, R; kwargs...) +householder_qr_null!(::DefaultDriver, A, N; kwargs...) = + householder_qr_null!(default_householder_driver(A), A, N; kwargs...) + +# dispatch helpers +for f in (:geqrt!, :gemqrt!, :geqp3!, :geqrf!, :ungqr!, :unmqr!) + @eval begin + $f(driver::Driver, args...) = throw(MethodError($f, (driver, args...))) # make JET not complain + $f(::LAPACK, args...) = YALAPACK.$f(args...) + end end -# LAPACK logic -# ------------ -function _lapack_qr!( - A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive = true, pivoted = false, - blocksize = ((pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) +function householder_qr!( + driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, + blocksize::Int = ((driver !== LAPACK() || pivoted || A === Q) ? 1 : YALAPACK.default_qr_blocksize(A)) ) + # error messages for disallowing driver - setting combinations + (blocksize == 1 || driver === LAPACK()) || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) + (!pivoted || driver === LAPACK()) || + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) + pivoted && (blocksize > 1) && + throw(ArgumentError(lazy"$driver does not provide a blocked pivoted QR decomposition")) + m, n = size(A) minmn = min(m, n) computeR = length(R) > 0 inplaceQ = Q === A - if pivoted && (blocksize > 1) - throw(ArgumentError("LAPACK does not provide a blocked implementation for a pivoted QR decomposition")) - end - if inplaceQ && (computeR || positive || blocksize > 1 || m < n) - throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize=1`) with `positive=false`")) - end + (inplaceQ && (computeR || positive || blocksize > 1 || m < n)) && + throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required, and using the unblocked algorithm (`blocksize = 1`) with `positive = false`")) + # Compute QR in packed form if blocksize > 1 nb = min(minmn, blocksize) if computeR # first use R as space for T - A, T = YALAPACK.geqrt!(A, view(R, 1:nb, 1:minmn)) + A, T = geqrt!(driver, A, view(R, 1:nb, 1:minmn)) else - A, T = YALAPACK.geqrt!(A, similar(A, nb, minmn)) + A, T = geqrt!(driver, A, similar(A, nb, minmn)) end - Q = YALAPACK.gemqrt!('L', 'N', A, T, one!(Q)) + Q = gemqrt!(driver, 'L', 'N', A, T, one!(Q)) else if pivoted - A, τ, jpvt = YALAPACK.geqp3!(A) + A, τ, jpvt = geqp3!(driver, A) else - A, τ = YALAPACK.geqrf!(A) + A, τ = geqrf!(driver, A) end if inplaceQ - Q = YALAPACK.ungqr!(A, τ) + Q = ungqr!(driver, A, τ) else - Q = YALAPACK.unmqr!('L', 'N', A, τ, one!(Q)) + Q = unmqr!(driver, 'L', 'N', A, τ, one!(Q)) end end if positive # already fix Q even if we do not need R - @inbounds for j in 1:minmn - s = sign_safe(A[j, j]) - @simd for i in 1:m - Q[i, j] *= s + if driver === LAPACK() + @inbounds for j in 1:minmn + s = sign_safe(A[j, j]) + @simd for i in 1:m + Q[i, j] *= s + end end + else + # guaranteed τ exists and no longer needed + τ .= sign_safe.(diagview(A)) + Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q + Qf .= Qf .* transpose(τ) end end if computeR R̃ = uppertriangular!(view(A, axes(R)...)) if positive - @inbounds for j in n:-1:1 - @simd for i in 1:min(minmn, j) - R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) + if driver === LAPACK() + @inbounds for j in n:-1:1 + @simd for i in 1:min(minmn, j) + R̃[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i])) + end end + else + R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R + R̃f .= conj.(τ) .* R̃f end end if !pivoted @@ -192,167 +200,17 @@ function _lapack_qr!( end return Q, R end - -function _lapack_qr_null!( - A::AbstractMatrix, N::AbstractMatrix; - positive = true, pivoted = false, blocksize = YALAPACK.default_qr_blocksize(A) - ) - m, n = size(A) - minmn = min(m, n) - fill!(N, zero(eltype(N))) - one!(view(N, (minmn + 1):m, 1:(m - minmn))) - if blocksize > 1 - nb = min(minmn, blocksize) - A, T = YALAPACK.geqrt!(A, similar(A, nb, minmn)) - N = YALAPACK.gemqrt!('L', 'N', A, T, N) - else - A, τ = YALAPACK.geqrf!(A) - N = YALAPACK.unmqr!('L', 'N', A, τ, N) - end - return N -end - -# Diagonal logic -# -------------- -function _diagonal_qr!( - A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool = true +function householder_qr!( + driver::Native, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 ) - # note: Ad and Qd might share memory here so order of operations is important - Ad = diagview(A) - Qd = diagview(Q) - Rd = diagview(R) - if positive - @. Rd = abs(Ad) - @. Qd = sign_safe(Ad) - else - Rd .= Ad - one!(Q) - end - return Q, R -end - -_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = true) = N - -# GPU logic -# -------------- -# placed here to avoid code duplication since much of the logic is replicable across CUDA and AMDGPU -function MatrixAlgebraKit.qr_full!( - A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} - ) - check_input(qr_full!, A, QR, alg) - Q, R = QR - _gpu_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function MatrixAlgebraKit.qr_compact!( - A::AbstractMatrix, QR, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} - ) - check_input(qr_compact!, A, QR, alg) - Q, R = QR - _gpu_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function MatrixAlgebraKit.qr_null!( - A::AbstractMatrix, N, alg::Union{CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} - ) - check_input(qr_null!, A, N, alg) - _gpu_qr_null!(A, N; alg.kwargs...) - return N -end - -_gpu_geqrf!(A::AbstractMatrix) = throw(MethodError(_gpu_geqrf!, (A,))) -_gpu_ungqr!(A::AbstractMatrix, τ::AbstractVector) = throw(MethodError(_gpu_ungqr!, (A, τ))) -function _gpu_unmqr!( - side::AbstractChar, trans::AbstractChar, A::AbstractMatrix, τ::AbstractVector, C - ) - throw(MethodError(_gpu_unmqr!, (side, trans, A, τ, C))) -end - -function _gpu_qr!( - A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive = true, blocksize = 1, pivoted = false - ) - blocksize > 1 && - throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) + # error messages for disallowing driver - setting combinations + blocksize == 1 || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) pivoted && - throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition")) - m, n = size(A) - minmn = min(m, n) - computeR = length(R) > 0 - inplaceQ = Q === A - if inplaceQ && (computeR || positive || m < n) - throw(ArgumentError("inplace Q only supported if matrix is tall (`m >= n`), R is not required and using `positive=false`")) - end - - A, τ = _gpu_geqrf!(A) - if inplaceQ - Q = _gpu_ungqr!(A, τ) - else - Q = _gpu_unmqr!('L', 'N', A, τ, one!(Q)) - end - # henceforth, τ is no longer needed and can be reused - - if positive # already fix Q even if we do not need R - # TODO: report that `lmul!` and `rmul!` with `Diagonal` don't work with CUDA - τ .= sign_safe.(diagview(A)) - Qf = view(Q, 1:m, 1:minmn) # first minmn columns of Q - Qf .= Qf .* transpose(τ) - end + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) + # positive = true regardless of setting - if computeR - R̃ = uppertriangular!(view(A, axes(R)...)) - if positive - R̃f = view(R̃, 1:minmn, 1:n) # first minmn rows of R - R̃f .= conj.(τ) .* R̃f - end - copyto!(R, R̃) - end - return Q, R -end - -function _gpu_qr_null!( - A::AbstractMatrix, N::AbstractMatrix; positive = true, blocksize = 1, pivoted = false - ) - blocksize > 1 && - throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a blocked implementation for a QR decomposition")) - pivoted && - throw(ArgumentError("CUSOLVER/ROCSOLVER does not provide a pivoted implementation for a QR decomposition")) - m, n = size(A) - minmn = min(m, n) - fill!(N, zero(eltype(N))) - one!(view(N, (minmn + 1):m, 1:(m - minmn))) - A, τ = _gpu_geqrf!(A) - N = _gpu_unmqr!('L', 'N', A, τ, N) - return N -end - -# Native logic -# -------------- -function qr_full!(A::AbstractMatrix, QR, alg::Native_HouseholderQR) - check_input(qr_full!, A, QR, alg) - Q, R = QR - A === Q && - throw(ArgumentError("inplace Q not supported with native QR implementation")) - _native_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function qr_compact!(A::AbstractMatrix, QR, alg::Native_HouseholderQR) - check_input(qr_compact!, A, QR, alg) - Q, R = QR - A === Q && - throw(ArgumentError("inplace Q not supported with native QR implementation")) - _native_qr!(A, Q, R; alg.kwargs...) - return Q, R -end -function qr_null!(A::AbstractMatrix, N, alg::Native_HouseholderQR) - check_input(qr_null!, A, N, alg) - _native_qr_null!(A, N; alg.kwargs...) - return N -end - -function _native_qr!( - A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; - positive::Bool = true # always true regardless of setting - ) m, n = size(A) minmn = min(m, n) @inbounds for j in 1:minmn @@ -363,7 +221,7 @@ function _native_qr!( for i in (j + 1):size(R, 1) R[i, j] = 0 end - H = Householder(β, v, j:m) + H = HouseholderReflection(β, v, j:m) lmul!(H, A; cols = (j + 1):n) # A[j,j] == 1; store β instead A[j, j] = β @@ -379,30 +237,129 @@ function _native_qr!( @inbounds for j in minmn:-1:1 β = A[j, j] A[j, j] = 1 - Hᴴ = Householder(conj(β), view(A, j:m, j), j:m) + Hᴴ = HouseholderReflection(conj(β), view(A, j:m, j), j:m) lmul!(Hᴴ, Q) end return Q, R end -function _native_qr_null!(A::AbstractMatrix, N::AbstractMatrix; positive::Bool = true) +function householder_qr_null!( + driver::Union{LAPACK, CUSOLVER, ROCSOLVER}, A::AbstractMatrix, N::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, + blocksize::Int = ((driver !== LAPACK() || pivoted) ? 1 : YALAPACK.default_qr_blocksize(A)) + ) + # error messages for disallowing driver - setting combinations + (blocksize == 1 || driver === LAPACK()) || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) + (!pivoted || driver === LAPACK()) || + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) + pivoted && (blocksize > 1) && + throw(ArgumentError(lazy"$driver does not provide a blocked pivoted QR decomposition")) + + m, n = size(A) + minmn = min(m, n) + zero!(N) + one!(view(N, (minmn + 1):m, 1:(m - minmn))) + + if blocksize > 1 + nb = min(minmn, blocksize) + A, T = geqrt!(driver, A, similar(A, nb, minmn)) + N = gemqrt!(driver, 'L', 'N', A, T, N) + else + A, τ = geqrf!(driver, A) + N = unmqr!(driver, 'L', 'N', A, τ, N) + end + return N +end +function householder_qr_null!( + driver::Native, A::AbstractMatrix, N::AbstractMatrix; + positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1 + ) + # error messages for disallowing driver - setting combinations + blocksize == 1 || + throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition")) + pivoted && + throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition")) + m, n = size(A) minmn = min(m, n) + @inbounds for j in 1:minmn β, v, ν = _householder!(view(A, j:m, j), 1) - H = Householder(β, v, j:m) + H = HouseholderReflection(β, v, j:m) lmul!(H, A; cols = (j + 1):n) - # A[j,j] == 1; store β instead + # A[j, j] == 1; store β instead A[j, j] = β end + # build N - fill!(N, zero(eltype(N))) + zero!(N) one!(view(N, (minmn + 1):m, 1:(m - minmn))) @inbounds for j in minmn:-1:1 β = A[j, j] A[j, j] = 1 - Hᴴ = Householder(conj(β), view(A, j:m, j), j:m) + Hᴴ = HouseholderReflection(conj(β), view(A, j:m, j), j:m) lmul!(Hᴴ, N) end return N end + + +# Diagonal +# -------- +function qr_full!(A, QR, alg::DiagonalAlgorithm) + check_input(qr_full!, A, QR, alg) + Q, R = QR + _diagonal_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function qr_compact!(A, QR, alg::DiagonalAlgorithm) + check_input(qr_compact!, A, QR, alg) + Q, R = QR + _diagonal_qr!(A, Q, R; alg.kwargs...) + return Q, R +end +function qr_null!(A, N, alg::DiagonalAlgorithm) + check_input(qr_null!, A, N, alg) + _diagonal_qr_null!(A, N; alg.kwargs...) + return N +end + +function _diagonal_qr!( + A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix; positive::Bool = true + ) + # note: Ad and Qd might share memory here so order of operations is important + Ad = diagview(A) + Qd = diagview(Q) + Rd = diagview(R) + if positive + @. Rd = abs(Ad) + @. Qd = sign_safe(Ad) + else + Rd .= Ad + one!(Q) + end + return Q, R +end + +_diagonal_qr_null!(A::AbstractMatrix, N; positive::Bool = true) = N + +# Deprecations +# ------------ +for drivertype in (:LAPACK, :CUSOLVER, :ROCSOLVER, :Native, :GLA) + algtype = Symbol(drivertype, :_HouseholderQR) + @eval begin + Base.@deprecate( + qr_full!(A, QR, alg::$algtype), + qr_full!(A, QR, Householder($drivertype(), alg.kwargs)) + ) + Base.@deprecate( + qr_compact!(A, QR, alg::$algtype), + qr_compact!(A, QR, Householder($drivertype(), alg.kwargs)) + ) + Base.@deprecate( + qr_null!(A, N, alg::$algtype), + qr_null!(A, N, Householder($drivertype(), alg.kwargs)) + ) + end +end diff --git a/src/interface/decompositions.jl b/src/interface/decompositions.jl index 19f20e8e..f7026a09 100644 --- a/src/interface/decompositions.jl +++ b/src/interface/decompositions.jl @@ -64,6 +64,29 @@ of `R` are non-negative. @algdef LAPACK_HouseholderQL @algdef LAPACK_HouseholderRQ +""" + Householder(; [driver], kwargs...) + +Algorithm type to denote the algorithm for computing QR, RQ, QL or LQ decompositions of a matrix using Householder reflectors. +The optional `driver` symbol can be used to choose between different implementations of this algorithm. + +### Keyword arguments + +- `positive::Bool = true` : Fix the gauge of the resulting factors by making the diagonal elements of `L` or `R` non-negative. +- `pivoted::Bool = false` : Use column- or row-pivoting for low-rank input matrices. +- `blocksize::Int` : Use a blocked version of the algorithm if `blocksize > 1`. + +Depending on the driver, various other keywords may be (un)available to customize the implementation. +""" +struct Householder{D <: Driver, KW} <: AbstractAlgorithm + driver::D + kwargs::KW +end +Householder(driver::Driver = DefaultDriver(); kwargs...) = Householder(driver, kwargs) + +default_householder_driver(A) = Native() +default_householder_driver(::YALAPACK.MaybeBlasMat) = LAPACK() + # General Eigenvalue Decomposition # ------------------------------- """ @@ -388,8 +411,8 @@ const GPU_SVDAlgorithm = Union{CUSOLVER_SVDAlgorithm, ROCSOLVER_SVDAlgorithm} const GPU_SVDPolar = Union{CUSOLVER_SVDPolar} const GPU_Randomized = Union{CUSOLVER_Randomized} -const QRAlgorithms = Union{LAPACK_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} -const LQAlgorithms = Union{LAPACK_HouseholderLQ, LQViaTransposedQR} +const QRAlgorithms = Union{Householder, LAPACK_HouseholderQR, Native_HouseholderQR, CUSOLVER_HouseholderQR, ROCSOLVER_HouseholderQR} +const LQAlgorithms = Union{Householder, LAPACK_HouseholderLQ, Native_HouseholderLQ, LQViaTransposedQR} const SVDAlgorithms = Union{LAPACK_SVDAlgorithm, GPU_SVDAlgorithm} const PolarAlgorithms = Union{PolarViaSVD, PolarNewton} diff --git a/src/interface/lq.jl b/src/interface/lq.jl index 8254c826..12b6d814 100644 --- a/src/interface/lq.jl +++ b/src/interface/lq.jl @@ -69,18 +69,13 @@ See also [`qr_full(!)`](@ref lq_full) and [`qr_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- default_lq_algorithm(A; kwargs...) = default_lq_algorithm(typeof(A); kwargs...) -function default_lq_algorithm(T::Type; kwargs...) + +default_lq_algorithm(T::Type; kwargs...) = throw(MethodError(default_lq_algorithm, (T,))) -end -function default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} - return Native_HouseholderLQ(; kwargs...) -end -function default_lq_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} - return LAPACK_HouseholderLQ(; kwargs...) -end -function default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} - return DiagonalAlgorithm(; kwargs...) -end +default_lq_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} = + Householder(; kwargs...) +default_lq_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} = + DiagonalAlgorithm(; kwargs...) for f in (:lq_full!, :lq_compact!, :lq_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/src/interface/qr.jl b/src/interface/qr.jl index c881b5a6..3d03eb86 100644 --- a/src/interface/qr.jl +++ b/src/interface/qr.jl @@ -69,18 +69,13 @@ See also [`lq_full(!)`](@ref lq_full) and [`lq_compact(!)`](@ref lq_compact). # Algorithm selection # ------------------- default_qr_algorithm(A; kwargs...) = default_qr_algorithm(typeof(A); kwargs...) -function default_qr_algorithm(T::Type; kwargs...) + +default_qr_algorithm(T::Type; kwargs...) = throw(MethodError(default_qr_algorithm, (T,))) -end -function default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} - return Native_HouseholderQR(; kwargs...) -end -function default_qr_algorithm(::Type{T}; kwargs...) where {T <: YALAPACK.MaybeBlasMat} - return LAPACK_HouseholderQR(; kwargs...) -end -function default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} - return DiagonalAlgorithm(; kwargs...) -end +default_qr_algorithm(::Type{T}; kwargs...) where {T <: AbstractMatrix} = + Householder(; kwargs...) +default_qr_algorithm(::Type{T}; kwargs...) where {T <: Diagonal} = + DiagonalAlgorithm(; kwargs...) for f in (:qr_full!, :qr_compact!, :qr_null!) @eval function default_algorithm(::typeof($f), ::Type{A}; kwargs...) where {A} diff --git a/test/algorithms.jl b/test/algorithms.jl index 0a2ccfec..a2f5ad09 100644 --- a/test/algorithms.jl +++ b/test/algorithms.jl @@ -2,7 +2,7 @@ using MatrixAlgebraKit using Test using TestExtras using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, - default_algorithm, select_algorithm + default_algorithm, select_algorithm, Householder @testset "default_algorithm" begin A = randn(3, 3) @@ -17,21 +17,21 @@ using MatrixAlgebraKit: LAPACK_SVDAlgorithm, PolarViaSVD, TruncatedAlgorithm, LAPACK_MultipleRelativelyRobustRepresentations() end for f in (lq_full!, lq_full, lq_compact!, lq_compact, lq_null!, lq_null) - @test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderLQ() + @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (left_polar!, left_polar, right_polar!, right_polar) @test @constinferred(default_algorithm(f, A)) == PolarViaSVD(LAPACK_DivideAndConquer()) end for f in (qr_full!, qr_full, qr_compact!, qr_compact, qr_null!, qr_null) - @test @constinferred(default_algorithm(f, A)) == LAPACK_HouseholderQR() + @test @constinferred(default_algorithm(f, A)) == Householder() end for f in (schur_full!, schur_full, schur_vals!, schur_vals) @test @constinferred(default_algorithm(f, A)) === LAPACK_Expert() end - @test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) === - LAPACK_HouseholderQR(; blocksize = 2) + @test @constinferred(default_algorithm(qr_compact!, A; blocksize = 2)) == + Householder(; blocksize = 2) end @testset "select_algorithm" begin