Skip to content

Commit 35153e2

Browse files
committed
fix: time update and upstream reference
1 parent e6c7ec9 commit 35153e2

File tree

5 files changed

+77
-26
lines changed

5 files changed

+77
-26
lines changed

src/orcapod/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .pipeline import Pipeline
1010

1111

12+
1213
no_tracking = DEFAULT_TRACKER_MANAGER.no_tracking
1314

1415
__all__ = [

src/orcapod/core/streams/pod_node_stream.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def __init__(self, pod_node: pp.PodNode, input_stream: cp.Stream, **kwargs):
3838
super().__init__(source=pod_node, upstreams=(input_stream,), **kwargs)
3939
self.pod_node = pod_node
4040
self.input_stream = input_stream
41-
self._set_modified_time() # set modified time to when we obtain the iterator
42-
# capture the immutable iterator from the input stream
4341

42+
# capture the immutable iterator from the input stream
4443
self._prepared_stream_iterator = input_stream.iter_packets()
44+
self._set_modified_time() # set modified time to when we obtain the iterator
4545

4646
# Packet-level caching (from your PodStream)
4747
self._cached_output_packets: list[tuple[cp.Tag, cp.Packet | None]] | None = None
@@ -134,7 +134,7 @@ def run(
134134
cached_results = []
135135

136136
# identify all entries in the input stream for which we still have not computed packets
137-
if filter is not None:
137+
if len(args) > 0 or len(kwargs) > 0:
138138
input_stream_used = self.input_stream.polars_filter(*args, **kwargs)
139139
else:
140140
input_stream_used = self.input_stream
@@ -194,6 +194,7 @@ def run(
194194

195195
if existing is not None and existing.num_rows > 0:
196196
# If there are existing entries, we can cache them
197+
# TODO: cache them based on the record ID
197198
existing_stream = TableStream(existing, tag_columns=tag_keys)
198199
for tag, packet in existing_stream.iter_packets():
199200
cached_results.append((tag, packet))
@@ -232,6 +233,14 @@ def run(
232233

233234
self._cached_output_packets = cached_results
234235
self._set_modified_time()
236+
self.pod_node.flush()
237+
# TODO: evaluate proper handling of cache here
238+
self.clear_cache()
239+
240+
def clear_cache(self) -> None:
241+
self._cached_output_packets = None
242+
self._cached_output_table = None
243+
self._cached_content_hash_column = None
235244

236245
def iter_packets(
237246
self, execution_engine: cp.ExecutionEngine | None = None
@@ -423,21 +432,41 @@ def as_table(
423432

424433
converter = self.data_context.type_converter
425434

426-
struct_packets = converter.python_dicts_to_struct_dicts(all_packets)
427-
all_tags_as_tables: pa.Table = pa.Table.from_pylist(
428-
all_tags, schema=tag_schema
429-
)
430-
all_packets_as_tables: pa.Table = pa.Table.from_pylist(
431-
struct_packets, schema=packet_schema
432-
)
435+
if len(all_tags) == 0:
436+
tag_types, packet_types = self.pod_node.output_types(
437+
include_system_tags=True
438+
)
439+
tag_schema = converter.python_schema_to_arrow_schema(tag_types)
440+
source_entries = {
441+
f"{constants.SOURCE_PREFIX}{c}": str for c in packet_types.keys()
442+
}
443+
packet_types.update(source_entries)
444+
packet_types[constants.CONTEXT_KEY] = str
445+
packet_schema = converter.python_schema_to_arrow_schema(packet_types)
446+
total_schema = arrow_utils.join_arrow_schemas(tag_schema, packet_schema)
447+
# return an empty table with the right schema
448+
self._cached_output_table = pa.Table.from_pylist(
449+
[], schema=total_schema
450+
)
451+
else:
452+
struct_packets = converter.python_dicts_to_struct_dicts(all_packets)
433453

434-
self._cached_output_table = arrow_utils.hstack_tables(
435-
all_tags_as_tables, all_packets_as_tables
436-
)
454+
all_tags_as_tables: pa.Table = pa.Table.from_pylist(
455+
all_tags, schema=tag_schema
456+
)
457+
all_packets_as_tables: pa.Table = pa.Table.from_pylist(
458+
struct_packets, schema=packet_schema
459+
)
460+
461+
self._cached_output_table = arrow_utils.hstack_tables(
462+
all_tags_as_tables, all_packets_as_tables
463+
)
437464
assert self._cached_output_table is not None, (
438465
"_cached_output_table should not be None here."
439466
)
440467

468+
if self._cached_output_table.num_rows == 0:
469+
return self._cached_output_table
441470
drop_columns = []
442471
if not include_source:
443472
drop_columns.extend(f"{constants.SOURCE_PREFIX}{c}" for c in self.keys()[1])

src/orcapod/pipeline/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ def run(
192192
Current implementation uses a simple traversal through all nodes. Future versions
193193
may implement more efficient graph traversal algorithms.
194194
"""
195-
for node in self.nodes.values():
195+
import networkx as nx
196+
for node in nx.topological_sort(self.graph):
196197
if run_async:
197198
synchronous_run(node.run_async, execution_engine=execution_engine)
198199
else:
@@ -215,7 +216,7 @@ def wrap_invocation(
215216
pipeline_database=self.pipeline_database,
216217
pipeline_path_prefix=self.pipeline_store_path_prefix,
217218
label=invocation.label,
218-
kernel_type="pod",
219+
kernel_type="function",
219220
)
220221
elif invocation in self.invocation_to_source_lut:
221222
source = self.invocation_to_source_lut[invocation]
@@ -306,7 +307,7 @@ class GraphRenderer:
306307
"style": "filled",
307308
"typefontcolor": "#3A3737", # dark gray
308309
},
309-
"pod": {
310+
"function": {
310311
"fillcolor": "#f5f5f5", # off white
311312
"shape": "cylinder",
312313
"fontcolor": "#090271", # darker navy blue
@@ -633,7 +634,7 @@ def create_custom_rules(
633634
"style": "filled",
634635
"type_font_color": operator_type_fcolor,
635636
},
636-
"pod": {
637+
"function": {
637638
"fillcolor": pod_bg,
638639
"shape": "box",
639640
"fontcolor": pod_main_fcolor,

src/orcapod/pipeline/nodes.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,16 +91,15 @@ def pipeline_path(self) -> tuple[str, ...]:
9191
...
9292

9393
def validate_inputs(self, *streams: cp.Stream) -> None:
94-
"""Sources take no input streams."""
95-
if len(streams) > 0:
96-
raise NotImplementedError(
97-
"At this moment, Node does not yet support handling additional input streams."
98-
)
94+
return
9995

100-
def forward(self, *streams: cp.Stream) -> cp.Stream:
101-
# TODO: re-evaluate the use here -- consider semi joining with input streams
102-
# super().validate_inputs(*self.input_streams)
103-
return super().forward(*self.upstreams) # type: ignore[return-value]
96+
# def forward(self, *streams: cp.Stream) -> cp.Stream:
97+
# # TODO: re-evaluate the use here -- consider semi joining with input streams
98+
# # super().validate_inputs(*self.input_streams)
99+
# return super().forward(*self.upstreams) # type: ignore[return-value]
100+
101+
def pre_kernel_processing(self, *streams: cp.Stream) -> tuple[cp.Stream, ...]:
102+
return self.upstreams
104103

105104
def kernel_output_types(
106105
self, *streams: cp.Stream, include_system_tags: bool = False
@@ -128,6 +127,9 @@ def get_all_records(
128127
"""
129128
raise NotImplementedError("This method should be implemented by subclasses.")
130129

130+
def flush(self):
131+
self.pipeline_database.flush()
132+
131133

132134
class KernelNode(NodeBase, WrappedKernel):
133135
"""
@@ -264,6 +266,11 @@ def __init__(
264266
**kwargs,
265267
)
266268

269+
def flush(self):
270+
self.pipeline_database.flush()
271+
if self.result_database is not None:
272+
self.result_database.flush()
273+
267274
@property
268275
def contained_kernel(self) -> cp.Kernel:
269276
return self.pod

src/orcapod/protocols/pipeline_protocols.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def get_all_records(
3535
"""
3636
...
3737

38+
def flush(self):
39+
"""
40+
Flush any in-memory data to persistent storage.
41+
42+
This method ensures that all buffered data is written to the underlying
43+
storage system, making it durable and consistent. It is useful for:
44+
- Ensuring data integrity before shutdown or restart
45+
- Committing changes after a batch of operations
46+
- Reducing memory usage by clearing buffers
47+
48+
"""
49+
...
50+
3851
def add_pipeline_record(
3952
self,
4053
tag: cp.Tag,

0 commit comments

Comments
 (0)