Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,17 @@ end
function ConstantColoringAlgorithm{:column}(
matrix_template::AbstractMatrix, color::Vector{Int}
)
S = convert(SparseMatrixCSC, matrix_template)
result = ColumnColoringResult(S, color)
bg = BipartiteGraph(matrix_template)
result = ColumnColoringResult(matrix_template, bg, color)
M, R = typeof(matrix_template), typeof(result)
return ConstantColoringAlgorithm{:column,M,R}(matrix_template, color, result)
end

function ConstantColoringAlgorithm{:row}(
matrix_template::AbstractMatrix, color::Vector{Int}
)
S = convert(SparseMatrixCSC, matrix_template)
result = RowColoringResult(S, color)
bg = BipartiteGraph(matrix_template)
result = RowColoringResult(matrix_template, bg, color)
M, R = typeof(matrix_template), typeof(result)
return ConstantColoringAlgorithm{:row,M,R}(matrix_template, color, result)
end
Expand Down
116 changes: 55 additions & 61 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ true
- [`ColoringProblem`](@ref)
- [`AbstractColoringResult`](@ref)
"""
function decompress(B::AbstractMatrix{R}, result::AbstractColoringResult) where {R<:Real}
@compat (; S) = result
A = respectful_similar(S, R)
function decompress(B::AbstractMatrix, result::AbstractColoringResult)
A = respectful_similar(result.A, eltype(B))
return decompress!(A, B, result)
end

Expand Down Expand Up @@ -264,12 +263,11 @@ end

## ColumnColoringResult

function decompress!(
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, color) = result
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::ColumnColoringResult)
@compat (; color) = result
S = result.bg.S2
check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
rvS = rowvals(S)
for j in axes(S, 2)
cj = color[j]
Expand All @@ -282,9 +280,10 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, group) = result
A::AbstractMatrix, b::AbstractVector, c::Integer, result::ColumnColoringResult
)
@compat (; group) = result
S = result.bg.S2
check_same_pattern(A, S)
rvS = rowvals(S)
for j in group[c]
Expand All @@ -296,10 +295,9 @@ function decompress_single_color!(
return A
end

function decompress!(
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, compressed_indices) = result
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::ColumnColoringResult)
@compat (; compressed_indices) = result
S = result.bg.S2
check_same_pattern(A, S)
nzA = nonzeros(A)
for k in eachindex(nzA, compressed_indices)
Expand All @@ -309,9 +307,10 @@ function decompress!(
end

function decompress_single_color!(
A::SparseMatrixCSC{R}, b::AbstractVector{R}, c::Integer, result::ColumnColoringResult
) where {R<:Real}
@compat (; S, group) = result
A::SparseMatrixCSC, b::AbstractVector, c::Integer, result::ColumnColoringResult
)
@compat (; group) = result
S = result.bg.S2
check_same_pattern(A, S)
rvS = rowvals(S)
nzA = nonzeros(A)
Expand All @@ -326,12 +325,11 @@ end

## RowColoringResult

function decompress!(
A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::RowColoringResult
) where {R<:Real}
@compat (; S, color) = result
function decompress!(A::AbstractMatrix, B::AbstractMatrix, result::RowColoringResult)
@compat (; color) = result
S = result.bg.S2
check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
rvS = rowvals(S)
for j in axes(S, 2)
for k in nzrange(S, j)
Expand All @@ -344,9 +342,10 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R}, b::AbstractVector{R}, c::Integer, result::RowColoringResult
) where {R<:Real}
@compat (; S, Sᵀ, group) = result
A::AbstractMatrix, b::AbstractVector, c::Integer, result::RowColoringResult
)
@compat (; group) = result
S, Sᵀ = result.bg.S2, result.bg.S1
check_same_pattern(A, S)
rvSᵀ = rowvals(Sᵀ)
for i in group[c]
Expand All @@ -358,10 +357,9 @@ function decompress_single_color!(
return A
end

function decompress!(
A::SparseMatrixCSC{R}, B::AbstractMatrix{R}, result::RowColoringResult
) where {R<:Real}
@compat (; S, compressed_indices) = result
function decompress!(A::SparseMatrixCSC, B::AbstractMatrix, result::RowColoringResult)
@compat (; compressed_indices) = result
S = result.bg.S2
check_same_pattern(A, S)
nzA = nonzeros(A)
for k in eachindex(nzA, compressed_indices)
Expand All @@ -373,15 +371,13 @@ end
## StarSetColoringResult

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, star_set) = result
A::AbstractMatrix, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
)
@compat (; color, star_set) = result
@compat (; star, hub, spokes) = star_set
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
A .= zero(R)
fill!(A, zero(eltype(A)))
for i in axes(A, 1)
if !iszero(S[i, i])
A[i, i] = B[i, color[i]]
Expand All @@ -403,14 +399,15 @@ function decompress!(
end

function decompress_single_color!(
A::AbstractMatrix{R},
b::AbstractVector{R},
A::AbstractMatrix,
b::AbstractVector,
c::Integer,
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, group, star_set) = result
)
@compat (; color, group, star_set) = result
@compat (; hub, spokes) = star_set
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
for i in axes(A, 1)
if !iszero(S[i, i]) && color[i] == c
Expand All @@ -434,12 +431,10 @@ function decompress_single_color!(
end

function decompress!(
A::SparseMatrixCSC{R},
B::AbstractMatrix{R},
result::StarSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, compressed_indices) = result
A::SparseMatrixCSC, B::AbstractMatrix, result::StarSetColoringResult, uplo::Symbol=:F
)
@compat (; compressed_indices) = result
S = result.ag.S
nzA = nonzeros(A)
if uplo == :F
check_same_pattern(A, S)
Expand Down Expand Up @@ -468,14 +463,13 @@ end
# TODO: add method for A::SparseMatrixCSC

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
result::TreeSetColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (; S, color, vertices_by_tree, reverse_bfs_orders, buffer) = result
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
)
@compat (; color, vertices_by_tree, reverse_bfs_orders, buffer) = result
S = result.ag.S
uplo == :F && check_same_pattern(A, S)
A .= zero(R)
R = eltype(A)
fill!(A, zero(R))

if eltype(buffer) == R
buffer_right_type = buffer
Expand Down Expand Up @@ -513,19 +507,19 @@ end
## MatrixInverseColoringResult

function decompress!(
A::AbstractMatrix{R},
B::AbstractMatrix{R},
A::AbstractMatrix,
B::AbstractMatrix,
result::LinearSystemColoringResult,
uplo::Symbol=:F,
) where {R<:Real}
@compat (;
S, color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A
) = result
)
@compat (; color, strict_upper_nonzero_inds, T_factorization, strict_upper_nonzeros_A) =
result
S = result.ag.S
uplo == :F && check_same_pattern(A, S)

# TODO: for some reason I cannot use ldiv! with a sparse QR
strict_upper_nonzeros_A = T_factorization \ vec(B)
A .= zero(R)
fill!(A, zero(eltype(A)))
for i in axes(A, 1)
if !iszero(S[i, i])
A[i, i] = B[i, color[i]]
Expand Down
17 changes: 17 additions & 0 deletions src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ end
SparsityPatternCSC(A::SparseMatrixCSC) = SparsityPatternCSC(A.m, A.n, A.colptr, A.rowval)

Base.size(S::SparsityPatternCSC) = (S.m, S.n)
Base.size(S::SparsityPatternCSC, d) = d::Integer <= 2 ? size(S)[d] : 1
Base.axes(S::SparsityPatternCSC, d::Integer) = Base.OneTo(size(S, d))

SparseArrays.nnz(S::SparsityPatternCSC) = length(S.rowval)
SparseArrays.rowvals(S::SparsityPatternCSC) = S.rowval
SparseArrays.nzrange(S::SparsityPatternCSC, j::Integer) = S.colptr[j]:(S.colptr[j + 1] - 1)
Expand Down Expand Up @@ -81,6 +84,15 @@ function Base.transpose(S::SparsityPatternCSC{T}) where {T}
return SparsityPatternCSC{T}(n, m, B_colptr, B_rowval)
end

# copied from SparseArrays.jl
function Base.getindex(S::SparsityPatternCSC, i0::Integer, i1::Integer)
r1 = Int(S.colptr[i1])
r2 = Int(S.colptr[i1 + 1] - 1)
(r1 > r2) && return false
r1 = searchsortedfirst(rowvals(S), i0, r1, r2, Base.Order.Forward)
return ((r1 > r2) || (rowvals(S)[r1] != i0)) ? false : true
end

## Adjacency graph

"""
Expand Down Expand Up @@ -109,6 +121,7 @@ struct AdjacencyGraph{T}
S::SparsityPatternCSC{T}
end

AdjacencyGraph(A::AbstractMatrix) = AdjacencyGraph(SparseMatrixCSC(A))
AdjacencyGraph(A::SparseMatrixCSC) = AdjacencyGraph(SparsityPatternCSC(A))

pattern(g::AdjacencyGraph) = g.S
Expand Down Expand Up @@ -183,6 +196,10 @@ struct BipartiteGraph{T<:Integer}
S2::SparsityPatternCSC{T}
end

function BipartiteGraph(A::AbstractMatrix; symmetric_pattern::Bool=false)
return BipartiteGraph(SparseMatrixCSC(A); symmetric_pattern)
end

function BipartiteGraph(A::SparseMatrixCSC; symmetric_pattern::Bool=false)
S2 = SparsityPatternCSC(A) # columns to rows
if symmetric_pattern
Expand Down
29 changes: 11 additions & 18 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,11 @@ function coloring(
decompression_eltype::Type=Float64,
symmetric_pattern::Bool=false,
)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
)
color = partial_distance2_coloring(bg, Val(2), algo.order)
return ColumnColoringResult(S, color)
return ColumnColoringResult(A, bg, color)
end

function coloring(
Expand All @@ -195,12 +194,11 @@ function coloring(
decompression_eltype::Type=Float64,
symmetric_pattern::Bool=false,
)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(
S; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
A; symmetric_pattern=symmetric_pattern || A isa Union{Symmetric,Hermitian}
)
color = partial_distance2_coloring(bg, Val(1), algo.order)
return RowColoringResult(S, color)
return RowColoringResult(A, bg, color)
end

function coloring(
Expand All @@ -209,10 +207,9 @@ function coloring(
algo::GreedyColoringAlgorithm{:direct};
decompression_eltype::Type=Float64,
)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, star_set = star_coloring(ag, algo.order)
return StarSetColoringResult(S, color, star_set)
return StarSetColoringResult(A, ag, color, star_set)
end

function coloring(
Expand All @@ -221,31 +218,27 @@ function coloring(
algo::GreedyColoringAlgorithm{:substitution};
decompression_eltype::Type=Float64,
)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, tree_set = acyclic_coloring(ag, algo.order)
return TreeSetColoringResult(S, color, tree_set, decompression_eltype)
return TreeSetColoringResult(A, ag, color, tree_set, decompression_eltype)
end

## ADTypes interface

function ADTypes.column_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
color = partial_distance2_coloring(bg, Val(2), algo.order)
return color
end

function ADTypes.row_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
bg = BipartiteGraph(S; symmetric_pattern=A isa Union{Symmetric,Hermitian})
bg = BipartiteGraph(A; symmetric_pattern=A isa Union{Symmetric,Hermitian})
color = partial_distance2_coloring(bg, Val(1), algo.order)
return color
end

function ADTypes.symmetric_coloring(A::AbstractMatrix, algo::GreedyColoringAlgorithm)
S = convert(SparseMatrixCSC, A)
ag = AdjacencyGraph(S)
ag = AdjacencyGraph(A)
color, star_set = star_coloring(ag, algo.order)
return color
end
13 changes: 7 additions & 6 deletions src/matrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,23 @@ function respectful_similar(A::Union{Symmetric,Hermitian}, ::Type{T}) where {T}
end

"""
same_pattern(A::AbstractMatrix, B::AbstractMatrix)
same_pattern(A, B)

Perform a partial equality check on the sparsity patterns of `A` and `B`:

- if the return is `true`, they might have the same sparsity pattern but we're not sure
- if the return is `false`, they definitely don't have the same sparsity pattern
"""
function same_pattern(A::AbstractMatrix, B::AbstractMatrix)
return size(A) == size(B)
end
same_pattern(A, B) = size(A) == size(B)

function same_pattern(A::SparseMatrixCSC, B::SparseMatrixCSC)
function same_pattern(
A::Union{SparseMatrixCSC,SparsityPatternCSC},
B::Union{SparseMatrixCSC,SparsityPatternCSC},
)
return size(A) == size(B) && nnz(A) == nnz(B)
end

function check_same_pattern(A::AbstractMatrix, S::AbstractMatrix)
function check_same_pattern(A, S)
if !same_pattern(A, S)
throw(DimensionMismatch("`A` and `S` must have the same sparsity pattern."))
end
Expand Down
Loading