Skip to content
12 changes: 6 additions & 6 deletions ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: ROCSOLVER, LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
Expand All @@ -28,10 +28,10 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T
return ROCSOLVER_DivideAndConquer(; kwargs...)
end

_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) =
YArocSOLVER.unmqr!(side, trans, A, τ, C)
for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::ROCSOLVER, args...) = YArocSOLVER.$f(args...)
end

_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
# not yet supported
Expand Down
15 changes: 7 additions & 8 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: CUSOLVER, LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: geqrf!, ungqr!, unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
Expand All @@ -32,6 +32,7 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {TT
return CUSOLVER_DivideAndConquer(; kwargs...)
end


# include for block sector support
function MatrixAlgebraKit.default_qr_algorithm(::Type{Base.ReshapedArray{T, 2, SubArray{T, 1, A, Tuple{UnitRange{Int}}, true}, Tuple{}}}; kwargs...) where {T <: BlasFloat, A <: CuVecOrMat{T}}
return CUSOLVER_HouseholderQR(; kwargs...)
Expand All @@ -50,14 +51,12 @@ function MatrixAlgebraKit.default_eigh_algorithm(::Type{Base.ReshapedArray{T, 2,
return CUSOLVER_DivideAndConquer(; kwargs...)
end

for f in (:geqrf!, :ungqr!, :unmqr!)
@eval $f(::CUSOLVER, args...) = YACUSOLVER.$f(args...)
end

_gpu_geev!(A::StridedCuMatrix, D::StridedCuVector, V::StridedCuMatrix) =
YACUSOLVER.Xgeev!(A, D, V)
_gpu_geqrf!(A::StridedCuMatrix) =
YACUSOLVER.geqrf!(A)
_gpu_ungqr!(A::StridedCuMatrix, τ::StridedCuVector) =
YACUSOLVER.ungqr!(A, τ)
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedCuMatrix, τ::StridedCuVector, C::StridedCuVecOrMat) =
YACUSOLVER.unmqr!(side, trans, A, τ, C)
_gpu_gesvd!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix) =
YACUSOLVER.gesvd!(A, S, U, Vᴴ)
_gpu_Xgesvdp!(A::StridedCuMatrix, S::StridedCuVector, U::StridedCuMatrix, Vᴴ::StridedCuMatrix; kwargs...) =
Expand Down
73 changes: 28 additions & 45 deletions ext/MatrixAlgebraKitGenericLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module MatrixAlgebraKitGenericLinearAlgebraExt

using MatrixAlgebraKit
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, default_fixgauge
using MatrixAlgebraKit: left_orth_alg
using MatrixAlgebraKit: sign_safe, check_input, diagview, gaugefix!, one!, zero!, default_fixgauge
using GenericLinearAlgebra: svd!, svdvals!, eigen!, eigvals!, Hermitian, qr!
using LinearAlgebra: I, Diagonal, lmul!

Expand Down Expand Up @@ -57,81 +56,65 @@ function MatrixAlgebraKit.eigh_vals!(A::AbstractMatrix, D, ::GLA_QRIteration)
return eigvals!(Hermitian(A); sortby = real)
end

function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return GLA_HouseholderQR(; kwargs...)
end

function MatrixAlgebraKit.qr_full!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_full!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function MatrixAlgebraKit.qr_compact!(A::AbstractMatrix, QR, alg::GLA_HouseholderQR)
check_input(qr_compact!, A, QR, alg)
Q, R = QR
return _gla_householder_qr!(A, Q, R; alg.kwargs...)
end

function MatrixAlgebraKit.qr_null!(A::AbstractMatrix, N, alg::GLA_HouseholderQR)
check_input(qr_null!, A, N, alg)
return _gla_householder_qr_null!(A, N; alg.kwargs...)
end

function _gla_householder_qr!(A::AbstractMatrix, Q, R; positive = true, blocksize = 1, pivoted = false)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
function MatrixAlgebraKit.householder_qr!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
)
blocksize == 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

m, n = size(A)
k = min(m, n)
minmn = min(m, n)
computeR = length(R) > 0

# compute QR
Q̃, R̃ = qr!(A)
lmul!(Q̃, MatrixAlgebraKit.one!(Q))

if positive
@inbounds for j in 1:k
@inbounds for j in 1:minmn
s = sign_safe(R̃[j, j])
@simd for i in 1:m
Q[i, j] *= s
end
end
end

computeR = length(R) > 0
if computeR
if positive
@inbounds for j in n:-1:1
@simd for i in 1:min(k, j)
@simd for i in 1:min(minmn, j)
R[i, j] = R̃[i, j] * conj(sign_safe(R̃[i, i]))
end
@simd for i in (min(k, j) + 1):size(R, 1)
@simd for i in (min(minmn, j) + 1):size(R, 1)
R[i, j] = zero(eltype(R))
end
end
else
R[1:k, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(k + 1):end, :]))
R[1:minmn, :] .= R̃
MatrixAlgebraKit.zero!(@view(R[(minmn + 1):end, :]))
end
end
return Q, R
end

function _gla_householder_qr_null!(
A::AbstractMatrix, N::AbstractMatrix;
positive = true, blocksize = 1, pivoted = false
function MatrixAlgebraKit.householder_qr_null!(
driver::MatrixAlgebraKit.GLA, A::AbstractMatrix, N::AbstractMatrix;
positive::Bool = true, pivoted::Bool = false, blocksize::Int = 1
)
pivoted && throw(ArgumentError("Only pivoted = false implemented for GLA_HouseholderQR."))
(blocksize == 1) || throw(ArgumentError("Only blocksize = 1 implemented for GLA_HouseholderQR."))
blocksize == 1 ||
throw(ArgumentError(lazy"$driver does not provide a blocked QR decomposition"))
pivoted &&
throw(ArgumentError(lazy"$driver does not provide a pivoted QR decomposition"))

m, n = size(A)
minmn = min(m, n)
fill!(N, zero(eltype(N)))
zero!(N)
one!(view(N, (minmn + 1):m, 1:(m - minmn)))
Q̃, = qr!(A)
lmul!(Q̃, N)
return N
end

function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedMatrix{<:Union{Float16, ComplexF16, BigFloat, Complex{BigFloat}}}}
return MatrixAlgebraKit.LQViaTransposedQR(GLA_HouseholderQR(; kwargs...))
return lmul!(Q̃, N)
end

MatrixAlgebraKit.left_orth_alg(alg::GLA_HouseholderQR) = MatrixAlgebraKit.LeftOrthViaQR(alg)
Expand Down
53 changes: 53 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,59 @@ If this is not possible, for example when the output size is not known a priori
this function may return `nothing`.
""" initialize_output


# Drivers
# -------
"""
abstract type Driver

Supertype used for customizing various implementations of the same algorithm.
"""
abstract type Driver end

"""
DefaultDriver <: Driver

Select a default driver at runtime, based on the input matrix.
"""
struct DefaultDriver <: Driver end

"""
LAPACK <: Driver

Driver to select LAPACK as the implementation strategy.
"""
struct LAPACK <: Driver end

"""
CUSOLVER <: Driver

Driver to select CUSOLVER as the implementation strategy.
"""
struct CUSOLVER <: Driver end

"""
ROCSOLVER <: Driver

Driver to select ROCSOLVER as the implementation strategy.
"""
struct ROCSOLVER <: Driver end

"""
GLA <: Driver

Driver to select GenericLinearAlgebra.jl as the implementation strategy.
"""
struct GLA <: Driver end

"""
Native <: Driver

Driver to select a native implementation in MatrixAlgebraKit as the implementation strategy.
"""
struct Native <: Driver end


# Truncation strategy
# -------------------
"""
Expand Down
16 changes: 8 additions & 8 deletions src/common/householder.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
const IndexRange{T <: Integer} = Base.AbstractRange{T}

# Elementary Householder reflection
struct Householder{T, V <: AbstractVector, R <: IndexRange}
struct HouseholderReflection{T, V <: AbstractVector, R <: IndexRange}
β::T
v::V
r::R
end
Base.adjoint(H::Householder) = Householder(conj(H.β), H.v, H.r)
Base.adjoint(H::HouseholderReflection) = HouseholderReflection(conj(H.β), H.v, H.r)

function householder(x::AbstractVector, r::IndexRange = axes(x, 1), k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(x[r], i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end
# Householder reflector h that zeros the elements A[r,col] (except for A[k,col]) upon lmul!(h,A)
function householder(A::AbstractMatrix, r::IndexRange, col::Int, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(A[r, col], i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end
# Householder reflector that zeros the elements A[row,r] (except for A[row,k]) upon rmul!(A,h')
function householder(A::AbstractMatrix, row::Int, r::IndexRange, k = first(r))
i = findfirst(==(k), r)
i == nothing && error("k = $k should be in the range r = $r")
β, v, ν = _householder!(conj!(A[row, r]), i)
return Householder(β, v, r), ν
return HouseholderReflection(β, v, r), ν
end

# generate Householder vector based on vector v, such that applying the reflection
Expand Down Expand Up @@ -66,7 +66,7 @@ function _householder!(v::AbstractVector{T}, i::Int = 1) where {T}
return β, v, ν
end

function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
function LinearAlgebra.lmul!(H::HouseholderReflection, x::AbstractVector)
v = H.v
r = H.r
β = H.β
Expand All @@ -87,7 +87,7 @@ function LinearAlgebra.lmul!(H::Householder, x::AbstractVector)
end
return x
end
function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2))
function LinearAlgebra.lmul!(H::HouseholderReflection, A::AbstractMatrix; cols = axes(A, 2))
v = H.v
r = H.r
β = H.β
Expand All @@ -110,7 +110,7 @@ function LinearAlgebra.lmul!(H::Householder, A::AbstractMatrix; cols = axes(A, 2
end
return A
end
function LinearAlgebra.rmul!(A::AbstractMatrix, H::Householder; rows = axes(A, 1))
function LinearAlgebra.rmul!(A::AbstractMatrix, H::HouseholderReflection; rows = axes(A, 1))
v = H.v
r = H.r
β = H.β
Expand Down
Loading