forked from torch/graph
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNode.lua
More file actions
105 lines (95 loc) · 2.42 KB
/
Node.lua
File metadata and controls
105 lines (95 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
--[[
Node class. This class is generally used with edge to add edges into a graph.
graph:add(graph.Edge(graph.Node(),graph.Node()))
But, one can also easily use this node class to create a graph. It will register
all the edges into its children table and one can parse the graph from any given node.
The drawback is there will be no global edge table and node table, which is mostly useful
to run algorithms on graphs. If all you need is just a data structure to store data and
run DFS, BFS over the graph, then this method is also quick and nice.
--]]
local Node = torch.class('graph.Node')
function Node:__init(d,p)
self.data = d
self.id = 0
self.children = {}
self.visited = false
self.marked = false
end
function Node:add(child)
local children = self.children
if type(child) == 'table' and not torch.typename(child) then
for i,v in ipairs(child) do
self:add(v)
end
elseif not children[child] then
table.insert(children,child)
children[child] = #children
end
end
-- visitor
function Node:visit(pre_func,post_func)
if not self.visited then
if pre_func then pre_func(self) end
for i,child in ipairs(self.children) do
child:visit(pre_func, post_func)
end
if post_func then post_func(self) end
end
end
function Node:label()
return tostring(self.data)
end
-- Create a graph from the Node traversal
function Node:graph()
local g = graph.Graph()
local function build_graph(node)
for i,child in ipairs(node.children) do
g:add(graph.Edge(node,child))
end
end
self:bfs(build_graph)
return g
end
function Node:dfs_dirty(func)
local visitednodes = {}
local dfs_func = function(node)
func(node)
table.insert(visitednodes,node)
end
local dfs_func_pre = function(node)
node.visited = true
end
self:visit(dfs_func_pre, dfs_func)
return visitednodes
end
function Node:dfs(func)
for i,node in ipairs(self:dfs_dirty(func)) do
node.visited = false
end
end
function Node:bfs_dirty(func)
local visitednodes = {}
local bfsnodes = {}
local bfs_func = function(node)
func(node)
for i,child in ipairs(node.children) do
if not child.marked then
child.marked = true
table.insert(bfsnodes,child)
end
end
end
table.insert(bfsnodes,self)
self.marked = true
while #bfsnodes > 0 do
local node = table.remove(bfsnodes,1)
table.insert(visitednodes,node)
bfs_func(node)
end
return visitednodes
end
function Node:bfs(func)
for i,node in ipairs(self:bfs_dirty(func)) do
node.marked = false
end
end