Skip to content

Commit d4708bc

Browse files
committed
feat: add easier access to source, function and operator pods in pipeline
1 parent db5b5d7 commit d4708bc

File tree

1 file changed

+42
-14
lines changed

1 file changed

+42
-14
lines changed

src/orcapod/pipeline/graph.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from orcapod import contexts
55
from orcapod.protocols import core_protocols as cp
66
from orcapod.protocols import database_protocols as dbp
7-
from typing import Any
7+
from typing import Any, cast
88
from collections.abc import Collection
99
import os
1010
import tempfile
@@ -94,11 +94,39 @@ def __init__(
9494
self.results_store_path_prefix = self.name + ("_results",)
9595
self.pipeline_database = pipeline_database
9696
self.results_database = results_database
97-
self.nodes: dict[str, Node] = {}
97+
self._nodes: dict[str, Node] = {}
9898
self.auto_compile = auto_compile
9999
self._dirty = False
100100
self._ordered_nodes = [] # Track order of invocations
101101

102+
@property
103+
def nodes(self) -> dict[str, Node]:
104+
return self._nodes.copy()
105+
106+
@property
107+
def function_pods(self) -> dict[str, cp.Pod]:
108+
return {
109+
label: cast(cp.Pod, node)
110+
for label, node in self._nodes.items()
111+
if getattr(node, "kernel_type") == "function"
112+
}
113+
114+
@property
115+
def source_pods(self) -> dict[str, cp.Source]:
116+
return {
117+
label: node
118+
for label, node in self._nodes.items()
119+
if getattr(node, "kernel_type") == "source"
120+
}
121+
122+
@property
123+
def operator_pods(self) -> dict[str, cp.Kernel]:
124+
return {
125+
label: node
126+
for label, node in self._nodes.items()
127+
if getattr(node, "kernel_type") == "operator"
128+
}
129+
102130
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
103131
"""
104132
Exit the pipeline context, ensuring all nodes are properly closed.
@@ -156,13 +184,13 @@ def compile(self) -> None:
156184
# If there are multiple nodes with the same label, we need to resolve the collision
157185
logger.info(f"Collision detected for label '{label}': {nodes}")
158186
for i, node in enumerate(nodes, start=1):
159-
self.nodes[f"{label}_{i}"] = node
187+
self._nodes[f"{label}_{i}"] = node
160188
node.label = f"{label}_{i}"
161189
else:
162-
self.nodes[label] = nodes[0]
190+
self._nodes[label] = nodes[0]
163191
nodes[0].label = label
164192

165-
self.label_lut = {v: k for k, v in self.nodes.items()}
193+
self.label_lut = {v: k for k, v in self._nodes.items()}
166194

167195
self.graph = node_graph
168196

@@ -172,7 +200,7 @@ def show_graph(self, **kwargs) -> None:
172200
def set_mode(self, mode: str) -> None:
173201
if mode not in ("production", "development"):
174202
raise ValueError("Mode must be either 'production' or 'development'")
175-
for node in self.nodes.values():
203+
for node in self._nodes.values():
176204
if hasattr(node, "set_mode"):
177205
node.set_mode(mode)
178206

@@ -257,27 +285,27 @@ def wrap_invocation(
257285

258286
def __getattr__(self, item: str) -> Any:
259287
"""Allow direct access to pipeline attributes."""
260-
if item in self.nodes:
261-
return self.nodes[item]
288+
if item in self._nodes:
289+
return self._nodes[item]
262290
raise AttributeError(f"Pipeline has no attribute '{item}'")
263291

264292
def __dir__(self) -> list[str]:
265293
"""Return a list of attributes and methods of the pipeline."""
266-
return list(super().__dir__()) + list(self.nodes.keys())
294+
return list(super().__dir__()) + list(self._nodes.keys())
267295

268296
def rename(self, old_name: str, new_name: str) -> None:
269297
"""
270298
Rename a node in the pipeline.
271299
This will update the label and the internal mapping.
272300
"""
273-
if old_name not in self.nodes:
301+
if old_name not in self._nodes:
274302
raise KeyError(f"Node '{old_name}' does not exist in the pipeline.")
275-
if new_name in self.nodes:
303+
if new_name in self._nodes:
276304
raise KeyError(f"Node '{new_name}' already exists in the pipeline.")
277-
node = self.nodes[old_name]
278-
del self.nodes[old_name]
305+
node = self._nodes[old_name]
306+
del self._nodes[old_name]
279307
node.label = new_name
280-
self.nodes[new_name] = node
308+
self._nodes[new_name] = node
281309
logger.info(f"Node '{old_name}' renamed to '{new_name}'")
282310

283311

0 commit comments

Comments
 (0)