|
1 | 1 | import hashlib |
2 | 2 | import logging |
3 | | -import sys |
4 | 3 | from abc import abstractmethod |
5 | 4 | from collections.abc import Callable, Collection, Iterable, Sequence |
6 | 5 | from datetime import datetime, timezone |
7 | | -from typing import TYPE_CHECKING, Any, Literal |
| 6 | +from typing import TYPE_CHECKING, Any, Literal, Protocol, cast |
8 | 7 |
|
9 | 8 | from orcapod import contexts |
10 | 9 | from orcapod.core.datagrams import ( |
@@ -224,54 +223,50 @@ def track_invocation(self, *streams: cp.Stream, label: str | None = None) -> Non |
224 | 223 | self._tracker_manager.record_pod_invocation(self, streams, label=label) |
225 | 224 |
|
226 | 225 |
|
| 226 | +class CallableWithPod(Protocol): |
| 227 | + def __call__(self, *args, **kwargs) -> Any: ... |
| 228 | + |
| 229 | + @property |
| 230 | + def pod(self) -> "FunctionPod": ... |
| 231 | + |
| 232 | + |
227 | 233 | def function_pod( |
228 | 234 | output_keys: str | Collection[str] | None = None, |
229 | 235 | function_name: str | None = None, |
230 | 236 | version: str = "v0.0", |
231 | 237 | label: str | None = None, |
232 | 238 | **kwargs, |
233 | | -) -> Callable[..., "FunctionPod"]: |
| 239 | +) -> Callable[..., CallableWithPod]: |
234 | 240 | """ |
235 | | - Decorator that wraps a function in a FunctionPod instance. |
| 241 | + Decorator that attaches FunctionPod as pod attribute. |
236 | 242 |
|
237 | 243 | Args: |
238 | 244 | output_keys: Keys for the function output(s) |
239 | 245 | function_name: Name of the function pod; if None, defaults to the function name |
240 | 246 | **kwargs: Additional keyword arguments to pass to the FunctionPod constructor. Please refer to the FunctionPod documentation for details. |
241 | 247 |
|
242 | 248 | Returns: |
243 | | - FunctionPod instance wrapping the decorated function |
| 249 | + CallableWithPod: Decorated function with `pod` attribute holding the FunctionPod instance |
244 | 250 | """ |
245 | 251 |
|
246 | | - def decorator(func) -> FunctionPod: |
| 252 | + def decorator(func: Callable) -> CallableWithPod: |
247 | 253 | if func.__name__ == "<lambda>": |
248 | 254 | raise ValueError("Lambda functions cannot be used with function_pod") |
249 | 255 |
|
250 | | - if not hasattr(func, "__module__") or func.__module__ is None: |
251 | | - raise ValueError( |
252 | | - f"Function {func.__name__} must be defined at module level" |
253 | | - ) |
254 | | - |
255 | 256 | # Store the original function in the module for pickling purposes |
256 | 257 | # and make sure to change the name of the function |
257 | | - module = sys.modules[func.__module__] |
258 | | - base_function_name = func.__name__ |
259 | | - new_function_name = f"_original_{func.__name__}" |
260 | | - setattr(module, new_function_name, func) |
261 | | - # rename the function to be consistent and make it pickleable |
262 | | - setattr(func, "__name__", new_function_name) |
263 | | - setattr(func, "__qualname__", new_function_name) |
264 | 258 |
|
265 | 259 | # Create a simple typed function pod |
266 | 260 | pod = FunctionPod( |
267 | 261 | function=func, |
268 | 262 | output_keys=output_keys, |
269 | | - function_name=function_name or base_function_name, |
| 263 | + function_name=function_name or func.__name__, |
270 | 264 | version=version, |
271 | 265 | label=label, |
272 | 266 | **kwargs, |
273 | 267 | ) |
274 | | - return pod |
| 268 | + setattr(func, "pod", pod) |
| 269 | + return cast(CallableWithPod, func) |
275 | 270 |
|
276 | 271 | return decorator |
277 | 272 |
|
|
0 commit comments