Skip to content

Commit f9e8c5b

Browse files
committed
feat: add execution engine options to various stream and source methods
- Enhanced InvocationBase and SourceBase classes to accept `execution_engine_opts` in methods like `iter_packets`, `as_table`, `flow`, `run`, and `run_async`. - Updated StatefulStreamBase and its subclasses to support `execution_engine_opts` in their method signatures and implementations. - Modified CachedPodStream, KernelStream, LazyPodResultStream, PodNodeStream, TableStream, and WrappedStream to propagate `execution_engine_opts` through their respective methods. - Updated RayEngine to handle function submission with engine-specific options. - Enhanced the Pipeline class to pass `execution_engine_opts` during node execution. - Updated core protocols to include `execution_engine_opts` in relevant method signatures for Pods and Streams. This change improves flexibility and configurability of execution engines across the Orcapod framework.
1 parent 7512b37 commit f9e8c5b

File tree

15 files changed

+436
-66
lines changed

15 files changed

+436
-66
lines changed

src/orcapod/core/pods.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def call(
207207
packet: cp.Packet,
208208
record_id: str | None = None,
209209
execution_engine: cp.ExecutionEngine | None = None,
210+
execution_engine_opts: dict[str, Any] | None = None,
210211
) -> tuple[cp.Tag, cp.Packet | None]: ...
211212

212213
@abstractmethod
@@ -216,6 +217,7 @@ async def async_call(
216217
packet: cp.Packet,
217218
record_id: str | None = None,
218219
execution_engine: cp.ExecutionEngine | None = None,
220+
execution_engine_opts: dict[str, Any] | None = None,
219221
) -> tuple[cp.Tag, cp.Packet | None]: ...
220222

221223
def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> None:
@@ -408,6 +410,7 @@ def call(
408410
packet: cp.Packet,
409411
record_id: str | None = None,
410412
execution_engine: cp.ExecutionEngine | None = None,
413+
execution_engine_opts: dict[str, Any] | None = None,
411414
) -> tuple[cp.Tag, DictPacket | None]:
412415
if not self.is_active():
413416
logger.info(
@@ -426,7 +429,11 @@ def call(
426429
with self._tracker_manager.no_tracking():
427430
if execution_engine is not None:
428431
# use the provided execution engine to run the function
429-
values = execution_engine.submit_sync(self.function, **input_dict)
432+
values = execution_engine.submit_sync(
433+
self.function,
434+
fn_kwargs=input_dict,
435+
engine_opts=execution_engine_opts,
436+
)
430437
else:
431438
values = self.function(**input_dict)
432439

@@ -458,6 +465,7 @@ async def async_call(
458465
packet: cp.Packet,
459466
record_id: str | None = None,
460467
execution_engine: cp.ExecutionEngine | None = None,
468+
execution_engine_opts: dict[str, Any] | None = None,
461469
) -> tuple[cp.Tag, cp.Packet | None]:
462470
"""
463471
Asynchronous call to the function pod. This is a placeholder for future implementation.
@@ -481,7 +489,9 @@ async def async_call(
481489
input_dict = packet
482490
if execution_engine is not None:
483491
# use the provided execution engine to run the function
484-
values = await execution_engine.submit_async(self.function, **input_dict)
492+
values = await execution_engine.submit_async(
493+
self.function, fn_kwargs=input_dict, engine_opts=execution_engine_opts
494+
)
485495
else:
486496
values = self.function(**input_dict)
487497

@@ -607,9 +617,14 @@ def call(
607617
packet: cp.Packet,
608618
record_id: str | None = None,
609619
execution_engine: cp.ExecutionEngine | None = None,
620+
execution_engine_opts: dict[str, Any] | None = None,
610621
) -> tuple[cp.Tag, cp.Packet | None]:
611622
return self.pod.call(
612-
tag, packet, record_id=record_id, execution_engine=execution_engine
623+
tag,
624+
packet,
625+
record_id=record_id,
626+
execution_engine=execution_engine,
627+
execution_engine_opts=execution_engine_opts,
613628
)
614629

615630
async def async_call(
@@ -618,9 +633,14 @@ async def async_call(
618633
packet: cp.Packet,
619634
record_id: str | None = None,
620635
execution_engine: cp.ExecutionEngine | None = None,
636+
execution_engine_opts: dict[str, Any] | None = None,
621637
) -> tuple[cp.Tag, cp.Packet | None]:
622638
return await self.pod.async_call(
623-
tag, packet, record_id=record_id, execution_engine=execution_engine
639+
tag,
640+
packet,
641+
record_id=record_id,
642+
execution_engine=execution_engine,
643+
execution_engine_opts=execution_engine_opts,
624644
)
625645

626646
def kernel_identity_structure(
@@ -683,6 +703,7 @@ def call(
683703
packet: cp.Packet,
684704
record_id: str | None = None,
685705
execution_engine: cp.ExecutionEngine | None = None,
706+
execution_engine_opts: dict[str, Any] | None = None,
686707
skip_cache_lookup: bool = False,
687708
skip_cache_insert: bool = False,
688709
) -> tuple[cp.Tag, cp.Packet | None]:
@@ -700,7 +721,11 @@ def call(
700721
print(f"Cache hit for {packet}!")
701722
if output_packet is None:
702723
tag, output_packet = super().call(
703-
tag, packet, record_id=record_id, execution_engine=execution_engine
724+
tag,
725+
packet,
726+
record_id=record_id,
727+
execution_engine=execution_engine,
728+
execution_engine_opts=execution_engine_opts,
704729
)
705730
if (
706731
output_packet is not None
@@ -717,6 +742,7 @@ async def async_call(
717742
packet: cp.Packet,
718743
record_id: str | None = None,
719744
execution_engine: cp.ExecutionEngine | None = None,
745+
execution_engine_opts: dict[str, Any] | None = None,
720746
skip_cache_lookup: bool = False,
721747
skip_cache_insert: bool = False,
722748
) -> tuple[cp.Tag, cp.Packet | None]:
@@ -732,14 +758,19 @@ async def async_call(
732758
output_packet = self.get_cached_output_for_packet(packet)
733759
if output_packet is None:
734760
tag, output_packet = await super().async_call(
735-
tag, packet, record_id=record_id, execution_engine=execution_engine
761+
tag,
762+
packet,
763+
record_id=record_id,
764+
execution_engine=execution_engine,
765+
execution_engine_opts=execution_engine_opts,
736766
)
737767
if output_packet is not None and not skip_cache_insert:
738768
self.record_packet(
739769
packet,
740770
output_packet,
741771
record_id=record_id,
742772
execution_engine=execution_engine,
773+
execution_engine_opts=execution_engine_opts,
743774
)
744775

745776
return tag, output_packet
@@ -754,11 +785,14 @@ def record_packet(
754785
output_packet: cp.Packet,
755786
record_id: str | None = None,
756787
execution_engine: cp.ExecutionEngine | None = None,
788+
execution_engine_opts: dict[str, Any] | None = None,
757789
skip_duplicates: bool = False,
758790
) -> cp.Packet:
759791
"""
760792
Record the output packet against the input packet in the result store.
761793
"""
794+
795+
# TODO: consider incorporating execution_engine_opts into the record
762796
data_table = output_packet.as_table(include_context=True, include_source=True)
763797

764798
for i, (k, v) in enumerate(self.tiered_pod_id.items()):

src/orcapod/core/sources/base.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,13 @@ def __iter__(self) -> Iterator[tuple[cp.Tag, cp.Packet]]:
119119
def iter_packets(
120120
self,
121121
execution_engine: cp.ExecutionEngine | None = None,
122+
execution_engine_opts: dict[str, Any] | None = None,
122123
) -> Iterator[tuple[cp.Tag, cp.Packet]]:
123124
"""Delegate to the cached KernelStream."""
124-
return self().iter_packets(execution_engine=execution_engine)
125+
return self().iter_packets(
126+
execution_engine=execution_engine,
127+
execution_engine_opts=execution_engine_opts,
128+
)
125129

126130
def as_table(
127131
self,
@@ -131,6 +135,7 @@ def as_table(
131135
include_content_hash: bool | str = False,
132136
sort_by_tags: bool = True,
133137
execution_engine: cp.ExecutionEngine | None = None,
138+
execution_engine_opts: dict[str, Any] | None = None,
134139
) -> "pa.Table":
135140
"""Delegate to the cached KernelStream."""
136141
return self().as_table(
@@ -140,39 +145,57 @@ def as_table(
140145
include_content_hash=include_content_hash,
141146
sort_by_tags=sort_by_tags,
142147
execution_engine=execution_engine,
148+
execution_engine_opts=execution_engine_opts,
143149
)
144150

145151
def flow(
146-
self, execution_engine: cp.ExecutionEngine | None = None
152+
self,
153+
execution_engine: cp.ExecutionEngine | None = None,
154+
execution_engine_opts: dict[str, Any] | None = None,
147155
) -> Collection[tuple[cp.Tag, cp.Packet]]:
148156
"""Delegate to the cached KernelStream."""
149-
return self().flow(execution_engine=execution_engine)
157+
return self().flow(
158+
execution_engine=execution_engine,
159+
execution_engine_opts=execution_engine_opts,
160+
)
150161

151162
def run(
152163
self,
153164
*args: Any,
154165
execution_engine: cp.ExecutionEngine | None = None,
166+
execution_engine_opts: dict[str, Any] | None = None,
155167
**kwargs: Any,
156168
) -> None:
157169
"""
158170
Run the source node, executing the contained source.
159171
160172
This is a no-op for sources since they are not executed like pods.
161173
"""
162-
self().run(*args, execution_engine=execution_engine, **kwargs)
174+
self().run(
175+
*args,
176+
execution_engine=execution_engine,
177+
execution_engine_opts=execution_engine_opts,
178+
**kwargs,
179+
)
163180

164181
async def run_async(
165182
self,
166183
*args: Any,
167184
execution_engine: cp.ExecutionEngine | None = None,
185+
execution_engine_opts: dict[str, Any] | None = None,
168186
**kwargs: Any,
169187
) -> None:
170188
"""
171189
Run the source node asynchronously, executing the contained source.
172190
173191
This is a no-op for sources since they are not executed like pods.
174192
"""
175-
await self().run_async(*args, execution_engine=execution_engine, **kwargs)
193+
await self().run_async(
194+
*args,
195+
execution_engine=execution_engine,
196+
execution_engine_opts=execution_engine_opts,
197+
**kwargs,
198+
)
176199

177200
# ==================== LiveStream Protocol (Delegation) ====================
178201

@@ -339,9 +362,13 @@ def __iter__(self) -> Iterator[tuple[cp.Tag, cp.Packet]]:
339362
def iter_packets(
340363
self,
341364
execution_engine: cp.ExecutionEngine | None = None,
365+
execution_engine_opts: dict[str, Any] | None = None,
342366
) -> Iterator[tuple[cp.Tag, cp.Packet]]:
343367
"""Delegate to the cached KernelStream."""
344-
return self().iter_packets(execution_engine=execution_engine)
368+
return self().iter_packets(
369+
execution_engine=execution_engine,
370+
execution_engine_opts=execution_engine_opts,
371+
)
345372

346373
def as_table(
347374
self,
@@ -351,6 +378,7 @@ def as_table(
351378
include_content_hash: bool | str = False,
352379
sort_by_tags: bool = True,
353380
execution_engine: cp.ExecutionEngine | None = None,
381+
execution_engine_opts: dict[str, Any] | None = None,
354382
) -> "pa.Table":
355383
"""Delegate to the cached KernelStream."""
356384
return self().as_table(
@@ -360,39 +388,57 @@ def as_table(
360388
include_content_hash=include_content_hash,
361389
sort_by_tags=sort_by_tags,
362390
execution_engine=execution_engine,
391+
execution_engine_opts=execution_engine_opts,
363392
)
364393

365394
def flow(
366-
self, execution_engine: cp.ExecutionEngine | None = None
395+
self,
396+
execution_engine: cp.ExecutionEngine | None = None,
397+
execution_engine_opts: dict[str, Any] | None = None,
367398
) -> Collection[tuple[cp.Tag, cp.Packet]]:
368399
"""Delegate to the cached KernelStream."""
369-
return self().flow(execution_engine=execution_engine)
400+
return self().flow(
401+
execution_engine=execution_engine,
402+
execution_engine_opts=execution_engine_opts,
403+
)
370404

371405
def run(
372406
self,
373407
*args: Any,
374408
execution_engine: cp.ExecutionEngine | None = None,
409+
execution_engine_opts: dict[str, Any] | None = None,
375410
**kwargs: Any,
376411
) -> None:
377412
"""
378413
Run the source node, executing the contained source.
379414
380415
This is a no-op for sources since they are not executed like pods.
381416
"""
382-
self().run(*args, execution_engine=execution_engine, **kwargs)
417+
self().run(
418+
*args,
419+
execution_engine=execution_engine,
420+
execution_engine_opts=execution_engine_opts,
421+
**kwargs,
422+
)
383423

384424
async def run_async(
385425
self,
386426
*args: Any,
387427
execution_engine: cp.ExecutionEngine | None = None,
428+
execution_engine_opts: dict[str, Any] | None = None,
388429
**kwargs: Any,
389430
) -> None:
390431
"""
391432
Run the source node asynchronously, executing the contained source.
392433
393434
This is a no-op for sources since they are not executed like pods.
394435
"""
395-
await self().run_async(*args, execution_engine=execution_engine, **kwargs)
436+
await self().run_async(
437+
*args,
438+
execution_engine=execution_engine,
439+
execution_engine_opts=execution_engine_opts,
440+
**kwargs,
441+
)
396442

397443
# ==================== LiveStream Protocol (Delegation) ====================
398444

0 commit comments

Comments
 (0)