Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions src/coloring.jl
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ $TYPEDFIELDS
struct TreeSet{T}
reverse_bfs_orders::Vector{Tuple{T,T}}
is_star::Vector{Bool}
num_edges_per_tree::Vector{T}
tree_edge_indices::Vector{T}
nt::T
end

function TreeSet(
Expand All @@ -388,20 +389,20 @@ function TreeSet(
S = pattern(g)
edge_to_index = edge_indices(g)
nv = nb_vertices(g)
nt = forest.num_trees
(; nt, ranks) = forest

# root_to_tree is a vector that maps a tree's root to the index of the tree
# We can recycle forest.ranks because we don't need it anymore to merge trees
root_to_tree = forest.ranks
# We can recycle the vector "ranks" because we don't need it anymore to merge trees
root_to_tree = ranks
fill!(root_to_tree, zero(T))

# Contains the number of edges per tree
num_edges_per_tree = zeros(T, nt)
# vector specifying the starting and ending indices of edges for each tree
tree_edge_indices = zeros(T, nt + 1)

# vector of dictionaries where each dictionary stores the neighbors of each vertex in a tree
trees = [Dict{T,Vector{T}}() for i in 1:nt]

# current number of roots found
# number of roots found
nr = 0

rvS = rowvals(S)
Expand All @@ -418,18 +419,20 @@ function TreeSet(
root_to_tree[root] = nr
end

# index of the tree T that contains this edge
# index of the tree that contains this edge
index_tree = root_to_tree[root]
num_edges_per_tree[index_tree] += 1

# Update the neighbors of i in the tree T
# Update the number of edges for the current tree (shifted by 1 to facilitate the final cumsum)
tree_edge_indices[index_tree + 1] += 1

# Update the neighbors of i in the current tree
if !haskey(trees[index_tree], i)
trees[index_tree][i] = [j]
else
push!(trees[index_tree][i], j)
end

# Update the neighbors of j in the tree T
# Update the neighbors of j in the current tree
if !haskey(trees[index_tree], j)
trees[index_tree][j] = [i]
else
Expand All @@ -439,6 +442,12 @@ function TreeSet(
end
end

# Compute a shifted cumulative sum of tree_edge_indices, starting from one
tree_edge_indices[1] = one(T)
for k in 2:(nt + 1)
tree_edge_indices[k] += tree_edge_indices[k - 1]
end

# degrees is a vector of integers that stores the degree of each vertex in a tree
degrees = buffer

Expand Down Expand Up @@ -529,7 +538,7 @@ function TreeSet(
is_star[k] = bool_star
end

return TreeSet(reverse_bfs_orders, is_star, num_edges_per_tree)
return TreeSet(reverse_bfs_orders, is_star, tree_edge_indices, nt)
end

## Postprocessing, mirrors decompression code
Expand Down Expand Up @@ -597,15 +606,17 @@ function postprocess!(
end
else
# only the colors of non-leaf vertices are used
(; reverse_bfs_orders, is_star, num_edges_per_tree) = star_or_tree_set
(; reverse_bfs_orders, is_star, tree_edge_indices, nt) = star_or_tree_set
nb_trivial_trees = 0

# Index of the first edge in reverse_bfs_orders for the current tree
first = 1

# Iterate through all non-trivial trees
for k in eachindex(num_edges_per_tree)
ne_tree = num_edges_per_tree[k]
for k in 1:nt
# Position of the first edge in the tree
first = tree_edge_indices[k]

# Total number of edges in the tree
ne_tree = tree_edge_indices[k + 1] - first

# Check if we have more than one edge in the tree (non-trivial tree)
if ne_tree > 1
# Determine if the tree is a star
Expand All @@ -622,14 +633,17 @@ function postprocess!(
else
nb_trivial_trees += 1
end
first += ne_tree
end

# Process the trivial trees (if any)
if nb_trivial_trees > 0
first = 1
for k in eachindex(num_edges_per_tree)
ne_tree = num_edges_per_tree[k]
for k in 1:nt
# Position of the first edge in the tree
first = tree_edge_indices[k]

# Total number of edges in the tree
ne_tree = tree_edge_indices[k + 1] - first

# Check if we have exactly one edge in the tree
if ne_tree == 1
(i, j) = reverse_bfs_orders[first]
Expand All @@ -642,7 +656,6 @@ function postprocess!(
color_used[color[j]] = true
end
end
first += ne_tree
end
end
end
Expand Down
27 changes: 11 additions & 16 deletions src/decompression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ end
function decompress!(
A::AbstractMatrix, B::AbstractMatrix, result::TreeSetColoringResult, uplo::Symbol=:F
)
(; ag, color, reverse_bfs_orders, num_edges_per_tree, buffer) = result
(; ag, color, reverse_bfs_orders, tree_edge_indices, nt, buffer) = result
(; S) = ag
uplo == :F && check_same_pattern(A, S)
R = eltype(A)
Expand All @@ -538,13 +538,11 @@ function decompress!(
end
end

# Index of the first edge in reverse_bfs_orders for the current tree
first = 1

# Recover the off-diagonal coefficients of A
for k in eachindex(num_edges_per_tree)
ne_tree = num_edges_per_tree[k]
last = first + ne_tree - 1
for k in 1:nt
# Positions of the edges for each tree
first = tree_edge_indices[k]
last = tree_edge_indices[k + 1] - 1

# Reset the buffer to zero for all vertices in a tree (except the root)
for pos in first:last
Expand All @@ -567,7 +565,6 @@ function decompress!(
A[j, i] = val
end
end
first += ne_tree
end
return A
end
Expand All @@ -582,7 +579,8 @@ function decompress!(
ag,
color,
reverse_bfs_orders,
num_edges_per_tree,
tree_edge_indices,
nt,
diagonal_indices,
diagonal_nzind,
lower_triangle_offsets,
Expand Down Expand Up @@ -622,16 +620,14 @@ function decompress!(
end
end

# Index of the first edge in reverse_bfs_orders for the current tree
first = 1

# Index of offsets in lower_triangle_offsets and upper_triangle_offsets
counter = 0

# Recover the off-diagonal coefficients of A
for k in eachindex(num_edges_per_tree)
ne_tree = num_edges_per_tree[k]
last = first + ne_tree - 1
for k in 1:nt
# Positions of the edges for each tree
first = tree_edge_indices[k]
last = tree_edge_indices[k + 1] - 1

# Reset the buffer to zero for all vertices in a tree (except the root)
for pos in first:last
Expand Down Expand Up @@ -683,7 +679,6 @@ function decompress!(
end
#! format: on
end
first += ne_tree
end
return A
end
Expand Down
8 changes: 4 additions & 4 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,18 @@ $TYPEDFIELDS
"""
mutable struct Forest{T<:Integer}
"current number of distinct trees in the forest"
num_trees::T
nt::T
"vector storing the index of a parent in the tree for each edge, used in union-find operations"
parents::Vector{T}
"vector approximating the depth of each tree to optimize path compression"
ranks::Vector{T}
end

function Forest{T}(n::Integer) where {T<:Integer}
num_trees = T(n)
nt = T(n)
parents = collect(Base.OneTo(T(n)))
ranks = zeros(T, T(n))
return Forest{T}(num_trees, parents, ranks)
return Forest{T}(nt, parents, ranks)
end

function _find_root!(parents::Vector{T}, index_edge::T) where {T<:Integer}
Expand All @@ -49,6 +49,6 @@ function root_union!(forest::Forest{T}, index_edge1::T, index_edge2::T) where {T
rks[index_edge1] += one(T)
end
parents[index_edge2] = index_edge1
forest.num_trees -= one(T)
forest.nt -= one(T)
return nothing
end
19 changes: 9 additions & 10 deletions src/result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ struct TreeSetColoringResult{
color::Vector{T}
group::GT
reverse_bfs_orders::Vector{Tuple{T,T}}
num_edges_per_tree::Vector{T}
tree_edge_indices::Vector{T}
nt::T
diagonal_indices::Vector{T}
diagonal_nzind::Vector{T}
lower_triangle_offsets::Vector{T}
Expand All @@ -329,7 +330,7 @@ function TreeSetColoringResult(
tree_set::TreeSet{<:Integer},
decompression_eltype::Type{R},
) where {T<:Integer,R}
(; reverse_bfs_orders, num_edges_per_tree) = tree_set
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_set
(; S) = ag
nvertices = length(color)
group = group_by_color(T, color)
Expand Down Expand Up @@ -358,15 +359,13 @@ function TreeSetColoringResult(
lower_triangle_offsets = Vector{T}(undef, nedges)
upper_triangle_offsets = Vector{T}(undef, nedges)

# Index of the first edge in reverse_bfs_orders for the current tree
first = 1

# Index in lower_triangle_offsets and upper_triangle_offsets
index_offsets = 0

for k in eachindex(num_edges_per_tree)
ne_tree = num_edges_per_tree[k]
last = first + ne_tree - 1
for k in 1:nt
# Positions of the edges for each tree
first = tree_edge_indices[k]
last = tree_edge_indices[k + 1] - 1

for pos in first:last
(leaf, neighbor) = reverse_bfs_orders[pos]
Expand Down Expand Up @@ -400,7 +399,6 @@ function TreeSetColoringResult(
end
#! format: on
end
first += ne_tree
end

# buffer holds the sum of edge values for subtrees in a tree.
Expand All @@ -413,7 +411,8 @@ function TreeSetColoringResult(
color,
group,
reverse_bfs_orders,
num_edges_per_tree,
tree_edge_indices,
nt,
diagonal_indices,
diagonal_nzind,
lower_triangle_offsets,
Expand Down
6 changes: 3 additions & 3 deletions test/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Test

@testset "Constructor Forest" begin
forest = Forest{Int}(5)
@test forest.num_trees == 5
@test forest.nt == 5
@test length(forest.parents) == 5
@test all(forest.parents .== 1:5)
@test all(forest.ranks .== 0)
Expand All @@ -27,7 +27,7 @@ end
@test forest.parents[3] == 1
@test forest.ranks[1] == 1
@test forest.ranks[3] == 0
@test forest.num_trees == 4
@test forest.nt == 4

root1 = find_root!(forest, 1)
root2 = find_root!(forest, 2)
Expand All @@ -39,5 +39,5 @@ end
@test forest.parents[2] == 1
@test forest.ranks[1] == 1
@test forest.ranks[2] == 0
@test forest.num_trees == 3
@test forest.nt == 3
end