Skip to content

Commit f2472bf

Browse files
committed
feat: add pipeline dag plotting
1 parent b476274 commit f2472bf

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"beartype>=0.21.0",
1919
"deltalake>=1.0.2",
2020
"selection-pipeline",
21+
"graphviz>=0.21",
2122
]
2223
readme = "README.md"
2324
requires-python = ">=3.11.0"

src/orcapod/pipeline/graph.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,17 @@
66
from orcapod.protocols import database_protocols as dbp
77
from typing import Any
88
from collections.abc import Collection
9+
import os
10+
import tempfile
911
import logging
1012
import asyncio
13+
from typing import TYPE_CHECKING
14+
from orcapod.utils.lazy_module import LazyModule
15+
16+
if TYPE_CHECKING:
17+
import networkx as nx
18+
else:
19+
nx = LazyModule("networkx")
1120

1221

1322
def synchronous_run(async_func, *args, **kwargs):
@@ -109,11 +118,16 @@ def compile(self) -> None:
109118

110119
invocation_to_stream_lut = {}
111120
G = self.generate_graph()
121+
node_graph = nx.DiGraph()
112122
for invocation in nx.topological_sort(G):
113123
input_streams = [
114124
invocation_to_stream_lut[parent] for parent in invocation.parents()
115125
]
126+
116127
node = self.wrap_invocation(invocation, new_input_streams=input_streams)
128+
for parent in node.upstreams:
129+
node_graph.add_edge(parent.source, node)
130+
117131
invocation_to_stream_lut[invocation] = node()
118132
name_candidates.setdefault(node.label, []).append(node)
119133

@@ -127,6 +141,13 @@ def compile(self) -> None:
127141
else:
128142
self.nodes[label] = nodes[0]
129143

144+
self.label_lut = {v: k for k, v in self.nodes.items()}
145+
146+
self.graph = node_graph
147+
148+
def show_graph(self, **kwargs) -> None:
149+
render_graph(self.graph, self.label_lut, **kwargs)
150+
130151
def run(
131152
self,
132153
execution_engine: cp.ExecutionEngine | None = None,
@@ -217,3 +238,199 @@ def rename(self, old_name: str, new_name: str) -> None:
217238
node.label = new_name
218239
self.nodes[new_name] = node
219240
logger.info(f"Node '{old_name}' renamed to '{new_name}'")
241+
242+
243+
# import networkx as nx
244+
# # import graphviz
245+
# import matplotlib.pyplot as plt
246+
# import matplotlib.image as mpimg
247+
# import tempfile
248+
# import os
249+
250+
251+
class GraphRenderer:
252+
"""Simple renderer for NetworkX graphs using Graphviz DOT format"""
253+
254+
def __init__(self):
255+
"""Initialize the renderer"""
256+
pass
257+
258+
def _sanitize_node_id(self, node_id: Any) -> str:
259+
"""Convert node_id to a valid DOT identifier using hash"""
260+
return f"node_{hash(node_id)}"
261+
262+
def _get_node_label(
263+
self, node_id: Any, label_lut: dict[Any, str] | None = None
264+
) -> str:
265+
"""Get label for a node"""
266+
if label_lut and node_id in label_lut:
267+
return label_lut[node_id]
268+
return str(node_id)
269+
270+
def generate_dot(
271+
self,
272+
graph: "nx.DiGraph",
273+
label_lut: dict[Any, str] | None = None,
274+
rankdir: str = "TB",
275+
node_shape: str = "box",
276+
node_style: str = "filled",
277+
node_color: str = "lightblue",
278+
edge_color: str = "black",
279+
dpi: int = 150,
280+
) -> str:
281+
"""
282+
Generate DOT syntax from NetworkX graph
283+
284+
Args:
285+
graph: NetworkX DiGraph to render
286+
label_lut: Optional dictionary mapping node_id -> display_label
287+
rankdir: Graph direction ('TB', 'BT', 'LR', 'RL')
288+
node_shape: Shape for all nodes
289+
node_style: Style for all nodes
290+
node_color: Fill color for all nodes
291+
edge_color: Color for all edges
292+
dpi: Resolution for rendered image (default 150)
293+
294+
Returns:
295+
DOT format string
296+
"""
297+
try:
298+
import graphviz
299+
except ImportError as e:
300+
raise ImportError(
301+
"Graphviz is not installed. Please install graphviz to render graph of the pipeline."
302+
) from e
303+
304+
dot = graphviz.Digraph(comment="NetworkX Graph")
305+
306+
# Set graph attributes
307+
dot.attr(rankdir=rankdir, dpi=str(dpi))
308+
dot.attr("node", shape=node_shape, style=node_style, fillcolor=node_color)
309+
dot.attr("edge", color=edge_color)
310+
311+
# Add nodes
312+
for node_id in graph.nodes():
313+
sanitized_id = self._sanitize_node_id(node_id)
314+
label = self._get_node_label(node_id, label_lut)
315+
dot.node(sanitized_id, label=label)
316+
317+
# Add edges
318+
for source, target in graph.edges():
319+
source_id = self._sanitize_node_id(source)
320+
target_id = self._sanitize_node_id(target)
321+
dot.edge(source_id, target_id)
322+
323+
return dot.source
324+
325+
def render_graph(
326+
self,
327+
graph: nx.DiGraph,
328+
label_lut: dict[Any, str] | None = None,
329+
show: bool = True,
330+
output_path: str | None = None,
331+
raw_output: bool = False,
332+
rankdir: str = "TB",
333+
figsize: tuple = (6, 4),
334+
dpi: int = 150,
335+
**style_kwargs,
336+
) -> str | None:
337+
"""
338+
Render NetworkX graph using Graphviz
339+
340+
Args:
341+
graph: NetworkX DiGraph to render
342+
label_lut: Optional dictionary mapping node_id -> display_label
343+
show: Display the graph using matplotlib
344+
output_path: Save graph to file (e.g., 'graph.png', 'graph.pdf')
345+
raw_output: Return DOT syntax instead of rendering
346+
rankdir: Graph direction ('TB', 'BT', 'LR', 'RL')
347+
figsize: Figure size for matplotlib display
348+
dpi: Resolution for rendered image (default 150)
349+
**style_kwargs: Additional styling (node_color, edge_color, node_shape, etc.)
350+
351+
Returns:
352+
DOT syntax if raw_output=True, None otherwise
353+
"""
354+
try:
355+
import graphviz
356+
except ImportError as e:
357+
raise ImportError(
358+
"Graphviz is not installed. Please install graphviz to render graph of the pipeline."
359+
) from e
360+
361+
if raw_output:
362+
return self.generate_dot(graph, label_lut, rankdir, dpi=dpi, **style_kwargs)
363+
364+
# Create Graphviz object
365+
dot = graphviz.Digraph(comment="NetworkX Graph")
366+
dot.attr(rankdir=rankdir, dpi=str(dpi))
367+
368+
# Apply styling
369+
node_shape = style_kwargs.get("node_shape", "box")
370+
node_style = style_kwargs.get("node_style", "filled")
371+
node_color = style_kwargs.get("node_color", "lightblue")
372+
edge_color = style_kwargs.get("edge_color", "black")
373+
374+
dot.attr("node", shape=node_shape, style=node_style, fillcolor=node_color)
375+
dot.attr("edge", color=edge_color)
376+
377+
# Add nodes with labels
378+
for node_id in graph.nodes():
379+
sanitized_id = self._sanitize_node_id(node_id)
380+
label = self._get_node_label(node_id, label_lut)
381+
dot.node(sanitized_id, label=label)
382+
383+
# Add edges
384+
for source, target in graph.edges():
385+
source_id = self._sanitize_node_id(source)
386+
target_id = self._sanitize_node_id(target)
387+
dot.edge(source_id, target_id)
388+
389+
# Handle output
390+
if output_path:
391+
# Save to file
392+
name, ext = os.path.splitext(output_path)
393+
format_type = ext[1:] if ext else "png"
394+
dot.render(name, format=format_type, cleanup=True)
395+
print(f"Graph saved to {output_path}")
396+
397+
if show:
398+
# Display with matplotlib
399+
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
400+
dot.render(tmp.name[:-4], format="png", cleanup=True)
401+
402+
import matplotlib.pyplot as plt
403+
import matplotlib.image as mpimg
404+
405+
# Display with matplotlib
406+
img = mpimg.imread(tmp.name)
407+
plt.figure(figsize=figsize)
408+
plt.imshow(img)
409+
plt.axis("off")
410+
plt.title("Graph Visualization")
411+
plt.tight_layout()
412+
plt.show()
413+
414+
# Clean up
415+
os.unlink(tmp.name)
416+
417+
return None
418+
419+
420+
# Convenience function for quick rendering
421+
def render_graph(
422+
graph: nx.DiGraph, label_lut: dict[Any, str] | None = None, **kwargs
423+
) -> str | None:
424+
"""
425+
Convenience function to quickly render a NetworkX graph
426+
427+
Args:
428+
graph: NetworkX DiGraph to render
429+
label_lut: Optional dictionary mapping node_id -> display_label
430+
**kwargs: All other arguments passed to GraphRenderer.render_graph()
431+
432+
Returns:
433+
DOT syntax if raw_output=True, None otherwise
434+
"""
435+
renderer = GraphRenderer()
436+
return renderer.render_graph(graph, label_lut, **kwargs)

uv.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)