Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Structured matrices
using LinearAlgebra: AbstractTriangular
using SparseInverseSubset

# Matrix wrapper types that we know are square and are thus potentially invertible. For
# these we can use simpler definitions for `/` and `\`.
Expand Down
163 changes: 84 additions & 79 deletions src/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,94 +50,99 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
return (I, V), findnz_pullback
end

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
j = p[k]
while j != k
if Base.USE_GPL_LIBS # Don't define rrules for sparse determinants if we don't have CHOLMOD from SuiteSparse.jl

if VERSION < v"1.7"
#=
The method below for `logabsdet(F::UmfpackLU)` is required to calculate the (log)
determinants of sparse matrices, but was not defined prior to Julia v1.7. In order
for the rrules for the determinants of sparse matrices below to work, they need to be
able to compute the primals as well, so this import from the future is included. For
more recent versions of Julia, this definition lives in:
julia/stdlib/SuiteSparse/src/umfpack.jl
=#
using SuiteSparse.UMFPACK: UmfpackLU

# compute the sign/parity of a permutation
function _signperm(p)
n = length(p)
result = 0
todo = trues(n)
while any(todo)
k = findfirst(todo)
todo[k] = false
result += 1 # increment element count
todo[j] = false
j = p[j]
j = p[k]
while j != k
result += 1 # increment element count
todo[j] = false
j = p[j]
end
result += 1 # increment cycle count
end
result += 1 # increment cycle count
return ifelse(isodd(result), -1, 1)
end
return ifelse(isodd(result), -1, 1)
end

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
using SparseInverseSubset

function LinearAlgebra.logabsdet(F::UmfpackLU{T, TI}) where {T<:Union{Float64,ComplexF64},TI<:Union{Int32, Int64}}
n = checksquare(F)
issuccess(F) || return log(zero(real(T))), zero(T)
U = F.U
Rs = F.Rs
p = F.p
q = F.q
s = _signperm(p)*_signperm(q)*one(real(T))
P = one(T)
abs_det = zero(real(T))
@inbounds for i in 1:n
dg_ii = U[i, i] / Rs[i]
P *= sign(dg_ii)
abs_det += log(abs(dg_ii))
end
return abs_det, s * P
end
return abs_det, s * P
end
end


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)


function rrule(::typeof(logabsdet), x::SparseMatrixCSC)
F = cholesky(x)
L, D, U, P = SparseInverseSubset.get_ldup(F)
Ω = logabsdet(D)
function logabsdet_pullback(ΔΩ)
(Δy, Δsigny) = ΔΩ
(_, signy) = Ω
f = signy' * Δsigny
imagf = f - real(f)
g = real(Δy) + imagf
Z, P = sparseinv(F, depermute=true)
∂x = g * Z'
return (NoTangent(), ∂x)
end
return Ω, logabsdet_pullback
end
return Ω, logabsdet_pullback
end

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)

function rrule(::typeof(logdet), x::SparseMatrixCSC)
Ω = logdet(x)
function logdet_pullback(ΔΩ)
Z, p = sparseinv(x, depermute=true)
∂x = ΔΩ * Z'
return (NoTangent(), ∂x)
end
return Ω, logdet_pullback
end
return Ω, logdet_pullback
end

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)

function rrule(::typeof(det), x::SparseMatrixCSC)
Ω = det(x)
function det_pullback(ΔΩ)
Z, _ = sparseinv(x, depermute=true)
∂x = Z' * dot(Ω, ΔΩ)
return (NoTangent(), ∂x)
end
return Ω, det_pullback
end
return Ω, det_pullback
end


end # rrules that depend on CHOLMOD

function rrule(::typeof(spdiagm), m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractVector}...)

Expand Down
18 changes: 10 additions & 8 deletions test/rulesets/SparseArrays/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,14 @@ end
test_rrule(findnz, v ⊢ dv, output_tangent=(zeros(length(I)), V̄))
end

@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
if Base.USE_GPL_LIBS # these rrules don't work without CHOLMOD from SuiteSparse.jl
@testset "[log[abs[det]]] SparseMatrixCSC" begin
ii = [1:5; 2; 4]
jj = [1:5; 4; 2]
x = [ones(5); 0.1; 0.1]
A = sparse(ii, jj, x)
test_rrule(logabsdet, A)
test_rrule(logdet, A)
test_rrule(det, A)
end
end