Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 0 additions & 30 deletions test/enzyme.jl

This file was deleted.

19 changes: 19 additions & 0 deletions test/enzyme/eig.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_eig(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/eigh.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_eigh(T, (m, m); atol = m * m * TestSuite.precision(T), rtol = m * m * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/lq.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_lq(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/orthnull.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_orthnull(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/polar.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_polar(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/qr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(123)
if !is_buildkite
TestSuite.test_enzyme_qr(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
19 changes: 19 additions & 0 deletions test/enzyme/svd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...), n in (17, m, 23)
TestSuite.seed_rng!(1234)
if !is_buildkite
TestSuite.test_enzyme_svd(T, (m, n); atol = m * n * TestSuite.precision(T), rtol = m * n * TestSuite.precision(T))
end
end
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ if filter_tests!(testsuite, args)
else
is_apple_ci = Sys.isapple() && get(ENV, "CI", "false") == "true"
if is_apple_ci
delete!(testsuite, "enzyme")
filter!(p -> !startswith(first(p), "mooncake/"), testsuite)
delete!(testsuite, "chainrules")
end
Sys.iswindows() && delete!(testsuite, "enzyme")
(Sys.iswindows() || is_apple_ci) && filter!(p -> !startswith(first(p), "enzyme/"), testsuite)
end
end

Expand Down
12 changes: 11 additions & 1 deletion test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ using LinearAlgebra: Diagonal, norm, istriu, istril, I
using Random, StableRNGs
using Mooncake
using AMDGPU, CUDA
using Enzyme, EnzymeTestUtils

const tests = Dict()

Expand Down Expand Up @@ -117,7 +118,16 @@ include("mooncake/svd.jl")
include("mooncake/polar.jl")
include("mooncake/orthnull.jl")

include("enzyme.jl")
include("chainrules.jl")

# Enzyme
# ------
include("enzyme/eig.jl")
include("enzyme/eigh.jl")
include("enzyme/qr.jl")
include("enzyme/lq.jl")
include("enzyme/svd.jl")
include("enzyme/polar.jl")
include("enzyme/orthnull.jl")

end
34 changes: 29 additions & 5 deletions test/testsuite/ad_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
"""
remove_svd_gauge_dependence!(ΔV, D, V)

Remove the gauge-dependent part from the cotangents `ΔU` and ΔVᴴ` of the singular vector matrices `U`
and `Vᴴ`. The singular vectors are only determined up to complex phase (and unitary mixing for degenerate
eigenvalues), so the corresponding components of `ΔU` and `ΔVᴴ` are projected out.
"""
function remove_svd_gauge_dependence!(
ΔU, ΔVᴴ, U, S, Vᴴ;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(S)
)
gaugepart = mul!(U' * ΔU, Vᴴ, ΔVᴴ', true, true)
gaugepart = project_antihermitian!(gaugepart)
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end

"""
remove_eig_gauge_dependence!(ΔV, D, V)

Expand Down Expand Up @@ -163,6 +181,8 @@ function call_and_zero!(f!, A, alg)
return F′
end

is_cpu(A) = typeof(parent(A)) <: Array

"""
eigh_wrapper(f, A, alg)

Expand Down Expand Up @@ -234,6 +254,11 @@ function ad_qr_compact_setup(A::Diagonal)
end

function ad_qr_null_setup(A)
m, n = size(A)
minmn = min(m, n)
Q, R = qr_compact(A)
T = eltype(A)
ΔN = Q * randn!(similar(A, T, minmn, max(0, m - minmn)))
N = qr_null(A)
ΔN = randn!(copy(N))
remove_qr_null_gauge_dependence!(ΔN, A, N)
Expand All @@ -246,7 +271,6 @@ function ad_qr_full_setup(A)
remove_qr_gauge_dependence!(ΔQR..., A, QR...)
return QR, ΔQR
end

ad_qr_full_setup(A::Diagonal) = ad_qr_compact_setup(A)

function ad_qr_rank_deficient_compact_setup(A)
Expand Down Expand Up @@ -516,8 +540,8 @@ end
function ad_left_null_setup(A)
m, n = size(A)
T = eltype(A)
N = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n)))
ΔN = left_orth(A; alg = :qr)[1] * randn!(similar(A, T, min(m, n), m - min(m, n)))
N = left_orth(A)[1] * randn!(similar(A, T, min(m, n), m - min(m, n)))
ΔN = left_orth(A)[1] * randn!(similar(A, T, min(m, n), m - min(m, n)))
return N, ΔN
end

Expand All @@ -533,7 +557,7 @@ ad_right_orth_setup(A::Diagonal) = ad_left_orth_setup(A)
function ad_right_null_setup(A)
m, n = size(A)
T = eltype(A)
Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2]
ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A; alg = :lq)[2]
Nᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A)[2]
ΔNᴴ = randn!(similar(A, T, n - min(m, n), min(m, n))) * right_orth(A)[2]
return Nᴴ, ΔNᴴ
end
Loading
Loading