Skip to content

Commit 3eb5cee

Browse files
committed
feat: enhance function_pod decorator to attach FunctionPod as a callable attribute
1 parent 08a82e4 commit 3eb5cee

File tree

1 file changed

+15
-20
lines changed

1 file changed

+15
-20
lines changed

src/orcapod/core/pods.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import hashlib
22
import logging
3-
import sys
43
from abc import abstractmethod
54
from collections.abc import Callable, Collection, Iterable, Sequence
65
from datetime import datetime, timezone
7-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, Protocol, cast
87

98
from orcapod import contexts
109
from orcapod.core.datagrams import (
@@ -224,54 +223,50 @@ def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> Non
224223
self._tracker_manager.record_pod_invocation(self, streams, label=label)
225224

226225

226+
class CallableWithPod(Protocol):
227+
def __call__(self, *args, **kwargs) -> Any: ...
228+
229+
@property
230+
def pod(self) -> "FunctionPod": ...
231+
232+
227233
def function_pod(
228234
output_keys: str | Collection[str] | None = None,
229235
function_name: str | None = None,
230236
version: str = "v0.0",
231237
label: str | None = None,
232238
**kwargs,
233-
) -> Callable[..., "FunctionPod"]:
239+
) -> Callable[..., CallableWithPod]:
234240
"""
235-
Decorator that wraps a function in a FunctionPod instance.
241+
Decorator that attaches FunctionPod as pod attribute.
236242
237243
Args:
238244
output_keys: Keys for the function output(s)
239245
function_name: Name of the function pod; if None, defaults to the function name
240246
**kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details.
241247
242248
Returns:
243-
FunctionPod instance wrapping the decorated function
249+
CallableWithPod: Decorated function with `pod` attribute holding the FunctionPod instance
244250
"""
245251

246-
def decorator(func) -> FunctionPod:
252+
def decorator(func: Callable) -> CallableWithPod:
247253
if func.__name__ == "<lambda>":
248254
raise ValueError("Lambda functions cannot be used with function_pod")
249255

250-
if not hasattr(func, "__module__") or func.__module__ is None:
251-
raise ValueError(
252-
f"Function {func.__name__} must be defined at module level"
253-
)
254-
255256
# Store the original function in the module for pickling purposes
256257
# and make sure to change the name of the function
257-
module = sys.modules[func.__module__]
258-
base_function_name = func.__name__
259-
new_function_name = f"_original_{func.__name__}"
260-
setattr(module, new_function_name, func)
261-
# rename the function to be consistent and make it pickleable
262-
setattr(func, "__name__", new_function_name)
263-
setattr(func, "__qualname__", new_function_name)
264258

265259
# Create a simple typed function pod
266260
pod = FunctionPod(
267261
function=func,
268262
output_keys=output_keys,
269-
function_name=function_name or base_function_name,
263+
function_name=function_name or func.__name__,
270264
version=version,
271265
label=label,
272266
**kwargs,
273267
)
274-
return pod
268+
setattr(func, "pod", pod)
269+
return cast(CallableWithPod, func)
275270

276271
return decorator
277272

0 commit comments

Comments
 (0)