From 9d6bd035986854810e4d7a73a7aab58c4127111c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sat, 14 Feb 2026 16:11:46 -0500 Subject: [PATCH 01/15] Fix zero tangent guard in polar pullback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Guard `C .+= ΔP` with `!iszerotangent(ΔP)` in both `left_polar_pullback!` and `right_polar_pullback!` to handle the case where ΔP is `nothing`. Co-Authored-By: Claude Opus 4.6 --- src/pullbacks/polar.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 4d498da0..a549321e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) !iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, W, C, 1, 1) if !iszerotangent(ΔW) ΔWP = ΔW / P @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ From eeaf9a959ef6c7e4e3782adb57ce107ec1e6d171 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Mon, 16 Feb 2026 17:33:29 -0500 Subject: [PATCH 02/15] Use eigendecomposition-based Sylvester solver for symmetric case Replace LAPACK trsyl!-based solver with a direct eigendecomposition approach when both arguments are the same Hermitian matrix (as in polar pullbacks). This avoids LAPACKException(1) for close eigenvalues. Co-Authored-By: Claude Opus 4.6 --- src/common/pullbacks.jl | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/common/pullbacks.jl b/src/common/pullbacks.jl index 4fe853cd..41834f5d 100644 --- a/src/common/pullbacks.jl +++ b/src/common/pullbacks.jl @@ -11,5 +11,22 @@ function iszerotangent end iszerotangent(::Any) = false iszerotangent(::Nothing) = true -# fallback -_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C) +# Solve the Sylvester equation A*X + X*B + C = 0. +# When A === B (same Hermitian PD matrix, as in polar pullbacks), use an +# eigendecomposition-based solver to avoid LAPACK's trsyl! failing with +# LAPACKException(1) for close eigenvalues. +function _sylvester(A, B, C) + if A === B + return _sylvester_symm(A, C) + end + return LinearAlgebra.sylvester(A, B, C) +end + +function _sylvester_symm(P, C) + D, Q = LinearAlgebra.eigen(LinearAlgebra.Hermitian(P)) + Y = Q' * C * Q + @inbounds for j in axes(Y, 2), i in axes(Y, 1) + Y[i, j] = -Y[i, j] / (D[i] + D[j]) + end + return Q * Y * Q' +end From e2dce34facf91fdf0c69f3c880cef2ccc7153791 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 08:47:48 -0500 Subject: [PATCH 03/15] Add AD rules for projection methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add rrules/pullbacks for `project_hermitian!`, `project_antihermitian!`, and `project_isometric!` directly in each AD backend extension (ChainRulesCore, Enzyme, Mooncake). The hermitian/antihermitian pullbacks are self-adjoint, while the isometric pullback delegates to `left_polar_pullback!` with zero ΔP. Includes test utilities and tests for all three backends. Co-Authored-By: Claude Opus 4.6 --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 42 ++ .../MatrixAlgebraKitEnzymeExt.jl | 87 +++ .../MatrixAlgebraKitMooncakeExt.jl | 78 +++ test/testsuite/ad_utils.jl | 24 + test/testsuite/chainrules.jl | 57 ++ test/testsuite/enzyme.jl | 39 ++ test/testsuite/mooncake.jl | 574 ++++++++++++++++++ 7 files changed, 901 insertions(+) create mode 100644 test/testsuite/mooncake.jl diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 400b2a79..ead8ebf7 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,4 +274,46 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end +function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg) + Ac = copy_input(project_hermitian, A) + Aₕ = project_hermitian!(Ac, Aₕ, alg) + function project_hermitian_pullback(ΔAₕ) + ΔA = project_hermitian(unthunk(ΔAₕ)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_hermitian_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return Aₕ, project_hermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg) + Ac = copy_input(project_antihermitian, A) + Aₐ = project_antihermitian!(Ac, Aₐ, alg) + function project_antihermitian_pullback(ΔAₐ) + ΔA = project_antihermitian(unthunk(ΔAₐ)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_antihermitian_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return Aₐ, project_antihermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg) + Ac = copy_input(project_isometric, A) + # Compute the full polar decomposition to cache P for the pullback + WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg) + W_out = copy!(W, WP[1]) + function project_isometric_pullback(ΔW) + ΔA = zero(A) + MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing)) + return NoTangent(), ΔA, ZeroTangent(), NoTangent() + end + function project_isometric_pullback(::ZeroTangent) + return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + end + return W_out, project_isometric_pullback +end + end diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 24a1fa5e..4b5e01e5 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -454,4 +454,91 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing) end +# single-output projections: project_hermitian!, project_antihermitian! +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, project_f) in ( + (project_hermitian!, project_hermitian), + (project_antihermitian!, project_antihermitian), + ) + @eval begin + function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + A::Annotation, + arg::Annotation{TA}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TA} + ret = func.val(A.val, arg.val, alg.val) + cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing + dret = if EnzymeRules.needs_shadow(config) + (TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval + else + nothing + end + primal = EnzymeRules.needs_primal(config) ? ret : nothing + return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret)) + end + function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof($f!)}, + ::Type{RT}, + cache, + A::Annotation, + arg::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_arg, darg = cache + argdval = something(darg, arg.dval) + if !isa(A, Const) + A.dval .+= $project_f(argdval) + end + !isa(arg, Const) && make_zero!(arg.dval) + return (nothing, nothing, nothing) + end + end +end + +# project_isometric! needs special handling: compute full polar decomposition +function EnzymeRules.augmented_primal( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(project_isometric!)}, + ::Type{RT}, + A::Annotation, + W::Annotation{TW}, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT, TW} + # Compute the full polar decomposition for the pullback + Ac = copy(A.val) + m, n = size(A.val) + P = similar(A.val, n, n) + WP = left_polar!(Ac, (W.val, P), alg.val) + cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing + dret = if EnzymeRules.needs_shadow(config) + (TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval + else + nothing + end + primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing + return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret)) +end +function EnzymeRules.reverse( + config::EnzymeRules.RevConfigWidth{1}, + func::Const{typeof(project_isometric!)}, + ::Type{RT}, + cache, + A::Annotation, + W::Annotation, + alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, + ) where {RT} + cache_WP, dW = cache + Aval = nothing + WPval = something(cache_WP, (W.val, cache_WP[2])) + if !isa(A, Const) + left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing)) + end + !isa(W, Const) && make_zero!(W.dval) + return (nothing, nothing, nothing) +end + end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f32e4258..e6cc72a8 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -778,4 +778,82 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end +# single-output projections: project_hermitian!, project_antihermitian! +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + Ac = copy(A) + arg, darg = arrayify(arg_darg) + argc = copy(arg) + $f!(A, arg, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) + copy!(A, Ac) + dA .+= $f(darg) + copy!(arg, argc) + zero!(darg) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return arg_darg, $adj + end + @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_codual = CoDual(output, Mooncake.zero_tangent(output)) + function $adj(::NoRData) + arg, darg = arrayify(output_codual) + dA .+= $f(darg) + zero!(darg) + return NoRData(), NoRData(), NoRData() + end + return output_codual, $adj + end + end +end + +# project_isometric! needs special handling: compute full polar decomposition +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + W, dW = arrayify(W_dW) + Ac = copy(A) + Wc = copy(W) + # Compute the full polar decomposition for the pullback + m, n = size(A) + P = similar(A, n, n) + WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg)) + copy!(W, WP[1]) + function project_isometric_adjoint(::NoRData) + copy!(A, Ac) + left_polar_pullback!(dA, A, WP, (dW, nothing)) + copy!(W, Wc) + zero!(dW) + return NoRData(), NoRData(), NoRData(), NoRData() + end + return W_dW, project_isometric_adjoint +end + +@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + # Compute the full polar decomposition for the pullback + WP = left_polar(A, alg) + W_out = WP[1] + output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out)) + function project_isometric_adjoint(::NoRData) + W, dW = arrayify(output_codual) + left_polar_pullback!(dA, A, WP, (dW, nothing)) + zero!(dW) + return NoRData(), NoRData(), NoRData() + end + return output_codual, project_isometric_adjoint +end + end diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index fce118a8..c2a953ec 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -537,3 +537,27 @@ function ad_right_null_setup(A) ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] return Nᴴ, ΔNᴴ end + +function ad_project_hermitian_setup(A) + m, n = size(A) + T = eltype(A) + Aₕ = project_hermitian(A) + ΔAₕ = randn!(similar(A, T, m, n)) + return Aₕ, ΔAₕ +end + +function ad_project_antihermitian_setup(A) + m, n = size(A) + T = eltype(A) + Aₐ = project_antihermitian(A) + ΔAₐ = randn!(similar(A, T, m, n)) + return Aₐ, ΔAₐ +end + +function ad_project_isometric_setup(A) + m, n = size(A) + T = eltype(A) + W = project_isometric(A) + ΔW = randn!(similar(A, T, m, n)) + return W, ΔW +end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index bb94664d..3bab06b6 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -10,6 +10,7 @@ for f in :eig_trunc_no_error, :eigh_trunc_no_error, :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, + :project_hermitian, :project_antihermitian, :project_isometric, ) copy_f = Symbol(:cr_copy_, f) f! = Symbol(f, '!') @@ -46,6 +47,7 @@ function test_chainrules(T::Type, sz; kwargs...) test_chainrules_svd(T, sz; kwargs...) test_chainrules_polar(T, sz; kwargs...) test_chainrules_orthnull(T, sz; kwargs...) + test_chainrules_projections(T, sz; kwargs...) end end @@ -587,3 +589,58 @@ function test_chainrules_orthnull( ) end end + +function test_chainrules_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + config = Zygote.ZygoteRuleConfig() + if m == n + alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A) + @testset "project_hermitian" begin + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + test_rrule( + cr_copy_project_hermitian, A, alg_h ⊢ NoTangent(); + output_tangent = ΔAₕ, atol = atol, rtol = rtol + ) + test_rrule( + config, project_hermitian, A; + output_tangent = ΔAₕ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + @testset "project_antihermitian" begin + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + test_rrule( + cr_copy_project_antihermitian, A, alg_h ⊢ NoTangent(); + output_tangent = ΔAₐ, atol = atol, rtol = rtol + ) + test_rrule( + config, project_antihermitian, A; + output_tangent = ΔAₐ, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + if m > n + @testset "project_isometric" begin + W, ΔW = ad_project_isometric_setup(A) + alg_iso = MatrixAlgebraKit.default_polar_algorithm(A) + test_rrule( + cr_copy_project_isometric, A, alg_iso ⊢ NoTangent(); + output_tangent = ΔW, atol = atol, rtol = rtol + ) + test_rrule( + config, project_isometric, A; + output_tangent = ΔW, + atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false + ) + end + end + end +end diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index e9b07a31..413a4171 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -105,6 +105,7 @@ function test_enzyme(T::Type, sz; kwargs...) test_enzyme_polar(T, sz; kwargs...) test_enzyme_orthnull(T, sz; kwargs...) end + test_enzyme_projections(T, sz; kwargs...) end end @@ -462,3 +463,41 @@ function test_enzyme_orthnull( end end end + +function test_enzyme_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Enzyme AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) + if m == n + @testset "project_hermitian" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) + end + end + @testset "project_antihermitian" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) + end + end + end + if m > n + @testset "project_isometric" begin + @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) + W, ΔW = ad_project_isometric_setup(A) + eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm) + is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW) + end + end + end + end +end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl new file mode 100644 index 00000000..6a437c9b --- /dev/null +++ b/test/testsuite/mooncake.jl @@ -0,0 +1,574 @@ +using TestExtras +using MatrixAlgebraKit +using Mooncake, Mooncake.TestUtils +using Mooncake: rrule!! +using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc +using LinearAlgebra: BlasFloat +using GenericLinearAlgebra + +function mc_copy_eigh_full(A; kwargs...) + A = (A + A') / 2 + return eigh_full(A; kwargs...) +end + +function mc_copy_eigh_full!(A, DV; kwargs...) + A = (A + A') / 2 + return eigh_full!(A, DV; kwargs...) +end + +function mc_copy_eigh_vals(A; kwargs...) + A = (A + A') / 2 + return eigh_vals(A; kwargs...) +end + +function mc_copy_eigh_vals!(A, D; kwargs...) + A = (A + A') / 2 + return eigh_vals!(A, D; kwargs...) +end + +function mc_copy_eigh_trunc(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc!(A, DV, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error(A, alg; kwargs...) +end + +function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) + A = (A + A') / 2 + return eigh_trunc_no_error!(A, DV, alg; kwargs...) +end + +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) +MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) + +make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem +make_mooncake_tangent(ΔA::AbstractMatrix) = ΔA +make_mooncake_tangent(ΔA::AbstractVector) = ΔA +make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) + +make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) + +make_mooncake_fdata(x) = make_mooncake_tangent(x) +make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) +make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) + +# copies a preset tangent into a Mooncake CoDual +# for use in the pullback. +function copy_tangent(var::Mooncake.CoDual, Δargs) + dargs = make_mooncake_fdata(deepcopy(Δargs)) + copyto!(Mooncake.tangent(var), dargs) + return +end + +function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) + dargs = make_mooncake_fdata.(deepcopy(Δargs)) + for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) + if var_tangent isa Mooncake.FData + for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) + copyto!(var_f, darg_f) + end + else + copyto!(var_tangent, darg) + end + end + return +end + +# no `alg` argument +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + A_dA = Mooncake.CoDual(A_copy, dA_copy) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA) + # copy Δargs into tangent of the output variable for the pullback check + copy_tangent(copy_out, Δargs) + copy_pb!!(rdata) + @test Mooncake.primal(A_dA) == A + return dA_copy, Mooncake.tangent(copy_out) +end + +# `alg` argument +function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata) + dA_copy = make_mooncake_fdata(copy(ΔA)) + A_copy = copy(A) + A_dA = Mooncake.CoDual(A_copy, dA_copy) + copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA, Mooncake.CoDual(alg, Mooncake.NoFData())) + # copy Δargs into tangent of the output variable for the pullback check + copy_tangent(copy_out, Δargs) + copy_pb!!(rdata) + @test Mooncake.primal(A_dA) == A + return dA_copy, Mooncake.tangent(copy_out) +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + args_copy = deepcopy(args) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + A_dA = Mooncake.CoDual(A_inplace, dA_inplace) + args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) + end + # copy reference derivative of output ȳ into inplace_out + # needed for inplace methods like svd_trunc! that generate + # new output variables + copy_tangent(inplace_out, ȳ) + inplace_pb!!(rdata) + @test Mooncake.primal(A_dA) == A + @test Mooncake.primal(args_dargs) == args_copy + return dA_inplace, Mooncake.tangent(inplace_out) +end + +function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs) + dA_inplace = make_mooncake_fdata(copy(ΔA)) + A_inplace = copy(A) + args_copy = deepcopy(args) + dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) + # not every f! has a handwritten rrule!! + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) + A_dA = Mooncake.CoDual(A_inplace, dA_inplace) + args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) + if has_handwritten_rule + inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) + else + inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) + inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) + end + # copy reference derivative of output ȳ into inplace_out + # needed for inplace methods like svd_trunc! that generate + # new output variables + copy_tangent(inplace_out, ȳ) + inplace_pb!!(rdata) + @test Mooncake.primal(A_dA) == A + @test Mooncake.primal(args_dargs) == args_copy + return dA_inplace, Mooncake.tangent(inplace_out) +end + +""" + test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) + +Compare the result of running the *in-place, mutating* function `f!`'s reverse rule +with the result of running its *non-mutating* partner function `f`'s reverse rule. +We must compare directly because many of the mutating functions modify `A` as a +scratch workspace, making testing `f!` against finite differences infeasible. + +The arguments to this function are: + - `f!` the mutating, in-place version of the function (accepts `args` for the function result) + - `f` the non-mutating version of the function (does not accept `args` for the function result) + - `A` the input matrix to factorize + - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) + - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input + - `alg` optional algorithm keyword argument + - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) +""" +function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData(), ȳ = deepcopy(Δargs)) + sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)} + rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) + rrule = Mooncake.build_rrule(rvs_interp, sig) + ΔA = randn(rng, eltype(A), size(A)) + + copy_args = isa(args, Tuple) ? copy.(args) : copy(args) + inplace_args = isa(args, Tuple) ? copy.(args) : copy(args) + dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, ȳ, alg, rdata) + dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata; ȳ) + + dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] + dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] + @test dA_inplace_ ≈ dA_copy_ + @test copy_args == inplace_args + if dargs_copy isa Tuple + for (darg_copy_, darg_inplace_) in zip(dargs_copy, dargs_inplace) + if darg_copy_ isa Mooncake.FData + for (c_f, i_f) in zip(Mooncake._fields(darg_copy_), Mooncake._fields(darg_inplace_)) + @test c_f == i_f + end + else + @test darg_copy_ == darg_inplace_ + end + end + else + @test dargs_copy == dargs_inplace + end + return +end + +function test_mooncake(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake AD $summary_str" begin + test_mooncake_qr(T, sz; kwargs...) + test_mooncake_lq(T, sz; kwargs...) + if length(sz) == 1 || sz[1] == sz[2] + test_mooncake_eig(T, sz; kwargs...) + test_mooncake_eigh(T, sz; kwargs...) + end + test_mooncake_svd(T, sz; kwargs...) + test_mooncake_polar(T, sz; kwargs...) + # doesn't work for Diagonals yet? + if T <: Number + test_mooncake_orthnull(T, sz; kwargs...) + end + test_mooncake_projections(T, sz; kwargs...) + end +end + +function test_mooncake_qr( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "QR Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "qr_compact" begin + QR, ΔQR = ad_qr_compact_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) + end + @testset "qr_null" begin + N, ΔN = ad_qr_null_setup(A) + dN = make_mooncake_tangent(copy(ΔN)) + Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) + test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) + end + @testset "qr_full" begin + QR, ΔQR = ad_qr_full_setup(A) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) + end + @testset "qr_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) + dQR = make_mooncake_tangent(ΔQR) + Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) + test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) + end + end +end + +function test_mooncake_lq( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "LQ Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + @testset "lq_compact" begin + LQ, ΔLQ = ad_lq_compact_setup(A) + Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) + end + @testset "lq_null" begin + Nᴴ, ΔNᴴ = ad_lq_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) + test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) + end + @testset "lq_full" begin + LQ, ΔLQ = ad_lq_full_setup(A) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) + end + @testset "lq_compact - rank-deficient A" begin + m, n = size(A) + r = min(m, n) - 5 + Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) + LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) + dLQ = make_mooncake_tangent(ΔLQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) + test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) + end + end +end + +function test_mooncake_eig( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIG Mooncake AD rules $summary_str" begin + A = make_eig_matrix(T, sz) + m = size(A, 1) + @testset "eig_full" begin + DV, ΔDV, ΔD2V = ad_eig_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) + test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) + end + @testset "eig_vals" begin + D, ΔD = ad_eig_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol, rtol) + test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) + end + @testset "eig_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((copy.(ΔDVtrunc)..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) + end + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) + DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) + test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) + test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) + end + end +end + +function test_mooncake_eigh( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "EIGH Mooncake AD rules $summary_str" begin + A = make_eigh_matrix(T, sz) + m = size(A, 1) + @testset "eigh_full" begin + DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) + dDV = make_mooncake_tangent(ΔD2V) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) + test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) + end + @testset "eigh_vals" begin + D, ΔD = ad_eigh_vals_setup(A) + dD = make_mooncake_tangent(ΔD) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol, rtol) + test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) + end + @testset "eigh_trunc" begin + for r in 1:4:m + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) + end + D = eigh_vals(A / 2) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) + DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) + dDVtrunc = make_mooncake_tangent(ΔDVtrunc) + Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) + test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) + end + end +end + +function test_mooncake_svd( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "SVD Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + minmn = min(size(A)...) + @testset "svd_compact" begin + USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_full" begin + USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) + Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) + end + @testset "svd_vals" begin + S, ΔS = ad_svd_vals_setup(A) + Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) + end + @testset "svd_trunc" begin + @testset for r in 1:4:minmn + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) + end + @testset "trunctol" begin + A = instantiate_matrix(T, sz) + S, ΔS = ad_svd_vals_setup(A) + truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) + USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + ϵ = zero(real(eltype(T))) + dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) + test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) + dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) + Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) + test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) + end + end + end +end + +function test_mooncake_polar( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Polar Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + @testset "left_polar" begin + if m >= n + WP, ΔWP = ad_left_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) + end + end + @testset "right_polar" begin + if m <= n + PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) + Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) + test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) + end + end + end +end + +left_orth_qr(X) = left_orth(X; alg = :qr) +left_orth_polar(X) = left_orth(X; alg = :polar) +left_null_qr(X) = left_null(X; alg = :qr) +right_orth_lq(X) = right_orth(X; alg = :lq) +right_orth_polar(X) = right_orth(X; alg = :polar) +right_null_lq(X) = right_null(X; alg = :lq) + +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) + +function test_mooncake_orthnull( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Orthnull Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + VC, ΔVC = ad_left_orth_setup(A) + CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) + Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) + Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) + + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) + if m >= n + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) + end + + N, ΔN = ad_left_null_setup(A) + dN = make_mooncake_tangent(ΔN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) + test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) + + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) + + if m <= n + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) + test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) + end + + Nᴴ, ΔNᴴ = ad_right_null_setup(A) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) + test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) + end +end + +function test_mooncake_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Mooncake AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m == n + @testset "project_hermitian" begin + Aₕ, ΔAₕ = ad_project_hermitian_setup(A) + dAₕ = make_mooncake_tangent(ΔAₕ) + Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₕ, atol, rtol) + test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) + end + @testset "project_antihermitian" begin + Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) + dAₐ = make_mooncake_tangent(ΔAₐ) + Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₐ, atol, rtol) + test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) + end + end + if m > n + @testset "project_isometric" begin + W, ΔW = ad_project_isometric_setup(A) + dW = make_mooncake_tangent(ΔW) + Mooncake.TestUtils.test_rule(rng, project_isometric, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dW, atol, rtol) + test_pullbacks_match(project_isometric!, project_isometric, A, W, ΔW) + end + end + end +end From fda5129645bacdad1ecd82e70807898ab1d9a807 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:48:34 -0500 Subject: [PATCH 04/15] simplify implementations --- ext/MatrixAlgebraKitChainRulesCoreExt.jl | 36 ++-------- .../MatrixAlgebraKitMooncakeExt.jl | 66 ++++--------------- 2 files changed, 20 insertions(+), 82 deletions(-) diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index ead8ebf7..acfe0d83 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,46 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end -function ChainRulesCore.rrule(::typeof(project_hermitian!), A, Aₕ, alg) - Ac = copy_input(project_hermitian, A) - Aₕ = project_hermitian!(Ac, Aₕ, alg) +function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg) + Aₕ = project_hermitian(A, alg) function project_hermitian_pullback(ΔAₕ) ΔA = project_hermitian(unthunk(ΔAₕ)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_hermitian_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return NoTangent(), ΔA, NoTangent() end return Aₕ, project_hermitian_pullback end -function ChainRulesCore.rrule(::typeof(project_antihermitian!), A, Aₐ, alg) - Ac = copy_input(project_antihermitian, A) - Aₐ = project_antihermitian!(Ac, Aₐ, alg) +function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg) + Aₐ = project_antihermitian(A, alg) function project_antihermitian_pullback(ΔAₐ) ΔA = project_antihermitian(unthunk(ΔAₐ)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_antihermitian_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() + return NoTangent(), ΔA, NoTangent() end return Aₐ, project_antihermitian_pullback end -function ChainRulesCore.rrule(::typeof(project_isometric!), A, W, alg) - Ac = copy_input(project_isometric, A) - # Compute the full polar decomposition to cache P for the pullback - WP = left_polar!(Ac, (similar(W), similar(W, size(W, 2), size(W, 2))), alg) - W_out = copy!(W, WP[1]) - function project_isometric_pullback(ΔW) - ΔA = zero(A) - MatrixAlgebraKit.left_polar_pullback!(ΔA, A, WP, (unthunk(ΔW), nothing)) - return NoTangent(), ΔA, ZeroTangent(), NoTangent() - end - function project_isometric_pullback(::ZeroTangent) - return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent() - end - return W_out, project_isometric_pullback -end - end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index e6cc72a8..93400128 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -778,82 +778,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end -# single-output projections: project_hermitian!, project_antihermitian! # single-output projections: project_hermitian!, project_antihermitian! for (f!, f, adj) in ( (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) - Ac = copy(A) - arg, darg = arrayify(arg_darg) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) argc = copy(arg) - $f!(A, arg, Mooncake.primal(alg_dalg)) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + function $adj(::NoRData) - copy!(A, Ac) dA .+= $f(darg) + dA === darg || zero!(darg) copy!(arg, argc) - zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end return arg_darg, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) - output_codual = CoDual(output, Mooncake.zero_tangent(output)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) function $adj(::NoRData) - arg, darg = arrayify(output_codual) - dA .+= $f(darg) - zero!(darg) - return NoRData(), NoRData(), NoRData() + # TODO: need accumulating projection to avoid intermediate here + dA .+= $f(doutput) + return ntuple(Returns(NoRData(), 3)) end + return output_codual, $adj end end end -# project_isometric! needs special handling: compute full polar decomposition -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric!)}, A_dA::CoDual, W_dW::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) - A, dA = arrayify(A_dA) - W, dW = arrayify(W_dW) - Ac = copy(A) - Wc = copy(W) - # Compute the full polar decomposition for the pullback - m, n = size(A) - P = similar(A, n, n) - WP = left_polar!(copy(A), (copy(W), P), Mooncake.primal(alg_dalg)) - copy!(W, WP[1]) - function project_isometric_adjoint(::NoRData) - copy!(A, Ac) - left_polar_pullback!(dA, A, WP, (dW, nothing)) - copy!(W, Wc) - zero!(dW) - return NoRData(), NoRData(), NoRData(), NoRData() - end - return W_dW, project_isometric_adjoint -end - -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(project_isometric), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(f_df::CoDual{typeof(project_isometric)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) - A, dA = arrayify(A_dA) - alg = Mooncake.primal(alg_dalg) - # Compute the full polar decomposition for the pullback - WP = left_polar(A, alg) - W_out = WP[1] - output_codual = CoDual(W_out, Mooncake.zero_tangent(W_out)) - function project_isometric_adjoint(::NoRData) - W, dW = arrayify(output_codual) - left_polar_pullback!(dA, A, WP, (dW, nothing)) - zero!(dW) - return NoRData(), NoRData(), NoRData() - end - return output_codual, project_isometric_adjoint -end - end From 3e73c0c356d0f1225dbb70654f67c5786a6c5224 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:53:44 -0500 Subject: [PATCH 05/15] remove enzyme --- .../MatrixAlgebraKitEnzymeExt.jl | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl index 4b5e01e5..24a1fa5e 100644 --- a/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl +++ b/ext/MatrixAlgebraKitEnzymeExt/MatrixAlgebraKitEnzymeExt.jl @@ -454,91 +454,4 @@ function EnzymeRules.reverse( return (nothing, nothing, nothing) end -# single-output projections: project_hermitian!, project_antihermitian! -# single-output projections: project_hermitian!, project_antihermitian! -for (f!, project_f) in ( - (project_hermitian!, project_hermitian), - (project_antihermitian!, project_antihermitian), - ) - @eval begin - function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f!)}, - ::Type{RT}, - A::Annotation, - arg::Annotation{TA}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT, TA} - ret = func.val(A.val, arg.val, alg.val) - cache_arg = (arg.val !== ret) || EnzymeRules.overwritten(config)[3] ? copy(ret) : nothing - dret = if EnzymeRules.needs_shadow(config) - (TA == Nothing || isa(arg, Const)) ? zero(ret) : arg.dval - else - nothing - end - primal = EnzymeRules.needs_primal(config) ? ret : nothing - return EnzymeRules.AugmentedReturn(primal, dret, (cache_arg, dret)) - end - function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof($f!)}, - ::Type{RT}, - cache, - A::Annotation, - arg::Annotation, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} - cache_arg, darg = cache - argdval = something(darg, arg.dval) - if !isa(A, Const) - A.dval .+= $project_f(argdval) - end - !isa(arg, Const) && make_zero!(arg.dval) - return (nothing, nothing, nothing) - end - end -end - -# project_isometric! needs special handling: compute full polar decomposition -function EnzymeRules.augmented_primal( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(project_isometric!)}, - ::Type{RT}, - A::Annotation, - W::Annotation{TW}, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT, TW} - # Compute the full polar decomposition for the pullback - Ac = copy(A.val) - m, n = size(A.val) - P = similar(A.val, n, n) - WP = left_polar!(Ac, (W.val, P), alg.val) - cache_WP = EnzymeRules.overwritten(config)[3] ? copy.(WP) : nothing - dret = if EnzymeRules.needs_shadow(config) - (TW == Nothing || isa(W, Const)) ? zero(WP[1]) : W.dval - else - nothing - end - primal = EnzymeRules.needs_primal(config) ? WP[1] : nothing - return EnzymeRules.AugmentedReturn(primal, dret, (cache_WP, dret)) -end -function EnzymeRules.reverse( - config::EnzymeRules.RevConfigWidth{1}, - func::Const{typeof(project_isometric!)}, - ::Type{RT}, - cache, - A::Annotation, - W::Annotation, - alg::Const{<:MatrixAlgebraKit.AbstractAlgorithm}, - ) where {RT} - cache_WP, dW = cache - Aval = nothing - WPval = something(cache_WP, (W.val, cache_WP[2])) - if !isa(A, Const) - left_polar_pullback!(A.dval, Aval, WPval, (dW, nothing)) - end - !isa(W, Const) && make_zero!(W.dval) - return (nothing, nothing, nothing) -end - end From 219040edecee6c0a391037ba9e2f580cd1780cad Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 09:58:26 -0500 Subject: [PATCH 06/15] simplify chainrules tests --- test/testsuite/ad_utils.jl | 24 --------------------- test/testsuite/chainrules.jl | 42 ++++-------------------------------- 2 files changed, 4 insertions(+), 62 deletions(-) diff --git a/test/testsuite/ad_utils.jl b/test/testsuite/ad_utils.jl index c2a953ec..fce118a8 100644 --- a/test/testsuite/ad_utils.jl +++ b/test/testsuite/ad_utils.jl @@ -537,27 +537,3 @@ function ad_right_null_setup(A) ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2] return Nᴴ, ΔNᴴ end - -function ad_project_hermitian_setup(A) - m, n = size(A) - T = eltype(A) - Aₕ = project_hermitian(A) - ΔAₕ = randn!(similar(A, T, m, n)) - return Aₕ, ΔAₕ -end - -function ad_project_antihermitian_setup(A) - m, n = size(A) - T = eltype(A) - Aₐ = project_antihermitian(A) - ΔAₐ = randn!(similar(A, T, m, n)) - return Aₐ, ΔAₐ -end - -function ad_project_isometric_setup(A) - m, n = size(A) - T = eltype(A) - W = project_isometric(A) - ΔW = randn!(similar(A, T, m, n)) - return W, ΔW -end diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index 3bab06b6..52806750 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -10,7 +10,6 @@ for f in :eig_trunc_no_error, :eigh_trunc_no_error, :svd_compact, :svd_trunc, :svd_trunc_no_error, :svd_vals, :left_polar, :right_polar, - :project_hermitian, :project_antihermitian, :project_isometric, ) copy_f = Symbol(:cr_copy_, f) f! = Symbol(f, '!') @@ -599,47 +598,14 @@ function test_chainrules_projections( return @testset "Projections Chainrules AD rules $summary_str" begin A = instantiate_matrix(T, sz) m, n = size(A) - config = Zygote.ZygoteRuleConfig() if m == n - alg_h = MatrixAlgebraKit.default_hermitian_algorithm(A) @testset "project_hermitian" begin - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - test_rrule( - cr_copy_project_hermitian, A, alg_h ⊢ NoTangent(); - output_tangent = ΔAₕ, atol = atol, rtol = rtol - ) - test_rrule( - config, project_hermitian, A; - output_tangent = ΔAₕ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_hermitian, A, alg; atol, rtol) end @testset "project_antihermitian" begin - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - test_rrule( - cr_copy_project_antihermitian, A, alg_h ⊢ NoTangent(); - output_tangent = ΔAₐ, atol = atol, rtol = rtol - ) - test_rrule( - config, project_antihermitian, A; - output_tangent = ΔAₐ, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) - end - end - if m > n - @testset "project_isometric" begin - W, ΔW = ad_project_isometric_setup(A) - alg_iso = MatrixAlgebraKit.default_polar_algorithm(A) - test_rrule( - cr_copy_project_isometric, A, alg_iso ⊢ NoTangent(); - output_tangent = ΔW, atol = atol, rtol = rtol - ) - test_rrule( - config, project_isometric, A; - output_tangent = ΔW, - atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false - ) + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_antihermitian, A, alg; atol, rtol) end end end From 5d92319684f7de4e33af531ff3f083ba8655a4e3 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 10:15:49 -0500 Subject: [PATCH 07/15] simplify mooncake tests --- .../MatrixAlgebraKitMooncakeExt.jl | 7 ++++--- test/testsuite/mooncake.jl | 20 ++++++------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index 93400128..ca98a4fb 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -795,7 +795,7 @@ for (f!, f, adj) in ( dA .+= $f(darg) dA === darg || zero!(darg) copy!(arg, argc) - return NoRData(), NoRData(), NoRData(), NoRData() + return ntuple(Returns(NoRData()), 4) end return arg_darg, $adj end @@ -810,10 +810,11 @@ for (f!, f, adj) in ( function $adj(::NoRData) # TODO: need accumulating projection to avoid intermediate here dA .+= $f(doutput) - return ntuple(Returns(NoRData(), 3)) + zero!(doutput) + return ntuple(Returns(NoRData()), 3) end - return output_codual, $adj + return output_doutput, $adj end end end diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl index 6a437c9b..9ca869d2 100644 --- a/test/testsuite/mooncake.jl +++ b/test/testsuite/mooncake.jl @@ -550,25 +550,17 @@ function test_mooncake_projections( m, n = size(A) if m == n @testset "project_hermitian" begin - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - dAₕ = make_mooncake_tangent(ΔAₕ) - Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₕ, atol, rtol) + Aₕ = project_hermitian(A) + ΔAₕ = make_mooncake_tangent(Aₕ) + Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) end @testset "project_antihermitian" begin - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - dAₐ = make_mooncake_tangent(ΔAₐ) - Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dAₐ, atol, rtol) + Aₐ = project_antihermitian(A) + ΔAₐ = make_mooncake_tangent(Aₐ) + Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) end end - if m > n - @testset "project_isometric" begin - W, ΔW = ad_project_isometric_setup(A) - dW = make_mooncake_tangent(ΔW) - Mooncake.TestUtils.test_rule(rng, project_isometric, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dW, atol, rtol) - test_pullbacks_match(project_isometric!, project_isometric, A, W, ΔW) - end - end end end From 3b282ea90feaf68c1a3250ca2c3837150b775d64 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 17 Feb 2026 10:17:31 -0500 Subject: [PATCH 08/15] revert changes --- src/common/pullbacks.jl | 21 ++------------------- test/testsuite/enzyme.jl | 39 --------------------------------------- 2 files changed, 2 insertions(+), 58 deletions(-) diff --git a/src/common/pullbacks.jl b/src/common/pullbacks.jl index 41834f5d..4fe853cd 100644 --- a/src/common/pullbacks.jl +++ b/src/common/pullbacks.jl @@ -11,22 +11,5 @@ function iszerotangent end iszerotangent(::Any) = false iszerotangent(::Nothing) = true -# Solve the Sylvester equation A*X + X*B + C = 0. -# When A === B (same Hermitian PD matrix, as in polar pullbacks), use an -# eigendecomposition-based solver to avoid LAPACK's trsyl! failing with -# LAPACKException(1) for close eigenvalues. -function _sylvester(A, B, C) - if A === B - return _sylvester_symm(A, C) - end - return LinearAlgebra.sylvester(A, B, C) -end - -function _sylvester_symm(P, C) - D, Q = LinearAlgebra.eigen(LinearAlgebra.Hermitian(P)) - Y = Q' * C * Q - @inbounds for j in axes(Y, 2), i in axes(Y, 1) - Y[i, j] = -Y[i, j] / (D[i] + D[j]) - end - return Q * Y * Q' -end +# fallback +_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C) diff --git a/test/testsuite/enzyme.jl b/test/testsuite/enzyme.jl index 413a4171..e9b07a31 100644 --- a/test/testsuite/enzyme.jl +++ b/test/testsuite/enzyme.jl @@ -105,7 +105,6 @@ function test_enzyme(T::Type, sz; kwargs...) test_enzyme_polar(T, sz; kwargs...) test_enzyme_orthnull(T, sz; kwargs...) end - test_enzyme_projections(T, sz; kwargs...) end end @@ -463,41 +462,3 @@ function test_enzyme_orthnull( end end end - -function test_enzyme_projections( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Projections Enzyme AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - fdm = eltype(T) <: Union{Float32, ComplexF32} ? EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1, max_range = 1.0e-2) : EnzymeTestUtils.FiniteDifferences.central_fdm(5, 1) - if m == n - @testset "project_hermitian" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - Aₕ, ΔAₕ = ad_project_hermitian_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_hermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₕ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) - end - end - @testset "project_antihermitian" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - Aₐ, ΔAₐ = ad_project_antihermitian_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_antihermitian, RT, (A, TA); atol, rtol, output_tangent = ΔAₐ, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) - end - end - end - if m > n - @testset "project_isometric" begin - @testset "reverse: RT $RT, TA $TA" for RT in (Duplicated,), TA in (Duplicated,) - W, ΔW = ad_project_isometric_setup(A) - eltype(T) <: BlasFloat && test_reverse(project_isometric, RT, (A, TA); atol, rtol, output_tangent = ΔW, fdm) - is_cpu(A) && enz_test_pullbacks_match(rng, project_isometric!, project_isometric, A, W, ΔW) - end - end - end - end -end From f3bfc541a2780a76ce531d4b7d41414c4d247d9f Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Wed, 18 Feb 2026 07:33:46 -0500 Subject: [PATCH 09/15] possibly fix implementation --- .../MatrixAlgebraKitMooncakeExt.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index ca98a4fb..d605a33e 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -792,11 +792,15 @@ for (f!, f, adj) in ( arg = $f!(A, arg, Mooncake.primal(alg_dalg)) function $adj(::NoRData) - dA .+= $f(darg) - dA === darg || zero!(darg) + $f!(darg) + if dA !== darg + dA .+= darg + zero!(darg) + end copy!(arg, argc) return ntuple(Returns(NoRData()), 4) end + return arg_darg, $adj end From 89bed774b8434fe56401b5f4287bed4ee5f8084c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 22 Feb 2026 16:17:49 -0500 Subject: [PATCH 10/15] refactor projections tests --- test/mooncake/projections.jl | 19 + test/testsuite/TestSuite.jl | 3 +- test/testsuite/mooncake.jl | 566 ------------------------- test/testsuite/mooncake/projections.jl | 61 +++ 4 files changed, 81 insertions(+), 568 deletions(-) create mode 100644 test/mooncake/projections.jl delete mode 100644 test/testsuite/mooncake.jl create mode 100644 test/testsuite/mooncake/projections.jl diff --git a/test/mooncake/projections.jl b/test/mooncake/projections.jl new file mode 100644 index 00000000..547edd84 --- /dev/null +++ b/test/mooncake/projections.jl @@ -0,0 +1,19 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + if !is_buildkite + TestSuite.test_mooncake_projections(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + end +end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2edd0846..d3c98e9a 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -93,8 +93,6 @@ end include("ad_utils.jl") -include("projections.jl") - # Decompositions # -------------- include("decompositions/qr.jl") @@ -116,6 +114,7 @@ include("mooncake/eigh.jl") include("mooncake/svd.jl") include("mooncake/polar.jl") include("mooncake/orthnull.jl") +include("mooncake/projections.jl") include("enzyme.jl") include("chainrules.jl") diff --git a/test/testsuite/mooncake.jl b/test/testsuite/mooncake.jl deleted file mode 100644 index 9ca869d2..00000000 --- a/test/testsuite/mooncake.jl +++ /dev/null @@ -1,566 +0,0 @@ -using TestExtras -using MatrixAlgebraKit -using Mooncake, Mooncake.TestUtils -using Mooncake: rrule!! -using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD, eigh_trunc -using LinearAlgebra: BlasFloat -using GenericLinearAlgebra - -function mc_copy_eigh_full(A; kwargs...) - A = (A + A') / 2 - return eigh_full(A; kwargs...) -end - -function mc_copy_eigh_full!(A, DV; kwargs...) - A = (A + A') / 2 - return eigh_full!(A, DV; kwargs...) -end - -function mc_copy_eigh_vals(A; kwargs...) - A = (A + A') / 2 - return eigh_vals(A; kwargs...) -end - -function mc_copy_eigh_vals!(A, D; kwargs...) - A = (A + A') / 2 - return eigh_vals!(A, D; kwargs...) -end - -function mc_copy_eigh_trunc(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc(A, alg; kwargs...) -end - -function mc_copy_eigh_trunc!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc!(A, DV, alg; kwargs...) -end - -function mc_copy_eigh_trunc_no_error(A, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error(A, alg; kwargs...) -end - -function mc_copy_eigh_trunc_no_error!(A, DV, alg; kwargs...) - A = (A + A') / 2 - return eigh_trunc_no_error!(A, DV, alg; kwargs...) -end - -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_full), A) = MatrixAlgebraKit.copy_input(eigh_full, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_vals), A) = MatrixAlgebraKit.copy_input(eigh_vals, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) -MatrixAlgebraKit.copy_input(::typeof(mc_copy_eigh_trunc_no_error), A) = MatrixAlgebraKit.copy_input(eigh_trunc, A) - -make_mooncake_tangent(ΔAelem::T) where {T <: Number} = ΔAelem -make_mooncake_tangent(ΔA::AbstractMatrix) = ΔA -make_mooncake_tangent(ΔA::AbstractVector) = ΔA -make_mooncake_tangent(ΔD::Diagonal) = Mooncake.build_tangent(typeof(ΔD), diagview(ΔD)) - -make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), make_mooncake_tangent.(T)...) - -make_mooncake_fdata(x) = make_mooncake_tangent(x) -make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) -make_mooncake_fdata(x::Tuple) = map(make_mooncake_fdata, x) - -# copies a preset tangent into a Mooncake CoDual -# for use in the pullback. -function copy_tangent(var::Mooncake.CoDual, Δargs) - dargs = make_mooncake_fdata(deepcopy(Δargs)) - copyto!(Mooncake.tangent(var), dargs) - return -end - -function copy_tangent(var::Mooncake.CoDual, Δargs::Tuple) - dargs = make_mooncake_fdata.(deepcopy(Δargs)) - for (var_tangent, darg) in zip(Mooncake.tangent(var), dargs) - if var_tangent isa Mooncake.FData - for (var_f, darg_f) in zip(Mooncake._fields(var_tangent), Mooncake._fields(darg)) - copyto!(var_f, darg_f) - end - else - copyto!(var_tangent, darg) - end - end - return -end - -# no `alg` argument -function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) - dA_copy = make_mooncake_fdata(copy(ΔA)) - A_copy = copy(A) - A_dA = Mooncake.CoDual(A_copy, dA_copy) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA) - # copy Δargs into tangent of the output variable for the pullback check - copy_tangent(copy_out, Δargs) - copy_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - return dA_copy, Mooncake.tangent(copy_out) -end - -# `alg` argument -function _get_copying_derivative(f, rrule, A, ΔA, args, Δargs, alg, rdata) - dA_copy = make_mooncake_fdata(copy(ΔA)) - A_copy = copy(A) - A_dA = Mooncake.CoDual(A_copy, dA_copy) - copy_out, copy_pb!! = rrule(Mooncake.CoDual(f, Mooncake.NoFData()), A_dA, Mooncake.CoDual(alg, Mooncake.NoFData())) - # copy Δargs into tangent of the output variable for the pullback check - copy_tangent(copy_out, Δargs) - copy_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - return dA_copy, Mooncake.tangent(copy_out) -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, ::Nothing, rdata; ȳ = Δargs) - dA_inplace = make_mooncake_fdata(copy(ΔA)) - A_inplace = copy(A) - args_copy = deepcopy(args) - dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - A_dA = Mooncake.CoDual(A_inplace, dA_inplace) - args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs) - end - # copy reference derivative of output ȳ into inplace_out - # needed for inplace methods like svd_trunc! that generate - # new output variables - copy_tangent(inplace_out, ȳ) - inplace_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - @test Mooncake.primal(args_dargs) == args_copy - return dA_inplace, Mooncake.tangent(inplace_out) -end - -function _get_inplace_derivative(f!, A, ΔA, args, Δargs, alg, rdata; ȳ = Δargs) - dA_inplace = make_mooncake_fdata(copy(ΔA)) - A_inplace = copy(A) - args_copy = deepcopy(args) - dargs_inplace = make_mooncake_fdata(deepcopy(Δargs)) - # not every f! has a handwritten rrule!! - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - has_handwritten_rule = hasmethod(Mooncake.rrule!!, inplace_sig) - A_dA = Mooncake.CoDual(A_inplace, dA_inplace) - args_dargs = Mooncake.CoDual(args_copy, dargs_inplace) - if has_handwritten_rule - inplace_out, inplace_pb!! = Mooncake.rrule!!(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) - else - inplace_sig = Tuple{typeof(f!), typeof(A), typeof(args), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - inplace_rrule = Mooncake.build_rrule(rvs_interp, inplace_sig) - inplace_out, inplace_pb!! = inplace_rrule(Mooncake.CoDual(f!, Mooncake.NoFData()), A_dA, args_dargs, Mooncake.CoDual(alg, Mooncake.NoFData())) - end - # copy reference derivative of output ȳ into inplace_out - # needed for inplace methods like svd_trunc! that generate - # new output variables - copy_tangent(inplace_out, ȳ) - inplace_pb!!(rdata) - @test Mooncake.primal(A_dA) == A - @test Mooncake.primal(args_dargs) == args_copy - return dA_inplace, Mooncake.tangent(inplace_out) -end - -""" - test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData()) - -Compare the result of running the *in-place, mutating* function `f!`'s reverse rule -with the result of running its *non-mutating* partner function `f`'s reverse rule. -We must compare directly because many of the mutating functions modify `A` as a -scratch workspace, making testing `f!` against finite differences infeasible. - -The arguments to this function are: - - `f!` the mutating, in-place version of the function (accepts `args` for the function result) - - `f` the non-mutating version of the function (does not accept `args` for the function result) - - `A` the input matrix to factorize - - `args` preallocated output for `f!` (e.g. `Q` and `R` matrices for `qr_compact!`) - - `Δargs` precomputed derivatives of `args` for pullbacks of `f` and `f!`, to ensure they receive the same input - - `alg` optional algorithm keyword argument - - `rdata` Mooncake reverse data to supply to the pullback, in case `f` and `f!` return scalar results (as truncating functions do) -""" -function test_pullbacks_match(f!, f, A, args, Δargs, alg = nothing; rdata = Mooncake.NoRData(), ȳ = deepcopy(Δargs)) - sig = isnothing(alg) ? Tuple{typeof(f), typeof(A)} : Tuple{typeof(f), typeof(A), typeof(alg)} - rvs_interp = Mooncake.get_interpreter(Mooncake.ReverseMode) - rrule = Mooncake.build_rrule(rvs_interp, sig) - ΔA = randn(rng, eltype(A), size(A)) - - copy_args = isa(args, Tuple) ? copy.(args) : copy(args) - inplace_args = isa(args, Tuple) ? copy.(args) : copy(args) - dA_copy, dargs_copy = _get_copying_derivative(f, rrule, A, ΔA, copy_args, ȳ, alg, rdata) - dA_inplace, dargs_inplace = _get_inplace_derivative(f!, A, ΔA, inplace_args, Δargs, alg, rdata; ȳ) - - dA_inplace_ = Mooncake.arrayify(A, dA_inplace)[2] - dA_copy_ = Mooncake.arrayify(A, dA_copy)[2] - @test dA_inplace_ ≈ dA_copy_ - @test copy_args == inplace_args - if dargs_copy isa Tuple - for (darg_copy_, darg_inplace_) in zip(dargs_copy, dargs_inplace) - if darg_copy_ isa Mooncake.FData - for (c_f, i_f) in zip(Mooncake._fields(darg_copy_), Mooncake._fields(darg_inplace_)) - @test c_f == i_f - end - else - @test darg_copy_ == darg_inplace_ - end - end - else - @test dargs_copy == dargs_inplace - end - return -end - -function test_mooncake(T::Type, sz; kwargs...) - summary_str = testargs_summary(T, sz) - return @testset "Mooncake AD $summary_str" begin - test_mooncake_qr(T, sz; kwargs...) - test_mooncake_lq(T, sz; kwargs...) - if length(sz) == 1 || sz[1] == sz[2] - test_mooncake_eig(T, sz; kwargs...) - test_mooncake_eigh(T, sz; kwargs...) - end - test_mooncake_svd(T, sz; kwargs...) - test_mooncake_polar(T, sz; kwargs...) - # doesn't work for Diagonals yet? - if T <: Number - test_mooncake_orthnull(T, sz; kwargs...) - end - test_mooncake_projections(T, sz; kwargs...) - end -end - -function test_mooncake_qr( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "QR Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - @testset "qr_compact" begin - QR, ΔQR = ad_qr_compact_setup(A) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_compact!, qr_compact, A, QR, ΔQR) - end - @testset "qr_null" begin - N, ΔN = ad_qr_null_setup(A) - dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dN, atol, rtol) - test_pullbacks_match(qr_null!, qr_null, A, N, ΔN) - end - @testset "qr_full" begin - QR, ΔQR = ad_qr_full_setup(A) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_full!, qr_full, A, QR, ΔQR) - end - @testset "qr_compact - rank-deficient A" begin - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - QR, ΔQR = ad_qr_rank_deficient_compact_setup(Ard) - dQR = make_mooncake_tangent(ΔQR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dQR, atol, rtol) - test_pullbacks_match(qr_compact!, qr_compact, Ard, QR, ΔQR) - end - end -end - -function test_mooncake_lq( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "LQ Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - @testset "lq_compact" begin - LQ, ΔLQ = ad_lq_compact_setup(A) - Mooncake.TestUtils.test_rule(rng, lq_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(lq_compact!, lq_compact, A, LQ, ΔLQ) - end - @testset "lq_null" begin - Nᴴ, ΔNᴴ = ad_lq_null_setup(A) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dNᴴ, atol, rtol) - test_pullbacks_match(lq_null!, lq_null, A, Nᴴ, ΔNᴴ) - end - @testset "lq_full" begin - LQ, ΔLQ = ad_lq_full_setup(A) - dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) - test_pullbacks_match(lq_full!, lq_full, A, LQ, ΔLQ) - end - @testset "lq_compact - rank-deficient A" begin - m, n = size(A) - r = min(m, n) - 5 - Ard = instantiate_matrix(T, (m, r)) * instantiate_matrix(T, (r, n)) - LQ, ΔLQ = ad_lq_rank_deficient_compact_setup(Ard) - dLQ = make_mooncake_tangent(ΔLQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dLQ, atol, rtol) - test_pullbacks_match(lq_compact!, lq_compact, Ard, LQ, ΔLQ) - end - end -end - -function test_mooncake_eig( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIG Mooncake AD rules $summary_str" begin - A = make_eig_matrix(T, sz) - m = size(A, 1) - @testset "eig_full" begin - DV, ΔDV, ΔD2V = ad_eig_full_setup(A) - dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, eig_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dDV, atol, rtol) - test_pullbacks_match(eig_full!, eig_full, A, DV, ΔD2V) - end - @testset "eig_vals" begin - D, ΔD = ad_eig_vals_setup(A) - dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, eig_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dD, atol, rtol) - test_pullbacks_match(eig_vals!, eig_vals, A, D, ΔD) - end - @testset "eig_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((copy.(ΔDVtrunc)..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eig_algorithm(A), truncrank(5; by = real)) - DV, _, ΔDV, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol) - test_pullbacks_match(eig_trunc!, eig_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, eig_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol) - test_pullbacks_match(eig_trunc_no_error!, eig_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - end -end - -function test_mooncake_eigh( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "EIGH Mooncake AD rules $summary_str" begin - A = make_eigh_matrix(T, sz) - m = size(A, 1) - @testset "eigh_full" begin - DV, ΔDV, ΔD2V = ad_eigh_full_setup(A) - dDV = make_mooncake_tangent(ΔD2V) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_full, A; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol, rtol) - test_pullbacks_match(mc_copy_eigh_full!, mc_copy_eigh_full, A, DV, ΔD2V) - end - @testset "eigh_vals" begin - D, ΔD = ad_eigh_vals_setup(A) - dD = make_mooncake_tangent(ΔD) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_vals, A; mode = Mooncake.ReverseMode, output_tangent = dD, is_primitive = false, atol, rtol) - test_pullbacks_match(mc_copy_eigh_vals!, mc_copy_eigh_vals, A, D, ΔD) - end - @testset "eigh_trunc" begin - for r in 1:4:m - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), truncrank(r; by = abs)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - D = eigh_vals(A / 2) - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_eigh_algorithm(A), trunctol(; atol = maximum(abs, D) / 2)) - DV, _, ΔDV, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dDVerr = make_mooncake_tangent((ΔDVtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVerr, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc!, mc_copy_eigh_trunc, A, DV, ΔDV, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔDVtrunc) - dDVtrunc = make_mooncake_tangent(ΔDVtrunc) - Mooncake.TestUtils.test_rule(rng, mc_copy_eigh_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol, rtol, is_primitive = false) - test_pullbacks_match(mc_copy_eigh_trunc_no_error!, mc_copy_eigh_trunc_no_error, A, DV, ΔDV, truncalg; ȳ = ΔDVtrunc) - end - end -end - -function test_mooncake_svd( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "SVD Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - minmn = min(size(A)...) - @testset "svd_compact" begin - USVᴴ, _, ΔUSVᴴ = ad_svd_compact_setup(A) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_compact!, svd_compact, A, USVᴴ, ΔUSVᴴ) - end - @testset "svd_full" begin - USVᴴ, ΔUSVᴴ = ad_svd_full_setup(A) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A; is_primitive = false, mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_full!, svd_full, A, USVᴴ, ΔUSVᴴ) - end - @testset "svd_vals" begin - S, ΔS = ad_svd_vals_setup(A) - Mooncake.TestUtils.test_rule(rng, svd_vals, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(svd_vals!, svd_vals, A, S, ΔS) - end - @testset "svd_trunc" begin - @testset for r in 1:4:minmn - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), truncrank(r)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) - end - @testset "trunctol" begin - A = instantiate_matrix(T, sz) - S, ΔS = ad_svd_vals_setup(A) - truncalg = TruncatedAlgorithm(MatrixAlgebraKit.default_svd_algorithm(A), trunctol(atol = S[1, 1] / 2)) - USVᴴ, ΔUSVᴴ, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) - ϵ = zero(real(eltype(T))) - dUSVᴴerr = make_mooncake_tangent((ΔUSVᴴtrunc..., ϵ)) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol, rtol) - test_pullbacks_match(svd_trunc!, svd_trunc, A, USVᴴ, ΔUSVᴴ, truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(eltype(T)))), ȳ = ΔUSVᴴtrunc) - dUSVᴴ = make_mooncake_tangent(ΔUSVᴴtrunc) - Mooncake.TestUtils.test_rule(rng, svd_trunc_no_error, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol, rtol) - test_pullbacks_match(svd_trunc_no_error!, svd_trunc_no_error, A, USVᴴ, ΔUSVᴴ, truncalg; ȳ = ΔUSVᴴtrunc) - end - end - end -end - -function test_mooncake_polar( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Polar Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - @testset "left_polar" begin - if m >= n - WP, ΔWP = ad_left_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, left_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(left_polar!, left_polar, A, WP, ΔWP) - end - end - @testset "right_polar" begin - if m <= n - PWᴴ, ΔPWᴴ = ad_right_polar_setup(A) - Mooncake.TestUtils.test_rule(rng, right_polar, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(right_polar!, right_polar, A, PWᴴ, ΔPWᴴ) - end - end - end -end - -left_orth_qr(X) = left_orth(X; alg = :qr) -left_orth_polar(X) = left_orth(X; alg = :polar) -left_null_qr(X) = left_null(X; alg = :qr) -right_orth_lq(X) = right_orth(X; alg = :lq) -right_orth_polar(X) = right_orth(X; alg = :polar) -right_null_lq(X) = right_null(X; alg = :lq) - -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) - -function test_mooncake_orthnull( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Orthnull Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - VC, ΔVC = ad_left_orth_setup(A) - CVᴴ, ΔCVᴴ = ad_right_orth_setup(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(left_orth!, left_orth, A, VC, ΔVC) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(right_orth!, right_orth, A, CVᴴ, ΔCVᴴ) - - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, ΔVC) - if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, ΔVC) - end - - N, ΔN = ad_left_null_setup(A) - dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dN) - test_pullbacks_match(((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, ΔCVᴴ) - - if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false) - test_pullbacks_match(((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, ΔCVᴴ) - end - - Nᴴ, ΔNᴴ = ad_right_null_setup(A) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol, rtol, is_primitive = false, output_tangent = dNᴴ) - test_pullbacks_match(((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) - end -end - -function test_mooncake_projections( - T::Type, sz; - atol::Real = 0, rtol::Real = precision(T), - kwargs... - ) - summary_str = testargs_summary(T, sz) - return @testset "Projections Mooncake AD rules $summary_str" begin - A = instantiate_matrix(T, sz) - m, n = size(A) - if m == n - @testset "project_hermitian" begin - Aₕ = project_hermitian(A) - ΔAₕ = make_mooncake_tangent(Aₕ) - Mooncake.TestUtils.test_rule(rng, project_hermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(project_hermitian!, project_hermitian, A, Aₕ, ΔAₕ) - end - @testset "project_antihermitian" begin - Aₐ = project_antihermitian(A) - ΔAₐ = make_mooncake_tangent(Aₐ) - Mooncake.TestUtils.test_rule(rng, project_antihermitian, A; is_primitive = false, mode = Mooncake.ReverseMode, atol, rtol) - test_pullbacks_match(project_antihermitian!, project_antihermitian, A, Aₐ, ΔAₐ) - end - end - end -end diff --git a/test/testsuite/mooncake/projections.jl b/test/testsuite/mooncake/projections.jl new file mode 100644 index 00000000..6927a186 --- /dev/null +++ b/test/testsuite/mooncake/projections.jl @@ -0,0 +1,61 @@ +""" + test_mooncake_projections(T, sz; kwargs...) + +Run all Mooncake AD tests for hermitian and anti-hermitian projections of element type `T` +and size `sz`. +""" +function test_mooncake_projections(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake projection $summary_str" begin + test_mooncake_project_hermitian(T, sz; kwargs...) + test_mooncake_project_antihermitian(T, sz; kwargs...) + end +end + +""" + test_mooncake_project_hermitian(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `project_hermitian` and its in-place variant. +""" +function test_mooncake_project_hermitian( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "project_hermitian" begin + A = instantiate_matrix(T, sz) + B = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) + Mooncake.TestUtils.test_rule( + rng, project_hermitian, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_hermitian!, A, B, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + end +end + +""" + test_mooncake_project_antihermitian(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `project_antihermitian` and its in-place variant. +""" +function test_mooncake_project_antihermitian( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "project_antihermitian" begin + A = instantiate_matrix(T, sz) + B = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian!, A, B, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + end +end From b14a8f55fcc9c1882c7bbe4f2eea12f16bece0c6 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 26 Feb 2026 10:36:29 -0500 Subject: [PATCH 11/15] revert accidental deletion of projections file --- test/testsuite/TestSuite.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index d3c98e9a..361f125b 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -93,6 +93,8 @@ end include("ad_utils.jl") +include("projections.jl") + # Decompositions # -------------- include("decompositions/qr.jl") From b3180189762e683b95a4ef23d87414fc70d6c1eb Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 26 Feb 2026 10:36:29 -0500 Subject: [PATCH 12/15] comment about copy/restoring A --- ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index d605a33e..ef32d6de 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -788,6 +788,8 @@ for (f!, f, adj) in ( function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + + # don't need to copy/restore A since projections don't mutate input argc = copy(arg) arg = $f!(A, arg, Mooncake.primal(alg_dalg)) From 0988f9f5fd3387da9eff1eb0fd4c2e2ac37fa662 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Thu, 26 Feb 2026 14:39:04 -0500 Subject: [PATCH 13/15] test aliasing mooncake rules --- test/testsuite/mooncake/projections.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/testsuite/mooncake/projections.jl b/test/testsuite/mooncake/projections.jl index 6927a186..74e22a44 100644 --- a/test/testsuite/mooncake/projections.jl +++ b/test/testsuite/mooncake/projections.jl @@ -29,6 +29,10 @@ function test_mooncake_project_hermitian( rng, project_hermitian, A, alg; mode = Mooncake.ReverseMode, atol, rtol ) + Mooncake.TestUtils.test_rule( + rng, project_hermitian!, A, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) Mooncake.TestUtils.test_rule( rng, project_hermitian!, A, B, alg; mode = Mooncake.ReverseMode, atol, rtol @@ -53,6 +57,10 @@ function test_mooncake_project_antihermitian( rng, project_antihermitian, A, alg; mode = Mooncake.ReverseMode, atol, rtol ) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian!, A, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) Mooncake.TestUtils.test_rule( rng, project_antihermitian!, A, B, alg; mode = Mooncake.ReverseMode, atol, rtol From 426938b24b0c9078d74b692715c8011b8f543c31 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 27 Feb 2026 09:24:51 -0500 Subject: [PATCH 14/15] test mooncake diagonal projections and fix oopsie --- src/implementations/projections.jl | 6 +++--- test/mooncake/projections.jl | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/implementations/projections.jl b/src/implementations/projections.jl index 857366ed..0623fea5 100644 --- a/src/implementations/projections.jl +++ b/src/implementations/projections.jl @@ -65,11 +65,11 @@ end function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti} if anti - diagview(A) .= _imimag.(diagview(B)) + diagview(B) .= _imimag.(diagview(A)) else - diagview(A) .= real.(diagview(B)) + diagview(B) .= real.(diagview(A)) end - return A + return B end function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32) diff --git a/test/mooncake/projections.jl b/test/mooncake/projections.jl index 547edd84..b33967e9 100644 --- a/test/mooncake/projections.jl +++ b/test/mooncake/projections.jl @@ -13,7 +13,9 @@ is_buildkite = get(ENV, "BUILDKITE", "false") == "true" m = 19 for T in (BLASFloats..., GenericFloats...) TestSuite.seed_rng!(123) + atol = rtol = m * m * TestSuite.precision(T) if !is_buildkite - TestSuite.test_mooncake_projections(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T)) + TestSuite.test_mooncake_projections(T, (m, m); atol, rtol) + TestSuite.test_mooncake_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol) end end From 4254861f7f868ff648902c4d3b85b34729bba361 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Fri, 27 Feb 2026 09:33:32 -0500 Subject: [PATCH 15/15] Test supplying a destination to projections --- test/testsuite/projections.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/testsuite/projections.jl b/test/testsuite/projections.jl index 3147084a..82091d01 100644 --- a/test/testsuite/projections.jl +++ b/test/testsuite/projections.jl @@ -40,6 +40,14 @@ function test_project_antihermitian( @test Ba === Ac @test isantihermitian(Ba) @test Ba ≈ Aa + + # can we supply a destination + Ac = deepcopy(A) + Ba = instantiate_matrix(T, sz) + Ba₂ = project_antihermitian!(Ac, Ba) + @test A == Ac + @test Ba₂ === Ba + @test Ba₂ ≈ Aa end # test approximate error calculation @@ -79,6 +87,7 @@ function test_project_hermitian( Aa = (A - A') / 2 Bh = project_hermitian(A; blocksize = 16) + @test ishermitian(Bh) @test Bh ≈ Ah @test A == Ac @@ -91,6 +100,14 @@ function test_project_hermitian( @test Bh === Ac @test ishermitian(Bh) @test Bh ≈ Ah + + # can we supply a destination + Ac = deepcopy(A) + Bh = instantiate_matrix(T, sz) + Bh₂ = project_hermitian!(Ac, Bh) + @test A == Ac + @test Bh₂ === Bh + @test Bh₂ ≈ Ah end # test approximate error calculation