From 1abbd3b2a5b3c4ac53d8804cf63b3829ece5f5a0 Mon Sep 17 00:00:00 2001 From: Jutho Haegeman Date: Fri, 27 Feb 2026 12:25:45 +0100 Subject: [PATCH] some changes to the ad test utils --- Project.toml | 2 +- src/pullbacks/svd.jl | 4 +- test/testsuite/TestSuite.jl | 6 +- test/testsuite/ad_utils.jl | 361 +++++++++++++++-------------------- test/testsuite/chainrules.jl | 8 +- test/testsuite/enzyme.jl | 4 +- 6 files changed, 161 insertions(+), 224 deletions(-) diff --git a/Project.toml b/Project.toml index e079639a..e851d669 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Enzyme = "0.13.118" EnzymeTestUtils = "0.2.5" GenericLinearAlgebra = "0.3.19" GenericSchur = "0.5.6" -JET = "0.9, 0.10" +JET = "0.9, 0.10, 0.11" LinearAlgebra = "1" Mooncake = "0.5" ParallelTestRunner = "2" diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 9c131464..01fdc4f7 100644 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -1,4 +1,4 @@ -svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true) +svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true) function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV)) mask = abs.(Sr' .- Sr) .< degeneracy_atol @@ -43,7 +43,7 @@ function svd_pullback!( minmn = min(m, n) S = diagview(Smat) length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)")) - r = svd_rank(S, rank_atol) + r = svd_rank(S; rank_atol) Ur = view(U, :, 1:r) Vᴴr = view(Vᴴ, 1:r, :) Sr = view(S, 1:r) diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2edd0846..3614437a 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -11,7 +11,7 @@ module TestSuite using Test using MatrixAlgebraKit using MatrixAlgebraKit: diagview -using LinearAlgebra: Diagonal, norm, istriu, istril, I +using LinearAlgebra: Diagonal, norm, istriu, istril, I, mul! using Random, StableRNGs using Mooncake using AMDGPU, CUDA @@ -85,9 +85,9 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz) end instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A)))) -function instantiate_rank_deficient_matrix(T, sz; trunc = trunctol(rtol = 0.5)) +function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2))) A = instantiate_matrix(T, sz) - V, C = left_orth!(A; trunc = trunctol(rtol = 0.5)) + V, C = left_orth!(A; trunc) return mul!(A, V, C) end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index fce118a8..81662f14 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -1,9 +1,13 @@ +structured_randn!(A::AbstractMatrix) = randn!(A) +structured_randn!(A::Diagonal) = (randn!(diagview(A)); return A) + """ - remove_eig_gauge_dependence!(ΔV, D, V) + remove_eig_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the eigenvector matrix `V`. The -eigenvectors are only determined up to complex phase (and unitary mixing for degenerate -eigenvalues), so the corresponding components of `ΔV` are projected out. +eigenvectors are only determined up to a scalar factor (or an abitrary linear transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding components of +`ΔV` are projected out. """ function remove_eig_gauge_dependence!( ΔV, D, V; @@ -16,12 +20,12 @@ function remove_eig_gauge_dependence!( end """ - remove_eigh_gauge_dependence!(ΔV, D, V) + remove_eigh_gauge_dependence!(ΔV, D, V; degeneracy_atol = ...) Remove the gauge-dependent part from the cotangent `ΔV` of the Hermitian eigenvector matrix -`V`. The eigenvectors are only determined up to complex phase (and unitary mixing for -degenerate eigenvalues), so the corresponding anti-Hermitian components of `V' * ΔV` are -projected out. +`V`. The eigenvectors are only determined up to a complex phase (or a unitary transformation +across eigenvectors associated with degenerate eigenvalues), so the corresponding anti-Hermitian +components of `V' * ΔV` are projected out. """ function remove_eigh_gauge_dependence!( ΔV, D, V; @@ -35,47 +39,51 @@ function remove_eigh_gauge_dependence!( end """ - remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = ..., rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔU` and `ΔVᴴ` of the SVD factors. The -singular vectors are only determined up to a common complex phase per singular value (and -unitary mixing for degenerate singular values), so the corresponding anti-Hermitian components -of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. For the full SVD, the extra columns of `U` -and rows of `Vᴴ` beyond `min(m, n)` are additionally zeroed out. +singular vectors are only determined up to a common complex phase per singular value (or a +unitary transformation across singular vectors associated with degenerate singular values), +so the corresponding anti-Hermitian components of `U₁' * ΔU₁ + Vᴴ₁ * ΔVᴴ₁'` are projected out. +For the full SVD, the extra columns of `U` and rows of `Vᴴ` beyond the rank `r` are +additionally zeroed out, where `r = count(diagview(S) .> rank_atol)`. """ function remove_svd_gauge_dependence!( ΔU, ΔVᴴ, U, S, Vᴴ; - degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S) + degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S), + rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(S) ) - minmn = length(diagview(S)) - U₁ = view(U, :, 1:minmn) - Vᴴ₁ = view(Vᴴ, 1:minmn, :) - ΔU₁ = view(ΔU, :, 1:minmn) - ΔVᴴ₁ = view(ΔVᴴ, 1:minmn, :) + r = MatrixAlgebraKit.svd_rank(diagview(S); rank_atol) + U₁ = view(U, :, 1:r) + Vᴴ₁ = view(Vᴴ, 1:r, :) + ΔU₁ = view(ΔU, :, 1:r) + ΔVᴴ₁ = view(ΔVᴴ, 1:r, :) Sdiag = diagview(S) gaugepart = mul!(U₁' * ΔU₁, Vᴴ₁, ΔVᴴ₁', true, true) gaugepart = project_antihermitian!(gaugepart) gaugepart[abs.(transpose(Sdiag) .- Sdiag) .>= degeneracy_atol] .= 0 mul!(ΔU₁, U₁, gaugepart, -1, 1) - ΔU[:, (minmn + 1):end] .= 0 - ΔVᴴ[(minmn + 1):end, :] .= 0 + ΔU[:, (r + 1):end] .= 0 + ΔVᴴ[(r + 1):end, :] .= 0 return ΔU, ΔVᴴ end """ - remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) + remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔQ` and `ΔR` of the QR factors `Q` and `R`. For the full QR decomposition, the extra columns of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, rows of `ΔR` beyond the rank are zeroed out. """ -function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R) - r = MatrixAlgebraKit.qr_rank(R) +function remove_qr_gauge_dependence!(ΔQ, ΔR, A, Q, R; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(R)) + r = MatrixAlgebraKit.qr_rank(R; rank_atol) Q₁ = @view Q[:, 1:r] ΔQ₂ = @view ΔQ[:, (r + 1):end] - Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ - mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) + ΔQ₂ .= 0 + # TODO: refine this by differentiating between rank deficiency and qr_full cases + # Q₁ᴴΔQ₂ = Q₁' * ΔQ₂ + # mul!(ΔQ₂, Q₁, Q₁ᴴΔQ₂) view(ΔR, (r + 1):size(ΔR, 1), :) .= 0 return ΔQ, ΔR end @@ -93,19 +101,21 @@ function remove_qr_null_gauge_dependence!(ΔN, A, N) end """ - remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) + remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = ...) Remove the gauge-dependent part from the cotangents `ΔL` and `ΔQ` of the LQ factors `L` and `Q`. For the full LQ decomposition, the extra rows of `Q` beyond the rank `r` are not uniquely determined by `A`, so the corresponding part of `ΔQ` is projected to remove this ambiguity. Additionally, columns of `ΔL` beyond the rank are zeroed out. """ -function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q) - r = MatrixAlgebraKit.lq_rank(L) +function remove_lq_gauge_dependence!(ΔL, ΔQ, A, L, Q; rank_atol = MatrixAlgebraKit.default_pullback_rank_atol(L)) + r = MatrixAlgebraKit.lq_rank(L; rank_atol) Q₁ = @view Q[1:r, :] ΔQ₂ = @view ΔQ[(r + 1):end, :] - ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' - mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) + ΔQ₂ .= 0 + # TODO: refine this by differentiating between rank deficiency and lq_full cases + # ΔQ₂Q₁ᴴ = ΔQ₂ * Q₁' + # mul!(ΔQ₂, ΔQ₂Q₁ᴴ, Q₁) view(ΔL, :, (r + 1):size(ΔL, 2)) .= 0 return ΔL, ΔQ end @@ -130,11 +140,7 @@ Remove the gauge-dependent part from the cotangent `ΔN` of the left null space space basis is only determined up to a unitary rotation, so `ΔN` is projected onto the column span of the compact QR factor `Q₁` of `A`. """ -function remove_left_null_gauge_dependence!(ΔN, A, N) - Q, _ = qr_compact(A) - mul!(ΔN, Q, Q' * ΔN) - return ΔN -end +remove_left_null_gauge_dependence!(ΔN, A, N) = remove_qr_null_gauge_dependence!(ΔN, A, N) """ remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) @@ -143,11 +149,7 @@ Remove the gauge-dependent part from the cotangent `ΔNᴴ` of the right null sp null space basis is only determined up to a unitary rotation, so `ΔNᴴ` is projected onto the row span of the compact LQ factor `Q₁` of `A`. """ -function remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) - _, Q = lq_compact(A) - mul!(ΔNᴴ, ΔNᴴ * Q', Q) - return ΔNᴴ -end +remove_right_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) = remove_lq_null_gauge_dependence!(ΔNᴴ, A, Nᴴ) """ call_and_zero!(f!, A, alg) @@ -218,21 +220,11 @@ end function ad_qr_compact_setup(A) QR = qr_compact(A) - ΔQR = randn!.(copy.(QR)) - remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(copy.(QR)) + A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -function ad_qr_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - QR = qr_compact(A) - T = eltype(A) - ΔQ = Diagonal(randn!(similar(A.diag, T, m))) - ΔR = Diagonal(randn!(similar(A.diag, T, m))) - return QR, (ΔQ, ΔR) -end - function ad_qr_null_setup(A) N = qr_null(A) ΔN = randn!(copy(N)) @@ -242,54 +234,51 @@ end function ad_qr_full_setup(A) QR = qr_full(A) - ΔQR = randn!.(copy.(QR)) - remove_qr_gauge_dependence!(ΔQR..., A, QR...) + ΔQR = structured_randn!.(copy.(QR)) + A isa Diagonal || remove_qr_gauge_dependence!(ΔQR..., A, QR...) return QR, ΔQR end -ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A) - -function ad_qr_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - Q, R = qr_compact(Ard) - QR = (Q, R) - ΔQ = randn!(similar(A, T, m, minmn)) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) - MatrixAlgebraKit.zero!(ΔQ2) - ΔR = randn!(similar(A, T, minmn, n)) - view(ΔR, (r + 1):minmn, :) .= 0 - return (Q, R), (ΔQ, ΔR) -end - -function ad_qr_rank_deficient_compact_setup(A::Diagonal) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard_ = randn!(similar(A, T, m)) - MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) - Ard = Diagonal(Ard_) - Q, R = qr_compact(Ard) - ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) - ΔR = Diagonal(randn!(similar(diagview(A), T, m))) - MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) - MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) - return (Q, R), (ΔQ, ΔR) -end +# function ad_qr_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# Q, R = qr_compact(Ard) +# QR = (Q, R) +# ΔQ = randn!(similar(A, T, m, minmn)) +# Q1 = view(Q, 1:m, 1:r) +# Q2 = view(Q, 1:m, (r + 1):minmn) +# ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) +# MatrixAlgebraKit.zero!(ΔQ2) +# ΔR = randn!(similar(A, T, minmn, n)) +# view(ΔR, (r + 1):minmn, :) .= 0 +# return (Q, R), (ΔQ, ΔR) +# end + +# function ad_qr_rank_deficient_compact_setup(A::Diagonal) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard_ = randn!(similar(A, T, m)) +# MatrixAlgebraKit.zero!(view(Ard_, (r + 1):m)) +# Ard = Diagonal(Ard_) +# Q, R = qr_compact(Ard) +# ΔQ = Diagonal(randn!(similar(diagview(A), T, m))) +# ΔR = Diagonal(randn!(similar(diagview(A), T, m))) +# MatrixAlgebraKit.zero!(view(diagview(ΔQ), (r + 1):m)) +# MatrixAlgebraKit.zero!(view(diagview(ΔR), (r + 1):m)) +# return (Q, R), (ΔQ, ΔR) +# end function ad_lq_compact_setup(A) LQ = lq_compact(A) - ΔLQ = randn!.(copy.(LQ)) - remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(copy.(LQ)) + A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_compact_setup(A::Diagonal) = ad_qr_compact_setup(A) function ad_lq_null_setup(A) T = eltype(A) @@ -301,67 +290,60 @@ end function ad_lq_full_setup(A) LQ = lq_full(A) - ΔLQ = randn!.(copy.(LQ)) - remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) + ΔLQ = structured_randn!.(copy.(LQ)) + A isa Diagonal || remove_lq_gauge_dependence!(ΔLQ..., A, LQ...) return LQ, ΔLQ end -ad_lq_full_setup(A::Diagonal) = ad_qr_full_setup(A) -function ad_lq_rank_deficient_compact_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) - r = minmn - 5 - Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) - L, Q = lq_compact(Ard) - ΔL = randn!(similar(A, T, m, minmn)) - ΔQ = randn!(similar(A, T, minmn, n)) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) - ΔQ2 .= 0 - view(ΔL, :, (r + 1):minmn) .= 0 - return (L, Q), (ΔL, ΔQ) -end -ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) +# function ad_lq_rank_deficient_compact_setup(A) +# m, n = size(A) +# minmn = min(m, n) +# T = eltype(A) +# r = minmn - 5 +# Ard = randn!(similar(A, T, m, r)) * randn!(similar(A, T, r, n)) +# L, Q = lq_compact(Ard) +# ΔL = randn!(similar(A, T, m, minmn)) +# ΔQ = randn!(similar(A, T, minmn, n)) +# Q1 = view(Q, 1:r, 1:n) +# Q2 = view(Q, (r + 1):minmn, 1:n) +# ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) +# ΔQ2 .= 0 +# view(ΔL, :, (r + 1):minmn) .= 0 +# return (L, Q), (ΔL, ΔQ) +# end +# ad_lq_rank_deficient_compact_setup(A::Diagonal) = ad_qr_rank_deficient_compact_setup(A) function ad_eig_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A, complex(T), m, m)) + D, V = eig_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, complex(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end -function ad_eig_full_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - DV = eig_full(A) - D, V = DV - ΔV = randn!(similar(A.diag, T, m, m)) - ΔV = remove_eig_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A.diag, T, m))) - return DV, (ΔD, ΔV) -end +# function ad_eig_full_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# DV = eig_full(A) +# D, V = DV +# ΔV = randn!(similar(A.diag, T, m, m)) +# ΔV = remove_eig_gauge_dependence!(ΔV, D, V) +# ΔD = Diagonal(randn!(similar(A.diag, T, m))) +# return DV, (ΔD, ΔV) +# end function ad_eig_vals_setup(A) - m, n = size(A) - T = complex(eltype(A)) D = eig_vals(A) - ΔD = randn!(similar(A, complex(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end -function ad_eig_vals_setup(A::Diagonal) - m, n = size(A) - T = complex(eltype(A)) - D = eig_vals(A) - ΔD = randn!(similar(A.diag, T, m)) - return D, ΔD -end +# function ad_eig_vals_setup(A::Diagonal) +# m, n = size(A) +# T = complex(eltype(A)) +# D = eig_vals(A) +# ΔD = randn!(similar(A.diag, T, m)) +# return D, ΔD +# end function ad_eig_trunc_setup(A, truncalg) DV, ΔDV = ad_eig_full_setup(A) @@ -374,21 +356,15 @@ function ad_eig_trunc_setup(A, truncalg) end function ad_eigh_full_setup(A) - m, n = size(A) - T = eltype(A) - DV = eigh_full(A) - D, V = DV - ΔV = randn!(similar(A, T, m, m)) + D, V = eigh_full(A) + ΔD, ΔV = structured_randn!.(similar.((D, V))) ΔV = remove_eigh_gauge_dependence!(ΔV, D, V) - ΔD = Diagonal(randn!(similar(A, real(T), m))) - return DV, (ΔD, ΔV) + return (D, V), (ΔD, ΔV) end function ad_eigh_vals_setup(A) - m, n = size(A) - T = eltype(A) D = eigh_vals(A) - ΔD = randn!(similar(A, real(T), m)) + ΔD = randn!(similar(D)) return D, ΔD end @@ -403,55 +379,39 @@ function ad_eigh_trunc_setup(A, truncalg) end function ad_svd_compact_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A, T, m, minmn)) - ΔS = Diagonal(randn!(similar(A, real(T), minmn))) - ΔVᴴ = randn!(similar(A, T, minmn, n)) U, S, Vᴴ = svd_compact(A) + ΔU, ΔS, ΔVᴴ = structured_randn!.(similar.((U, S, Vᴴ))) ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -function ad_svd_compact_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - ΔU = randn!(similar(A.diag, T, m, n)) - ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) - ΔVᴴ = randn!(similar(A.diag, T, m, n)) - U, S, Vᴴ = svd_compact(A) - ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) - return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) -end +# function ad_svd_compact_setup(A::Diagonal) +# m, n = size(A) +# T = eltype(A) +# minmn = min(m, n) +# ΔU = randn!(similar(A.diag, T, m, n)) +# ΔS = Diagonal(randn!(similar(A.diag, real(T), minmn))) +# ΔVᴴ = randn!(similar(A.diag, T, m, n)) +# U, S, Vᴴ = svd_compact(A) +# ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) +# return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) +# end function ad_svd_full_setup(A) - m, n = size(A) - T = eltype(A) - minmn = min(m, n) - (_, _, _), (ΔU, ΔS, ΔVᴴ) = ad_svd_compact_setup(A) - ΔUfull = similar(A, T, m, m) - ΔUfull .= zero(T) - ΔSfull = similar(A, real(T), m, n) - ΔSfull .= zero(real(T)) - ΔVᴴfull = similar(A, T, n, n) - ΔVᴴfull .= zero(T) U, S, Vᴴ = svd_full(A) - view(ΔUfull, :, 1:minmn) .= ΔU - view(ΔVᴴfull, 1:minmn, :) .= ΔVᴴ - diagview(ΔSfull)[1:minmn] .= diagview(ΔS) - return (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull) + ΔU = structured_randn!(similar(U)) + ΔVᴴ = structured_randn!(similar(Vᴴ)) + ΔU, ΔVᴴ = remove_svd_gauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ) + ΔS = zero(S) + randn!(diagview(ΔS)) + return (U, S, Vᴴ), (ΔU, ΔS, ΔVᴴ) end -ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) +# ad_svd_full_setup(A::Diagonal) = ad_svd_compact_setup(A) function ad_svd_vals_setup(A) - m, n = size(A) - minmn = min(m, n) - T = eltype(A) S = svd_vals(A) - ΔS = randn!(similar(A, real(T), minmn)) + ΔS = randn!(similar(S)) return S, ΔS end @@ -468,48 +428,25 @@ function ad_svd_trunc_setup(A, truncalg) end function ad_left_polar_setup(A) - m, n = size(A) - T = eltype(A) WP = left_polar(A) - ΔWP = (randn!(similar(A, T, m, n)), randn!(similar(A, T, n, n))) - return WP, ΔWP -end - -function ad_left_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) - WP = left_polar(A) - ΔWP = (Diagonal(randn!(similar(A.diag))), randn!(similar(WP[2]))) + ΔWP = structured_randn!.(similar.(WP)) return WP, ΔWP end function ad_right_polar_setup(A) - m, n = size(A) - T = eltype(A) - PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(A, T, m, m)), randn!(similar(A, T, m, n))) - return PWᴴ, ΔPWᴴ -end -function ad_right_polar_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) PWᴴ = right_polar(A) - ΔPWᴴ = (randn!(similar(PWᴴ[1])), Diagonal(randn!(similar(A.diag)))) + ΔPWᴴ = structured_randn!.(similar.(PWᴴ)) return PWᴴ, ΔPWᴴ end function ad_left_orth_setup(A) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (randn!(similar(A, T, size(VC[1])...)), randn!(similar(A, T, size(VC[2])...))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end function ad_left_orth_setup(A::Diagonal) - m, n = size(A) - T = eltype(A) VC = left_orth(A) - ΔVC = (Diagonal(randn!(similar(A.diag, T, m))), Diagonal(randn!(similar(A.diag, T, m)))) + ΔVC = structured_randn!.(similar.(VC)) return VC, ΔVC end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index bb94664d..e3d0277a 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -17,14 +17,14 @@ for f in @eval begin function $copy_f(input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) end return $f(input, alg) end function ChainRulesCore.rrule(::typeof($copy_f), input, alg) output = MatrixAlgebraKit.initialize_output($f!, input, alg) if $_hermitian - input = (input + input') / 2 + input = project_hermitian(input) else input = copy(input) end @@ -112,7 +112,7 @@ function test_chainrules_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + QR, ΔQR = ad_qr_compact_setup(Ard) ΔQ, ΔR = ΔQR test_rrule( cr_copy_qr_compact, Ard, alg ⊢ NoTangent(); @@ -189,7 +189,7 @@ function test_chainrules_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + LQ, ΔLQ = ad_lq_compact_setup(Ard) test_rrule( cr_copy_lq_compact, Ard, alg ⊢ NoTangent(); output_tangent = ΔLQ, atol = atol, rtol = rtol diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index e9b07a31..ad08a58d 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -146,7 +146,7 @@ function test_enzyme_qr( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + QR, ΔQR = ad_qr_compact_setup(Ard) eltype(T) <: BlasFloat && test_reverse(qr_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔQR, fdm) is_cpu(A) && enz_test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, QR, ΔQR, alg) end @@ -190,7 +190,7 @@ function test_enzyme_lq( m, n = size(A) r = min(m, n) - 5 Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + LQ, ΔLQ = ad_lq_compact_setup(Ard) eltype(T) <: BlasFloat && test_reverse(lq_compact, RT, (Ard, TA), (alg, Const); atol, rtol, output_tangent = ΔLQ, fdm) is_cpu(A) && enz_test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, LQ, ΔLQ, alg) end