Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ Enzyme = "0.13.118"
EnzymeTestUtils = "0.2.5"
GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
JET = "0.9, 0.10, 0.11"
LinearAlgebra = "1"
Mooncake = "0.5"
ParallelTestRunner = "2"
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
svd_rank(S, rank_atol) = searchsortedlast(S, rank_atol; rev = true)
svd_rank(S; rank_atol = default_pullback_rank_atol(S)) = searchsortedlast(S, rank_atol; rev = true)

function check_svd_cotangents(aUΔU, Sr, aVΔV; degeneracy_atol = default_pullback_rank_atol(Sr), gauge_atol = default_pullback_gauge_atol(aUΔU, aVΔV))
mask = abs.(Sr' .- Sr) .< degeneracy_atol
Expand Down Expand Up @@ -43,7 +43,7 @@ function svd_pullback!(
minmn = min(m, n)
S = diagview(Smat)
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
r = svd_rank(S, rank_atol)
r = svd_rank(S; rank_atol)
Ur = view(U, :, 1:r)
Vᴴr = view(Vᴴ, 1:r, :)
Sr = view(S, 1:r)
Expand Down
6 changes: 3 additions & 3 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module TestSuite
using Test
using MatrixAlgebraKit
using MatrixAlgebraKit: diagview
using LinearAlgebra: Diagonal, norm, istriu, istril, I
using LinearAlgebra: Diagonal, norm, istriu, istril, I, mul!
using Random, StableRNGs
using Mooncake
using AMDGPU, CUDA
Expand Down Expand Up @@ -85,9 +85,9 @@ function instantiate_unitary(T, A::ROCMatrix{<:Complex}, sz)
end
instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A), eltype(A), sz), one(eltype(A))))

function instantiate_rank_deficient_matrix(T, sz; trunc = trunctol(rtol = 0.5))
function instantiate_rank_deficient_matrix(T, sz; trunc = truncrank(div(min(sz...), 2)))
A = instantiate_matrix(T, sz)
V, C = left_orth!(A; trunc = trunctol(rtol = 0.5))
V, C = left_orth!(A; trunc)
return mul!(A, V, C)
end

Expand Down
Loading
Loading