diff --git a/ext/MatrixAlgebraKitChainRulesCoreExt.jl b/ext/MatrixAlgebraKitChainRulesCoreExt.jl index 400b2a79..acfe0d83 100644 --- a/ext/MatrixAlgebraKitChainRulesCoreExt.jl +++ b/ext/MatrixAlgebraKitChainRulesCoreExt.jl @@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg) return PWᴴ, right_polar_pullback end +function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg) + Aₕ = project_hermitian(A, alg) + function project_hermitian_pullback(ΔAₕ) + ΔA = project_hermitian(unthunk(ΔAₕ)) + return NoTangent(), ΔA, NoTangent() + end + return Aₕ, project_hermitian_pullback +end + +function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg) + Aₐ = project_antihermitian(A, alg) + function project_antihermitian_pullback(ΔAₐ) + ΔA = project_antihermitian(unthunk(ΔAₐ)) + return NoTangent(), ΔA, NoTangent() + end + return Aₐ, project_antihermitian_pullback +end + end diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index f32e4258..ef32d6de 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -778,4 +778,51 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint end +# single-output projections: project_hermitian!, project_antihermitian! +for (f!, f, adj) in ( + (:project_hermitian!, :project_hermitian, :project_hermitian_adjoint), + (:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint), + ) + @eval begin + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg) + + # don't need to copy/restore A since projections don't mutate input + argc = copy(arg) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + + function $adj(::NoRData) + $f!(darg) + if dA !== darg + dA .+= darg + zero!(darg) + end + copy!(arg, argc) + return ntuple(Returns(NoRData()), 4) + end + + return arg_darg, $adj + end + + @is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + output = $f(A, Mooncake.primal(alg_dalg)) + output_doutput = Mooncake.zero_fcodual(output) + + doutput = last(arrayify(output_doutput)) + function $adj(::NoRData) + # TODO: need accumulating projection to avoid intermediate here + dA .+= $f(doutput) + zero!(doutput) + return ntuple(Returns(NoRData()), 3) + end + + return output_doutput, $adj + end + end +end + end diff --git a/src/implementations/projections.jl b/src/implementations/projections.jl index 857366ed..0623fea5 100644 --- a/src/implementations/projections.jl +++ b/src/implementations/projections.jl @@ -65,11 +65,11 @@ end function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti} if anti - diagview(A) .= _imimag.(diagview(B)) + diagview(B) .= _imimag.(diagview(A)) else - diagview(A) .= real.(diagview(B)) + diagview(B) .= real.(diagview(A)) end - return A + return B end function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32) diff --git a/src/pullbacks/polar.jl b/src/pullbacks/polar.jl index 4d498da0..a549321e 100644 --- a/src/pullbacks/polar.jl +++ b/src/pullbacks/polar.jl @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...) !iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1) !iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, W, C, 1, 1) if !iszerotangent(ΔW) ΔWP = ΔW / P @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs... !iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1) !iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1) C = _sylvester(P, P, M' - M) - C .+= ΔP + !iszerotangent(ΔP) && (C .+= ΔP) ΔA = mul!(ΔA, C, Wᴴ, 1, 1) if !iszerotangent(ΔWᴴ) PΔWᴴ = P \ ΔWᴴ diff --git a/test/mooncake/projections.jl b/test/mooncake/projections.jl new file mode 100644 index 00000000..b33967e9 --- /dev/null +++ b/test/mooncake/projections.jl @@ -0,0 +1,21 @@ +using MatrixAlgebraKit +using Test +using LinearAlgebra: Diagonal +using CUDA, AMDGPU + +BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI +GenericFloats = () +@isdefined(TestSuite) || include("../testsuite/TestSuite.jl") +using .TestSuite + +is_buildkite = get(ENV, "BUILDKITE", "false") == "true" + +m = 19 +for T in (BLASFloats..., GenericFloats...) + TestSuite.seed_rng!(123) + atol = rtol = m * m * TestSuite.precision(T) + if !is_buildkite + TestSuite.test_mooncake_projections(T, (m, m); atol, rtol) + TestSuite.test_mooncake_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol) + end +end diff --git a/test/testsuite/TestSuite.jl b/test/testsuite/TestSuite.jl index 2edd0846..361f125b 100644 --- a/test/testsuite/TestSuite.jl +++ b/test/testsuite/TestSuite.jl @@ -116,6 +116,7 @@ include("mooncake/eigh.jl") include("mooncake/svd.jl") include("mooncake/polar.jl") include("mooncake/orthnull.jl") +include("mooncake/projections.jl") include("enzyme.jl") include("chainrules.jl") diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index bb94664d..52806750 100644 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -46,6 +46,7 @@ function test_chainrules(T::Type, sz; kwargs...) test_chainrules_svd(T, sz; kwargs...) test_chainrules_polar(T, sz; kwargs...) test_chainrules_orthnull(T, sz; kwargs...) + test_chainrules_projections(T, sz; kwargs...) end end @@ -587,3 +588,25 @@ function test_chainrules_orthnull( ) end end + +function test_chainrules_projections( + T::Type, sz; + atol::Real = 0, rtol::Real = precision(T), + kwargs... + ) + summary_str = testargs_summary(T, sz) + return @testset "Projections Chainrules AD rules $summary_str" begin + A = instantiate_matrix(T, sz) + m, n = size(A) + if m == n + @testset "project_hermitian" begin + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_hermitian, A, alg; atol, rtol) + end + @testset "project_antihermitian" begin + alg = MatrixAlgebraKit.default_hermitian_algorithm(A) + test_rrule(project_antihermitian, A, alg; atol, rtol) + end + end + end +end diff --git a/test/testsuite/mooncake/projections.jl b/test/testsuite/mooncake/projections.jl new file mode 100644 index 00000000..74e22a44 --- /dev/null +++ b/test/testsuite/mooncake/projections.jl @@ -0,0 +1,69 @@ +""" + test_mooncake_projections(T, sz; kwargs...) + +Run all Mooncake AD tests for hermitian and anti-hermitian projections of element type `T` +and size `sz`. +""" +function test_mooncake_projections(T::Type, sz; kwargs...) + summary_str = testargs_summary(T, sz) + return @testset "Mooncake projection $summary_str" begin + test_mooncake_project_hermitian(T, sz; kwargs...) + test_mooncake_project_antihermitian(T, sz; kwargs...) + end +end + +""" + test_mooncake_project_hermitian(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `project_hermitian` and its in-place variant. +""" +function test_mooncake_project_hermitian( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "project_hermitian" begin + A = instantiate_matrix(T, sz) + B = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) + Mooncake.TestUtils.test_rule( + rng, project_hermitian, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_hermitian!, A, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_hermitian!, A, B, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + end +end + +""" + test_mooncake_project_antihermitian(T, sz; rng, atol, rtol) + +Test the Mooncake reverse-mode AD rule for `project_antihermitian` and its in-place variant. +""" +function test_mooncake_project_antihermitian( + T, sz; + rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T) + ) + return @testset "project_antihermitian" begin + A = instantiate_matrix(T, sz) + B = instantiate_matrix(T, sz) + alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian!, A, A, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + Mooncake.TestUtils.test_rule( + rng, project_antihermitian!, A, B, alg; + mode = Mooncake.ReverseMode, atol, rtol + ) + end +end diff --git a/test/testsuite/projections.jl b/test/testsuite/projections.jl index 3147084a..82091d01 100644 --- a/test/testsuite/projections.jl +++ b/test/testsuite/projections.jl @@ -40,6 +40,14 @@ function test_project_antihermitian( @test Ba === Ac @test isantihermitian(Ba) @test Ba ≈ Aa + + # can we supply a destination + Ac = deepcopy(A) + Ba = instantiate_matrix(T, sz) + Ba₂ = project_antihermitian!(Ac, Ba) + @test A == Ac + @test Ba₂ === Ba + @test Ba₂ ≈ Aa end # test approximate error calculation @@ -79,6 +87,7 @@ function test_project_hermitian( Aa = (A - A') / 2 Bh = project_hermitian(A; blocksize = 16) + @test ishermitian(Bh) @test Bh ≈ Ah @test A == Ac @@ -91,6 +100,14 @@ function test_project_hermitian( @test Bh === Ac @test ishermitian(Bh) @test Bh ≈ Ah + + # can we supply a destination + Ac = deepcopy(A) + Bh = instantiate_matrix(T, sz) + Bh₂ = project_hermitian!(Ac, Bh) + @test A == Ac + @test Bh₂ === Bh + @test Bh₂ ≈ Ah end # test approximate error calculation