diff --git a/src/coloring.jl b/src/coloring.jl index 6a95f054..c64aef32 100644 --- a/src/coloring.jl +++ b/src/coloring.jl @@ -328,7 +328,7 @@ function acyclic_coloring(g::Graph, order::AbstractOrder) end end end - return color, TreeSet(forest) + return color, TreeSet(forest, nvertices) end function _prevent_cycle!( @@ -408,5 +408,80 @@ $TYPEDFIELDS """ struct TreeSet "a forest of two-colored trees" - forest::DisjointSets{Tuple{Int,Int}} + trees::Vector{Vector{Tuple{Int,Int}}} + "???" + nodes::Vector{Vector{Int}} + "???" + stored_values::Vector{Float64} +end + +function TreeSet(forest::DisjointSets, nvertices::Int) + # forest is a structure DisjointSets from DataStructures.jl + # - forest.intmap: a dictionary that maps an edge (i, j) to an integer k + # - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j) + # - forest.internal.ngroups: the number of trees in the forest + ntrees = forest.internal.ngroups + + # vector of trees where each tree contains the indices of its edges + trees = [Int[] for i in 1:ntrees] + + # dictionary that maps a tree's root to the index of the tree + roots = Dict{Int,Int}() + + k = 0 + for edge in forest.revmap + root_edge = find_root!(forest, edge) + root = forest.intmap[root_edge] + if !haskey(roots, root) + k += 1 + roots[root] = k + end + index_tree = roots[root] + push!(trees[index_tree], forest.intmap[edge]) + end + + # vector of dictionaries where each dictionary stores the degree of each vertex in a tree + degrees = [Dict{Int,Int}() for k in 1:ntrees] + for k in 1:ntrees + tree = trees[k] + degree = degrees[k] + for edge_index in tree + i, j = forest.revmap[edge_index] + !haskey(degree, i) && (degree[i] = 0) + !haskey(degree, j) && (degree[j] = 0) + degree[i] += 1 + degree[j] += 1 + end + end + + # depth-first search (DFS) traversal order for each tree in the forest + dfs_orders = [Vector{Tuple{Int,Int}}() for k in 1:ntrees] + for k in 1:ntrees + tree = trees[k] + degree = degrees[k] + while sum(values(degree)) != 0 + for (t, edge_index) in enumerate(tree) + if edge_index != 0 + i, j = forest.revmap[edge_index] + if (degree[i] == 1) || (degree[j] == 1) # leaf vertex + if degree[i] > degree[j] # vertex i is the parent of vertex j + i, j = j, i # ensure that i always denotes a leaf vertex + end + degree[i] -= 1 # decrease the degree of vertex i + degree[j] -= 1 # decrease the degree of vertex j + tree[t] = 0 # remove the edge (i,j) + push!(dfs_orders[k], (i, j)) + end + end + end + end + end + + # stored_values holds the sum of edge values for subtrees in a tree. + # For each vertex i, stored_values[i] is the sum of edge values in the subtree rooted at i. + stored_values = Vector{Float64}(undef, nvertices) + + nodes = [[vertex for vertex in keys(degrees[k])] for k = 1:ntrees] + + return TreeSet(dfs_orders, nodes, stored_values) end diff --git a/src/decompression.jl b/src/decompression.jl index 6392cae7..e77cc36c 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -303,72 +303,8 @@ function decompress_aux!( A .= zero(R) S = get_matrix(result) color = column_colors(result) - - # forest is a structure DisjointSets from DataStructures.jl - # - forest.intmap: a dictionary that maps an edge (i, j) to an integer k - # - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j) - # - forest.internal.ngroups: the number of trees in the forest - forest = result.tree_set.forest - ntrees = forest.internal.ngroups - - # vector of trees where each tree contains the indices of its edges - trees = [Int[] for i in 1:ntrees] - - # dictionary that maps a tree's root to the index of the tree - roots = Dict{Int,Int}() - - k = 0 - for edge in forest.revmap - root_edge = find_root!(forest, edge) - root = forest.intmap[root_edge] - if !haskey(roots, root) - k += 1 - roots[root] = k - end - index_tree = roots[root] - push!(trees[index_tree], forest.intmap[edge]) - end - - # vector of dictionaries where each dictionary stores the degree of each vertex in a tree - degrees = [Dict{Int,Int}() for k in 1:ntrees] - for k in 1:ntrees - tree = trees[k] - degree = degrees[k] - for edge_index in tree - i, j = forest.revmap[edge_index] - !haskey(degree, i) && (degree[i] = 0) - !haskey(degree, j) && (degree[j] = 0) - degree[i] += 1 - degree[j] += 1 - end - end - - # depth-first search (DFS) traversal order for each tree in the forest - dfs_orders = [Vector{Tuple{Int,Int}}() for k in 1:ntrees] - for k in 1:ntrees - tree = trees[k] - degree = degrees[k] - while sum(values(degree)) != 0 - for (t, edge_index) in enumerate(tree) - if edge_index != 0 - i, j = forest.revmap[edge_index] - if (degree[i] == 1) || (degree[j] == 1) # leaf vertex - if degree[i] > degree[j] # vertex i is the parent of vertex j - i, j = j, i # ensure that i always denotes a leaf vertex - end - degree[i] -= 1 # decrease the degree of vertex i - degree[j] -= 1 # decrease the degree of vertex j - tree[t] = 0 # remove the edge (i,j) - push!(dfs_orders[k], (i, j)) - end - end - end - end - end - - # stored_values holds the sum of edge values for subtrees in a tree. - # For each vertex i, stored_values[i] is the sum of edge values in the subtree rooted at i. - stored_values = Vector{R}(undef, n) + @compat (; trees, nodes, stored_values) = result.tree_set + ntrees = length(trees) # Recover the diagonal coefficients of A for i in axes(A, 1) @@ -379,12 +315,12 @@ function decompress_aux!( # Recover the off-diagonal coefficients of A for k in 1:ntrees - vertices = keys(degrees[k]) + vertices = nodes[k] for vertex in vertices stored_values[vertex] = zero(R) end - tree = dfs_orders[k] + tree = trees[k] for (i, j) in tree val = B[i, color[j]] - stored_values[i] stored_values[j] = stored_values[j] + val