66from orcapod .protocols import database_protocols as dbp
77from typing import Any
88from collections .abc import Collection
9+ import os
10+ import tempfile
911import logging
1012import asyncio
13+ from typing import TYPE_CHECKING
14+ from orcapod .utils .lazy_module import LazyModule
15+
16+ if TYPE_CHECKING :
17+ import networkx as nx
18+ else :
19+ nx = LazyModule ("networkx" )
1120
1221
1322def synchronous_run (async_func , * args , ** kwargs ):
@@ -109,11 +118,16 @@ def compile(self) -> None:
109118
110119 invocation_to_stream_lut = {}
111120 G = self .generate_graph ()
121+ node_graph = nx .DiGraph ()
112122 for invocation in nx .topological_sort (G ):
113123 input_streams = [
114124 invocation_to_stream_lut [parent ] for parent in invocation .parents ()
115125 ]
126+
116127 node = self .wrap_invocation (invocation , new_input_streams = input_streams )
128+ for parent in node .upstreams :
129+ node_graph .add_edge (parent .source , node )
130+
117131 invocation_to_stream_lut [invocation ] = node ()
118132 name_candidates .setdefault (node .label , []).append (node )
119133
@@ -127,6 +141,13 @@ def compile(self) -> None:
127141 else :
128142 self .nodes [label ] = nodes [0 ]
129143
144+ self .label_lut = {v : k for k , v in self .nodes .items ()}
145+
146+ self .graph = node_graph
147+
148+ def show_graph (self , ** kwargs ) -> None :
149+ render_graph (self .graph , self .label_lut , ** kwargs )
150+
130151 def run (
131152 self ,
132153 execution_engine : cp .ExecutionEngine | None = None ,
@@ -217,3 +238,199 @@ def rename(self, old_name: str, new_name: str) -> None:
217238 node .label = new_name
218239 self .nodes [new_name ] = node
219240 logger .info (f"Node '{ old_name } ' renamed to '{ new_name } '" )
241+
242+
243+ # import networkx as nx
244+ # # import graphviz
245+ # import matplotlib.pyplot as plt
246+ # import matplotlib.image as mpimg
247+ # import tempfile
248+ # import os
249+
250+
251+ class GraphRenderer :
252+ """Simple renderer for NetworkX graphs using Graphviz DOT format"""
253+
254+ def __init__ (self ):
255+ """Initialize the renderer"""
256+ pass
257+
258+ def _sanitize_node_id (self , node_id : Any ) -> str :
259+ """Convert node_id to a valid DOT identifier using hash"""
260+ return f"node_{ hash (node_id )} "
261+
262+ def _get_node_label (
263+ self , node_id : Any , label_lut : dict [Any , str ] | None = None
264+ ) -> str :
265+ """Get label for a node"""
266+ if label_lut and node_id in label_lut :
267+ return label_lut [node_id ]
268+ return str (node_id )
269+
270+ def generate_dot (
271+ self ,
272+ graph : "nx.DiGraph" ,
273+ label_lut : dict [Any , str ] | None = None ,
274+ rankdir : str = "TB" ,
275+ node_shape : str = "box" ,
276+ node_style : str = "filled" ,
277+ node_color : str = "lightblue" ,
278+ edge_color : str = "black" ,
279+ dpi : int = 150 ,
280+ ) -> str :
281+ """
282+ Generate DOT syntax from NetworkX graph
283+
284+ Args:
285+ graph: NetworkX DiGraph to render
286+ label_lut: Optional dictionary mapping node_id -> display_label
287+ rankdir: Graph direction ('TB', 'BT', 'LR', 'RL')
288+ node_shape: Shape for all nodes
289+ node_style: Style for all nodes
290+ node_color: Fill color for all nodes
291+ edge_color: Color for all edges
292+ dpi: Resolution for rendered image (default 150)
293+
294+ Returns:
295+ DOT format string
296+ """
297+ try :
298+ import graphviz
299+ except ImportError as e :
300+ raise ImportError (
301+ "Graphviz is not installed. Please install graphviz to render graph of the pipeline."
302+ ) from e
303+
304+ dot = graphviz .Digraph (comment = "NetworkX Graph" )
305+
306+ # Set graph attributes
307+ dot .attr (rankdir = rankdir , dpi = str (dpi ))
308+ dot .attr ("node" , shape = node_shape , style = node_style , fillcolor = node_color )
309+ dot .attr ("edge" , color = edge_color )
310+
311+ # Add nodes
312+ for node_id in graph .nodes ():
313+ sanitized_id = self ._sanitize_node_id (node_id )
314+ label = self ._get_node_label (node_id , label_lut )
315+ dot .node (sanitized_id , label = label )
316+
317+ # Add edges
318+ for source , target in graph .edges ():
319+ source_id = self ._sanitize_node_id (source )
320+ target_id = self ._sanitize_node_id (target )
321+ dot .edge (source_id , target_id )
322+
323+ return dot .source
324+
325+ def render_graph (
326+ self ,
327+ graph : nx .DiGraph ,
328+ label_lut : dict [Any , str ] | None = None ,
329+ show : bool = True ,
330+ output_path : str | None = None ,
331+ raw_output : bool = False ,
332+ rankdir : str = "TB" ,
333+ figsize : tuple = (6 , 4 ),
334+ dpi : int = 150 ,
335+ ** style_kwargs ,
336+ ) -> str | None :
337+ """
338+ Render NetworkX graph using Graphviz
339+
340+ Args:
341+ graph: NetworkX DiGraph to render
342+ label_lut: Optional dictionary mapping node_id -> display_label
343+ show: Display the graph using matplotlib
344+ output_path: Save graph to file (e.g., 'graph.png', 'graph.pdf')
345+ raw_output: Return DOT syntax instead of rendering
346+ rankdir: Graph direction ('TB', 'BT', 'LR', 'RL')
347+ figsize: Figure size for matplotlib display
348+ dpi: Resolution for rendered image (default 150)
349+ **style_kwargs: Additional styling (node_color, edge_color, node_shape, etc.)
350+
351+ Returns:
352+ DOT syntax if raw_output=True, None otherwise
353+ """
354+ try :
355+ import graphviz
356+ except ImportError as e :
357+ raise ImportError (
358+ "Graphviz is not installed. Please install graphviz to render graph of the pipeline."
359+ ) from e
360+
361+ if raw_output :
362+ return self .generate_dot (graph , label_lut , rankdir , dpi = dpi , ** style_kwargs )
363+
364+ # Create Graphviz object
365+ dot = graphviz .Digraph (comment = "NetworkX Graph" )
366+ dot .attr (rankdir = rankdir , dpi = str (dpi ))
367+
368+ # Apply styling
369+ node_shape = style_kwargs .get ("node_shape" , "box" )
370+ node_style = style_kwargs .get ("node_style" , "filled" )
371+ node_color = style_kwargs .get ("node_color" , "lightblue" )
372+ edge_color = style_kwargs .get ("edge_color" , "black" )
373+
374+ dot .attr ("node" , shape = node_shape , style = node_style , fillcolor = node_color )
375+ dot .attr ("edge" , color = edge_color )
376+
377+ # Add nodes with labels
378+ for node_id in graph .nodes ():
379+ sanitized_id = self ._sanitize_node_id (node_id )
380+ label = self ._get_node_label (node_id , label_lut )
381+ dot .node (sanitized_id , label = label )
382+
383+ # Add edges
384+ for source , target in graph .edges ():
385+ source_id = self ._sanitize_node_id (source )
386+ target_id = self ._sanitize_node_id (target )
387+ dot .edge (source_id , target_id )
388+
389+ # Handle output
390+ if output_path :
391+ # Save to file
392+ name , ext = os .path .splitext (output_path )
393+ format_type = ext [1 :] if ext else "png"
394+ dot .render (name , format = format_type , cleanup = True )
395+ print (f"Graph saved to { output_path } " )
396+
397+ if show :
398+ # Display with matplotlib
399+ with tempfile .NamedTemporaryFile (suffix = ".png" , delete = False ) as tmp :
400+ dot .render (tmp .name [:- 4 ], format = "png" , cleanup = True )
401+
402+ import matplotlib .pyplot as plt
403+ import matplotlib .image as mpimg
404+
405+ # Display with matplotlib
406+ img = mpimg .imread (tmp .name )
407+ plt .figure (figsize = figsize )
408+ plt .imshow (img )
409+ plt .axis ("off" )
410+ plt .title ("Graph Visualization" )
411+ plt .tight_layout ()
412+ plt .show ()
413+
414+ # Clean up
415+ os .unlink (tmp .name )
416+
417+ return None
418+
419+
420+ # Convenience function for quick rendering
421+ def render_graph (
422+ graph : nx .DiGraph , label_lut : dict [Any , str ] | None = None , ** kwargs
423+ ) -> str | None :
424+ """
425+ Convenience function to quickly render a NetworkX graph
426+
427+ Args:
428+ graph: NetworkX DiGraph to render
429+ label_lut: Optional dictionary mapping node_id -> display_label
430+ **kwargs: All other arguments passed to GraphRenderer.render_graph()
431+
432+ Returns:
433+ DOT syntax if raw_output=True, None otherwise
434+ """
435+ renderer = GraphRenderer ()
436+ return renderer .render_graph (graph , label_lut , ** kwargs )
0 commit comments