From a0cc92fe46f1e7b57aaeca651e62e32c0bcaf1fa Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 12:38:22 +0200 Subject: [PATCH 01/40] refactor: Implemented filter operation on XCom --- airflow-core/src/airflow/models/xcom_arg.py | 29 +++-- .../src/airflow/sdk/definitions/xcom_arg.py | 100 ++++++++++++++++++ .../task_sdk/definitions/test_xcom_arg.py | 6 +- 3 files changed, 126 insertions(+), 9 deletions(-) diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index cfda9295cec26..b762d1405d7d1 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -93,15 +93,26 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: @attrs.define class SchedulerZipXComArg(SchedulerXComArg): - args: Sequence[SchedulerXComArg] - fillvalue: Any + args: SchedulerXComArg + callables: Sequence[str] @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: - return cls( - [deserialize_xcom_arg(arg, dag) for arg in data["args"]], - fillvalue=data.get("fillvalue", NOTSET), - ) + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + + +@attrs.define +class SchedulerFilterXComArg(SchedulerXComArg): + arg: SchedulerXComArg + callables: Sequence[str] + + @classmethod + def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: + # We are deliberately NOT deserializing the callables. These are shown + # in the UI, and displaying a function object is useless. + return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) @singledispatch @@ -178,6 +189,11 @@ def _(xcom_arg: SchedulerConcatXComArg, run_id: str, *, session: Session): return sum(ready_lengths) +@get_task_map_length.register +def _(xcom_arg: SchedulerFilterXComArg, run_id: str, *, session: Session): + return get_task_map_length(xcom_arg.arg, run_id, session=session) + + def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): """DAG serialization interface.""" klass = _XCOM_ARG_TYPES[data.get("type", "")] @@ -187,6 +203,7 @@ def deserialize_xcom_arg(data: dict[str, Any], dag: SchedulerDAG): _XCOM_ARG_TYPES: dict[str, type[SchedulerXComArg]] = { "": SchedulerPlainXComArg, "concat": SchedulerConcatXComArg, + "filter": SchedulerFilterXComArg, "map": SchedulerMapXComArg, "zip": SchedulerZipXComArg, } diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d5411b9dd8e46..414dfcaa2c96d 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -42,6 +42,7 @@ # the user, but deserialize them into strings in a serialized XComArg for # safety (those callables are arbitrary user code). MapCallables = Sequence[Callable[[Any], Any]] +FilterCallables = Sequence[Callable[[Any], bool]] class XComArg(ResolveMixin, DependencyMixin): @@ -175,6 +176,9 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) + def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: + return FilterXComArg(self, [f] if f else []) + def resolve(self, context: Mapping[str, Any]) -> Any: raise NotImplementedError() @@ -567,9 +571,105 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) +class _FilterResult(Sequence, Iterable): + def __init__(self, value: Sequence | Iterable, callables: list) -> None: + self.value = value + self.callables = callables + self.length: int | None = None + + def __getitem__(self, index: int) -> Any: + if not (0 <= index < len(self)): + raise IndexError + + value = self.value[index] + if self._apply_callables(value): + return value + return None + + def __len__(self) -> int: + # Calculating the length of an iterable can be a heavy operation, so we cache the result after first attempt + if not self.length: + if isinstance(self.value, Iterable): + self.length = sum(1 for _ in self.value) + else: + self.length = len(self.value) + return self.length + + def __iter__(self) -> Iterator: + for item in iter(self.value): + if self._apply_callables(item): + yield item + + def _apply_callables(self, value) -> bool: + for func in self.callables: + if not func(value): + return False + return True + + +class FilterXComArg(XComArg): + """ + An XCom reference with ``filter()`` call(s) applied. + + This is based on an XComArg, but also applies a series of "filters" that + filters the pulled XCom value. + + :meta private: + """ + + def __init__( + self, + arg: XComArg, + callables: FilterCallables, + ) -> None: + self.arg = arg + + if not callables: + callables = [self.none_filter] + else: + for c in callables: + if getattr(c, "_airflow_is_task_decorator", False): + raise ValueError( + "filter() argument must be a plain function, not a @task operator" + ) + self.callables = callables + + @classmethod + def none_filter(cls, value) -> bool: + return value if True else False + + def __repr__(self) -> str: + map_calls = "".join(f".filter({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" + + def _serialize(self) -> dict[str, Any]: + return { + "arg": serialize_xcom_arg(self.arg), + "callables": [ + inspect.getsource(c) if callable(c) else c for c in self.callables + ], + } + + def iter_references(self) -> Iterator[tuple[Operator, str]]: + yield from self.arg.iter_references() + + def filter(self, f: Callable[[Any], Any]) -> FilterXComArg: + # Filter arg.filter(f1).filter(f2) into one FilterXComArg. + return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter]) + + def resolve(self, context: Mapping[str, Any]) -> Any: + value = self.arg.resolve(context) + if not isinstance(value, (Sequence, dict)): + raise ValueError( + f"XCom filter expects sequence or dict, not {type(value).__name__}" + ) + return _FilterResult(value, self.callables) + + _XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = { "": PlainXComArg, "concat": ConcatXComArg, + "filter": FilterXComArg, "map": MapXComArg, "zip": ZipXComArg, } diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 9a16be08a352c..0704079c9302b 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -60,7 +60,7 @@ def pull(value): assert results == {"aa", "bb", "cc"} -def test_xcom_map_transform_to_none(run_ti: RunTI, mock_supervisor_comms): +def test_xcom_map_transform_to_none_and_filter(run_ti: RunTI, mock_supervisor_comms): results = set() with DAG("test") as dag: @@ -78,7 +78,7 @@ def c_to_none(v): return None return v - pull.expand(value=push().map(c_to_none)) + pull.expand(value=push().map(c_to_none).filter(None)) # Mock xcom result from push task mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) @@ -87,7 +87,7 @@ def c_to_none(v): for map_index in range(3): assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS - assert results == {"a", "b", None} + assert results == {"a", "b"} def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): From c858c9c76abf304c176b6d1bdb2c7cbefca6e557 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 13:19:42 +0200 Subject: [PATCH 02/40] refactor: Fixed some static checks --- airflow-core/src/airflow/models/xcom_arg.py | 1 - task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 12 +++--------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index b762d1405d7d1..d2ab23cd23a20 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -32,7 +32,6 @@ ) from airflow.utils.db import exists_query from airflow.utils.state import State -from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY __all__ = ["XComArg", "get_task_map_length"] diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 414dfcaa2c96d..cf5de8d32a0a8 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -629,9 +629,7 @@ def __init__( else: for c in callables: if getattr(c, "_airflow_is_task_decorator", False): - raise ValueError( - "filter() argument must be a plain function, not a @task operator" - ) + raise ValueError("filter() argument must be a plain function, not a @task operator") self.callables = callables @classmethod @@ -645,9 +643,7 @@ def __repr__(self) -> str: def _serialize(self) -> dict[str, Any]: return { "arg": serialize_xcom_arg(self.arg), - "callables": [ - inspect.getsource(c) if callable(c) else c for c in self.callables - ], + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } def iter_references(self) -> Iterator[tuple[Operator, str]]: @@ -660,9 +656,7 @@ def filter(self, f: Callable[[Any], Any]) -> FilterXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) if not isinstance(value, (Sequence, dict)): - raise ValueError( - f"XCom filter expects sequence or dict, not {type(value).__name__}" - ) + raise ValueError(f"XCom filter expects sequence or dict, not {type(value).__name__}") return _FilterResult(value, self.callables) From cc5e7be1bffcdcafaf87c3cf11edfad371760ecc Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 13:24:00 +0200 Subject: [PATCH 03/40] refactor: Fixed some mypy issues --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index cf5de8d32a0a8..804b9268105dc 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -577,7 +577,7 @@ def __init__(self, value: Sequence | Iterable, callables: list) -> None: self.callables = callables self.length: int | None = None - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: Any) -> Any: if not (0 <= index < len(self)): raise IndexError @@ -620,7 +620,7 @@ class FilterXComArg(XComArg): def __init__( self, arg: XComArg, - callables: FilterCallables, + callables: FilterCallables | None, ) -> None: self.arg = arg @@ -649,7 +649,7 @@ def _serialize(self) -> dict[str, Any]: def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() - def filter(self, f: Callable[[Any], Any]) -> FilterXComArg: + def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: # Filter arg.filter(f1).filter(f2) into one FilterXComArg. return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter]) From 48310f3bc350ffd48b8f5e87608b3996547d4810 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 13:26:22 +0200 Subject: [PATCH 04/40] refactor: Added filter to PlainXComArg --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 804b9268105dc..1a077afbd361e 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -336,6 +336,11 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) + def filter(self, *others: XComArg) -> ConcatXComArg: + if self.key != XCOM_RETURN_KEY: + raise ValueError("cannot filter non-return XCom") + return super().filter(*others) + def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] task_id = self.operator.task_id From 567c84bb9fd60531b533901381de6db73b91dcf7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 13:54:45 +0200 Subject: [PATCH 05/40] refactor: Reverted SchedulerZipXComArg back to orginal --- airflow-core/src/airflow/models/xcom_arg.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow-core/src/airflow/models/xcom_arg.py b/airflow-core/src/airflow/models/xcom_arg.py index d2ab23cd23a20..c7a7c75ba7140 100644 --- a/airflow-core/src/airflow/models/xcom_arg.py +++ b/airflow-core/src/airflow/models/xcom_arg.py @@ -32,6 +32,7 @@ ) from airflow.utils.db import exists_query from airflow.utils.state import State +from airflow.utils.types import NOTSET from airflow.utils.xcom import XCOM_RETURN_KEY __all__ = ["XComArg", "get_task_map_length"] @@ -92,14 +93,15 @@ def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: @attrs.define class SchedulerZipXComArg(SchedulerXComArg): - args: SchedulerXComArg - callables: Sequence[str] + args: Sequence[SchedulerXComArg] + fillvalue: Any @classmethod def _deserialize(cls, data: dict[str, Any], dag: SchedulerDAG) -> Self: - # We are deliberately NOT deserializing the callables. These are shown - # in the UI, and displaying a function object is useless. - return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"]) + return cls( + [deserialize_xcom_arg(arg, dag) for arg in data["args"]], + fillvalue=data.get("fillvalue", NOTSET), + ) @attrs.define From 49f247995eed8ac49199399880aae287a81b6dea Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 14:12:07 +0200 Subject: [PATCH 06/40] refactor: Fixed signature of filter method in PlainXComArg --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 1a077afbd361e..9bb3ade7d2762 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -336,10 +336,10 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - def filter(self, *others: XComArg) -> ConcatXComArg: + def filter(self, f: Callable[[Any], Any]) -> FilterXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot filter non-return XCom") - return super().filter(*others) + return super().filter(f) def resolve(self, context: Mapping[str, Any]) -> Any: ti = context["ti"] From cfeb4a1ecc86ff2dbf09f468456ec02ad13f0696 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 14:35:47 +0200 Subject: [PATCH 07/40] refactor: Fixed method signature filter --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 9bb3ade7d2762..d22cc76e17e95 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -336,7 +336,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - def filter(self, f: Callable[[Any], Any]) -> FilterXComArg: + def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot filter non-return XCom") return super().filter(f) From 4c4ee72aa1373f3cd2235150832618debe05bf42 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 14:37:38 +0200 Subject: [PATCH 08/40] refactor: Fixed callables type in _FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d22cc76e17e95..eed5204f23b47 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -577,7 +577,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: class _FilterResult(Sequence, Iterable): - def __init__(self, value: Sequence | Iterable, callables: list) -> None: + def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> None: self.value = value self.callables = callables self.length: int | None = None From e5358aaa4f9a8ffe4c426dd68c7eedf398419f22 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 14:53:10 +0200 Subject: [PATCH 09/40] refactor: raise TypeError if getitem is called on iterable --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index eed5204f23b47..97d9890bd9d99 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -586,10 +586,12 @@ def __getitem__(self, index: Any) -> Any: if not (0 <= index < len(self)): raise IndexError - value = self.value[index] - if self._apply_callables(value): - return value - return None + if isinstance(self.value, Sequence): + value = self.value[index] + if self._apply_callables(value): + return value + return None + raise TypeError("XComArg filter does not support indexing on non-sequence values") def __len__(self) -> int: # Calculating the length of an iterable can be a heavy operation, so we cache the result after first attempt From faccd4b3a0a76464f7de50d2f380cea831fac7b1 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 18:38:43 +0200 Subject: [PATCH 10/40] refactor: Refactored _MapResult to support iterables --- .../src/airflow/sdk/definitions/xcom_arg.py | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 97d9890bd9d99..a1a9e67a40898 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -392,20 +392,31 @@ def _get_callable_name(f: Callable | str) -> str: class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: + def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) -> None: self.value = value self.callables = callables - def __getitem__(self, index: Any) -> Any: - value = self.value[index] + def __getitem__(self, index: int) -> Any: + if not (0 <= index < len(self)): + raise IndexError - for f in self.callables: - value = f(value) - return value + if hasattr(self.value, '__getitem__'): + value = self.value[index] + return self._apply_callables(value) + raise TypeError("XComArg map does not support indexing on non-sequence values") def __len__(self) -> int: return len(self.value) + def __iter__(self) -> Iterator: + for item in iter(self.value): + yield self._apply_callables(item) + + def _apply_callables(self, value): + for func in self.callables: + value = func(value) + return value + class MapXComArg(XComArg): """ @@ -586,7 +597,7 @@ def __getitem__(self, index: Any) -> Any: if not (0 <= index < len(self)): raise IndexError - if isinstance(self.value, Sequence): + if hasattr(self.value, '__getitem__'): value = self.value[index] if self._apply_callables(value): return value From bd204e078ef22df621d1c922fbd6c5c10ed044b6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 18:48:57 +0200 Subject: [PATCH 11/40] refactor: Register task_map_length on SchedulerFilterXComArg --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index a1a9e67a40898..d02db94003f53 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -736,3 +736,8 @@ def _(xcom_arg: ConcatXComArg, resolved_val: Sized, upstream_map_indexes: dict[s if len(ready_lengths) != len(xcom_arg.args): return None # If any of the referenced XComs is not ready, we are not ready either. return sum(ready_lengths) + + +@get_task_map_length.register +def _(xcom_arg: FilterXComArg, resolved_val: Sized, upstream_map_indexes: dict[str, int]): + return get_task_map_length(xcom_arg.arg, resolved_val, upstream_map_indexes) From bd88a315d16ddcf378ef6a93421835bdb736b77d Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 07:47:33 +0200 Subject: [PATCH 12/40] refactor: Fixed signature of __getitem__ in _MapResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d02db94003f53..4a95e60ccb49c 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -396,11 +396,11 @@ def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) - self.value = value self.callables = callables - def __getitem__(self, index: int) -> Any: + def __getitem__(self, index: Any) -> Any: if not (0 <= index < len(self)): raise IndexError - if hasattr(self.value, '__getitem__'): + if hasattr(self.value, "__getitem__"): value = self.value[index] return self._apply_callables(value) raise TypeError("XComArg map does not support indexing on non-sequence values") From b171639756df478c0df28305a8b9b0ab772cb2ba Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 07:50:43 +0200 Subject: [PATCH 13/40] refactor: Print filter callable result --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 4a95e60ccb49c..0b3534b895466 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -599,7 +599,9 @@ def __getitem__(self, index: Any) -> Any: if hasattr(self.value, '__getitem__'): value = self.value[index] - if self._apply_callables(value): + result = self._apply_callables(value) + print("filter {} is {}".format(value, result)) + if result: return value return None raise TypeError("XComArg filter does not support indexing on non-sequence values") From ce8959eaca2a9ba82778068062620a2b1aef86fb Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 08:43:22 +0200 Subject: [PATCH 14/40] refactor: Changed calculation of length in FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 0b3534b895466..632208f06df35 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -609,10 +609,7 @@ def __getitem__(self, index: Any) -> Any: def __len__(self) -> int: # Calculating the length of an iterable can be a heavy operation, so we cache the result after first attempt if not self.length: - if isinstance(self.value, Iterable): - self.length = sum(1 for _ in self.value) - else: - self.length = len(self.value) + self.length = sum(1 for _ in self) return self.length def __iter__(self) -> Iterator: From 500df61c6eafbb80b06e1a6108b8482f807c178b Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 08:55:00 +0200 Subject: [PATCH 15/40] refactor: Fixed __getitem__ method of FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 632208f06df35..a03cd9c055dbe 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -600,10 +600,9 @@ def __getitem__(self, index: Any) -> Any: if hasattr(self.value, '__getitem__'): value = self.value[index] result = self._apply_callables(value) - print("filter {} is {}".format(value, result)) if result: return value - return None + return self.__getitem__(index + 1) raise TypeError("XComArg filter does not support indexing on non-sequence values") def __len__(self) -> int: From d4adc17a7288ac73ed051451391a03e553cb2514 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 08:56:56 +0200 Subject: [PATCH 16/40] refactor: No need to store result of apply_callables in __getitem__ method of FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index a03cd9c055dbe..29f7852fd3fe0 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -599,8 +599,7 @@ def __getitem__(self, index: Any) -> Any: if hasattr(self.value, '__getitem__'): value = self.value[index] - result = self._apply_callables(value) - if result: + if self._apply_callables(value): return value return self.__getitem__(index + 1) raise TypeError("XComArg filter does not support indexing on non-sequence values") From 800c087875bc90af6d965bd9ffaebde1f3947f1c Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 09:01:46 +0200 Subject: [PATCH 17/40] refactor: Cache result of filtered values in FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 29f7852fd3fe0..d80a61e82e99e 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -590,25 +590,20 @@ def resolve(self, context: Mapping[str, Any]) -> Any: class _FilterResult(Sequence, Iterable): def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> None: self.value = value + self.filtered_values: list | None = None self.callables = callables - self.length: int | None = None def __getitem__(self, index: Any) -> Any: if not (0 <= index < len(self)): raise IndexError - if hasattr(self.value, '__getitem__'): - value = self.value[index] - if self._apply_callables(value): - return value - return self.__getitem__(index + 1) - raise TypeError("XComArg filter does not support indexing on non-sequence values") + return self.filtered_values[index] def __len__(self) -> int: # Calculating the length of an iterable can be a heavy operation, so we cache the result after first attempt - if not self.length: - self.length = sum(1 for _ in self) - return self.length + if not self.filtered_values: + self.filtered_values = list(self) + return len(self.filtered_values) def __iter__(self) -> Iterator: for item in iter(self.value): From d57c288b877d9cd3a92582c97bb87abf511a253a Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 09:43:29 +0200 Subject: [PATCH 18/40] refactor: Refactored the _FilterResult with cache and lazy evaluation if possible --- .../src/airflow/sdk/definitions/xcom_arg.py | 49 ++++++++++++++----- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d80a61e82e99e..7093bf87eefd3 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -588,27 +588,54 @@ def resolve(self, context: Mapping[str, Any]) -> Any: class _FilterResult(Sequence, Iterable): - def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> None: + def __init__(self, value: Sequence | Iterable, callables: list) -> None: self.value = value - self.filtered_values: list | None = None self.callables = callables + self._cache: list = [] + self._iterator = iter(value) + self._exhausted = False + + def _next_filtered(self) -> Any: + """Returns the next item from the iterator that passes all filters.""" + while not self._exhausted: + try: + item = next(self._iterator) + if self._apply_callables(item): + self._cache.append(item) + return item + except StopIteration: + self._exhausted = True + raise StopIteration def __getitem__(self, index: Any) -> Any: - if not (0 <= index < len(self)): + if index < 0: raise IndexError - return self.filtered_values[index] + while len(self._cache) <= index: + try: + self._next_filtered() + except StopIteration: + raise IndexError + + return self._cache[index] def __len__(self) -> int: - # Calculating the length of an iterable can be a heavy operation, so we cache the result after first attempt - if not self.filtered_values: - self.filtered_values = list(self) - return len(self.filtered_values) + # Force full evaluation to determine total length + while not self._exhausted: + try: + self._next_filtered() + except StopIteration: + break + return len(self._cache) def __iter__(self) -> Iterator: - for item in iter(self.value): - if self._apply_callables(item): - yield item + yield from self._cache + + while not self._exhausted: + try: + yield self._next_filtered() + except StopIteration: + break def _apply_callables(self, value) -> bool: for func in self.callables: From d33dd083a484f2a3620adb7bc60a2e24bf838f26 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 10:55:34 +0200 Subject: [PATCH 19/40] refactor: Harmonize the callables typing for filter and map methods --- .../src/airflow/sdk/definitions/xcom_arg.py | 58 ++++++++++++++----- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 7093bf87eefd3..2a95e8fc43f72 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -167,7 +167,7 @@ def _serialize(self) -> dict[str, Any]: """ raise NotImplementedError() - def map(self, f: Callable[[Any], Any]) -> MapXComArg: + def map(self, f: MapCallables) -> MapXComArg: return MapXComArg(self, [f]) def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: @@ -176,7 +176,7 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) - def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: + def filter(self, f: FilterCallables | None) -> FilterXComArg: return FilterXComArg(self, [f] if f else []) def resolve(self, context: Mapping[str, Any]) -> Any: @@ -321,7 +321,7 @@ def as_teardown( def iter_references(self) -> Iterator[tuple[Operator, str]]: yield self.operator, self.key - def map(self, f: Callable[[Any], Any]) -> MapXComArg: + def map(self, f: MapCallables) -> MapXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot map against non-return XCom") return super().map(f) @@ -336,7 +336,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: + def filter(self, f: FilterCallables | None) -> FilterXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot filter non-return XCom") return super().filter(f) @@ -392,25 +392,53 @@ def _get_callable_name(f: Callable | str) -> str: class _MapResult(Sequence): - def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) -> None: + def __init__(self, value: Iterable | Sequence | dict, callables: list) -> None: self.value = value self.callables = callables + self._cache: list = [] + self._iterator = iter(value) + self._exhausted = False + + def _next_mapped(self) -> Any: + """Returns the next transformed item from the iterator.""" + while not self._exhausted: + try: + item = next(self._iterator) + result = self._apply_callables(item) + self._cache.append(result) + return result + except StopIteration: + self._exhausted = True + raise StopIteration def __getitem__(self, index: Any) -> Any: - if not (0 <= index < len(self)): + if index < 0: raise IndexError - if hasattr(self.value, "__getitem__"): - value = self.value[index] - return self._apply_callables(value) - raise TypeError("XComArg map does not support indexing on non-sequence values") + while len(self._cache) <= index: + try: + self._next_mapped() + except StopIteration: + raise IndexError + return self._cache[index] def __len__(self) -> int: - return len(self.value) + # Fully consume the iterator to get accurate length + while not self._exhausted: + try: + self._next_mapped() + except StopIteration: + break + return len(self._cache) def __iter__(self) -> Iterator: - for item in iter(self.value): - yield self._apply_callables(item) + yield from self._cache + + while not self._exhausted: + try: + yield self._next_mapped() + except StopIteration: + break def _apply_callables(self, value): for func in self.callables: @@ -448,7 +476,7 @@ def _serialize(self) -> dict[str, Any]: def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() - def map(self, f: Callable[[Any], Any]) -> MapXComArg: + def map(self, f: MapCallables) -> MapXComArg: # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) @@ -686,7 +714,7 @@ def _serialize(self) -> dict[str, Any]: def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() - def filter(self, f: Callable[[Any], Any] | None) -> FilterXComArg: + def filter(self, f: FilterCallables | None) -> FilterXComArg: # Filter arg.filter(f1).filter(f2) into one FilterXComArg. return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter]) From 0ae23e096617a0812d27283439035e6c83dfd684 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 13:23:42 +0200 Subject: [PATCH 20/40] refactor: Fixed some typings for Map and FilterXComArgs --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 2a95e8fc43f72..ff8504ec0dbed 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -167,7 +167,7 @@ def _serialize(self) -> dict[str, Any]: """ raise NotImplementedError() - def map(self, f: MapCallables) -> MapXComArg: + def map(self, f: Callable[[Any], Any]) -> MapXComArg: return MapXComArg(self, [f]) def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: @@ -176,7 +176,7 @@ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg: def concat(self, *others: XComArg) -> ConcatXComArg: return ConcatXComArg([self, *others]) - def filter(self, f: FilterCallables | None) -> FilterXComArg: + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: return FilterXComArg(self, [f] if f else []) def resolve(self, context: Mapping[str, Any]) -> Any: @@ -321,7 +321,7 @@ def as_teardown( def iter_references(self) -> Iterator[tuple[Operator, str]]: yield self.operator, self.key - def map(self, f: MapCallables) -> MapXComArg: + def map(self, f: Callable[[Any], Any]) -> MapXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot map against non-return XCom") return super().map(f) @@ -336,7 +336,7 @@ def concat(self, *others: XComArg) -> ConcatXComArg: raise ValueError("cannot concatenate non-return XCom") return super().concat(*others) - def filter(self, f: FilterCallables | None) -> FilterXComArg: + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: if self.key != XCOM_RETURN_KEY: raise ValueError("cannot filter non-return XCom") return super().filter(f) @@ -392,7 +392,7 @@ def _get_callable_name(f: Callable | str) -> str: class _MapResult(Sequence): - def __init__(self, value: Iterable | Sequence | dict, callables: list) -> None: + def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) -> None: self.value = value self.callables = callables self._cache: list = [] @@ -476,7 +476,7 @@ def _serialize(self) -> dict[str, Any]: def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() - def map(self, f: MapCallables) -> MapXComArg: + def map(self, f: Callable[[Any], Any]) -> MapXComArg: # Flatten arg.map(f1).map(f2) into one MapXComArg. return MapXComArg(self.arg, [*self.callables, f]) @@ -616,7 +616,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: class _FilterResult(Sequence, Iterable): - def __init__(self, value: Sequence | Iterable, callables: list) -> None: + def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> None: self.value = value self.callables = callables self._cache: list = [] @@ -714,7 +714,7 @@ def _serialize(self) -> dict[str, Any]: def iter_references(self) -> Iterator[tuple[Operator, str]]: yield from self.arg.iter_references() - def filter(self, f: FilterCallables | None) -> FilterXComArg: + def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: # Filter arg.filter(f1).filter(f2) into one FilterXComArg. return FilterXComArg(self.arg, [*self.callables, f if f else self.none_filter]) From ab1e9b6dd2131f09be7ff35f36d4977e241ade18 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Apr 2025 14:00:19 +0200 Subject: [PATCH 21/40] refactor: Fixed docstrings --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index ff8504ec0dbed..293ee335c7057 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -400,7 +400,7 @@ def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) - self._exhausted = False def _next_mapped(self) -> Any: - """Returns the next transformed item from the iterator.""" + """Return the next transformed item from the iterator.""" while not self._exhausted: try: item = next(self._iterator) @@ -624,7 +624,7 @@ def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> No self._exhausted = False def _next_filtered(self) -> Any: - """Returns the next item from the iterator that passes all filters.""" + """Return the next item from the iterator that passes all filters.""" while not self._exhausted: try: item = next(self._iterator) From 34d91769ec854d36f6827121f57324398331ec6d Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 15 Apr 2025 10:05:30 +0200 Subject: [PATCH 22/40] refactor: Introduced CallableResultMixin and splitted _MapResult into specialized _LazyMapResult --- .../src/airflow/sdk/definitions/xcom_arg.py | 71 ++++++++++--------- .../task_sdk/definitions/test_xcom_arg.py | 5 +- 2 files changed, 42 insertions(+), 34 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 293ee335c7057..afd93e48321df 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -391,11 +391,29 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class _MapResult(Sequence): +class CallableResultMixin: def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) -> None: self.value = value self.callables = callables - self._cache: list = [] + + def _apply_callables(self, value): + for func in self.callables: + value = func(value) + return value + + +class _MapResult(CallableResultMixin, Sequence): + def __getitem__(self, index: Any) -> Any: + value = self._apply_callables(self.value[index]) + return value + + def __len__(self) -> int: + return len(self.value) + + +class _LazyMapResult(CallableResultMixin, Sequence): + def __init__(self, value: Iterable, callables: MapCallables) -> None: + super().__init__([], callables) self._iterator = iter(value) self._exhausted = False @@ -405,7 +423,7 @@ def _next_mapped(self) -> Any: try: item = next(self._iterator) result = self._apply_callables(item) - self._cache.append(result) + self.value.append(result) return result except StopIteration: self._exhausted = True @@ -415,12 +433,12 @@ def __getitem__(self, index: Any) -> Any: if index < 0: raise IndexError - while len(self._cache) <= index: + while len(self.value) <= index: try: self._next_mapped() except StopIteration: raise IndexError - return self._cache[index] + return self.value[index] def __len__(self) -> int: # Fully consume the iterator to get accurate length @@ -429,10 +447,10 @@ def __len__(self) -> int: self._next_mapped() except StopIteration: break - return len(self._cache) + return len(self.value) def __iter__(self) -> Iterator: - yield from self._cache + yield from self.value while not self._exhausted: try: @@ -440,11 +458,6 @@ def __iter__(self) -> Iterator: except StopIteration: break - def _apply_callables(self, value): - for func in self.callables: - value = func(value) - return value - class MapXComArg(XComArg): """ @@ -482,9 +495,11 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): - raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - return _MapResult(value, self.callables) + if isinstance(value, (Sequence, dict)): + return _MapResult(value, self.callables) + elif isinstance(value, Iterable): + return _LazyMapResult(value, self.callables) + raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") class _ZipResult(Sequence): @@ -615,11 +630,9 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) -class _FilterResult(Sequence, Iterable): - def __init__(self, value: Sequence | Iterable, callables: FilterCallables) -> None: - self.value = value - self.callables = callables - self._cache: list = [] +class _FilterResult(CallableResultMixin, Sequence): + def __init__(self, value: Iterable | Sequence | dict, callables: FilterCallables) -> None: + super().__init__([], callables) self._iterator = iter(value) self._exhausted = False @@ -629,7 +642,7 @@ def _next_filtered(self) -> Any: try: item = next(self._iterator) if self._apply_callables(item): - self._cache.append(item) + self.value.append(item) return item except StopIteration: self._exhausted = True @@ -639,13 +652,13 @@ def __getitem__(self, index: Any) -> Any: if index < 0: raise IndexError - while len(self._cache) <= index: + while len(self.value) <= index: try: self._next_filtered() except StopIteration: raise IndexError - return self._cache[index] + return self.value[index] def __len__(self) -> int: # Force full evaluation to determine total length @@ -654,10 +667,10 @@ def __len__(self) -> int: self._next_filtered() except StopIteration: break - return len(self._cache) + return len(self.value) def __iter__(self) -> Iterator: - yield from self._cache + yield from self.value while not self._exhausted: try: @@ -665,12 +678,6 @@ def __iter__(self) -> Iterator: except StopIteration: break - def _apply_callables(self, value) -> bool: - for func in self.callables: - if not func(value): - return False - return True - class FilterXComArg(XComArg): """ @@ -720,7 +727,7 @@ def filter(self, f: Callable[[Any], bool] | None) -> FilterXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, (Iterable, Sequence, dict)): raise ValueError(f"XCom filter expects sequence or dict, not {type(value).__name__}") return _FilterResult(value, self.callables) diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 0704079c9302b..bc0a654e9a649 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -62,12 +62,13 @@ def pull(value): def test_xcom_map_transform_to_none_and_filter(run_ti: RunTI, mock_supervisor_comms): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -81,7 +82,7 @@ def c_to_none(v): pull.expand(value=push().map(c_to_none).filter(None)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Run "pull". This should automatically convert "c" to None. for map_index in range(3): From e708941198e969241840983ccda3ef1aca3cd4a6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 15 Apr 2025 15:56:30 +0200 Subject: [PATCH 23/40] refactor: Refactored CallableResultMixin and made it abstract class --- .../src/airflow/sdk/definitions/xcom_arg.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index afd93e48321df..316a2a54d9d98 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -20,9 +20,10 @@ import contextlib import inspect import itertools +from abc import ABCMeta from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator @@ -43,6 +44,7 @@ # safety (those callables are arbitrary user code). MapCallables = Sequence[Callable[[Any], Any]] FilterCallables = Sequence[Callable[[Any], bool]] +T = TypeVar("T", bound=Sequence) class XComArg(ResolveMixin, DependencyMixin): @@ -391,9 +393,9 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class CallableResultMixin: - def __init__(self, value: Iterable | Sequence | dict, callables: MapCallables) -> None: - self.value = value +class CallableResultMixin(Generic[T], Sequence, metaclass=ABCMeta): + def __init__(self, value: T, callables: MapCallables) -> None: + self.value = list(value.items()) if isinstance(value, dict) else value self.callables = callables def _apply_callables(self, value): @@ -402,7 +404,7 @@ def _apply_callables(self, value): return value -class _MapResult(CallableResultMixin, Sequence): +class _MapResult(CallableResultMixin[Sequence]): def __getitem__(self, index: Any) -> Any: value = self._apply_callables(self.value[index]) return value @@ -411,7 +413,7 @@ def __len__(self) -> int: return len(self.value) -class _LazyMapResult(CallableResultMixin, Sequence): +class _LazyMapResult(CallableResultMixin[list]): def __init__(self, value: Iterable, callables: MapCallables) -> None: super().__init__([], callables) self._iterator = iter(value) @@ -438,7 +440,8 @@ def __getitem__(self, index: Any) -> Any: self._next_mapped() except StopIteration: raise IndexError - return self.value[index] + value = self.value[index] + return value def __len__(self) -> int: # Fully consume the iterator to get accurate length @@ -630,7 +633,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) -class _FilterResult(CallableResultMixin, Sequence): +class _FilterResult(CallableResultMixin[list]): def __init__(self, value: Iterable | Sequence | dict, callables: FilterCallables) -> None: super().__init__([], callables) self._iterator = iter(value) @@ -658,7 +661,8 @@ def __getitem__(self, index: Any) -> Any: except StopIteration: raise IndexError - return self.value[index] + value = self.value[index] + return value def __len__(self) -> int: # Force full evaluation to determine total length From 3a5a199e3cefd1fdc6bffbc77a1434c7b6cc7a1f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 15 Apr 2025 15:56:52 +0200 Subject: [PATCH 24/40] refactor: Added test case when value is dict instead of list --- .../task_sdk/definitions/test_xcom_arg.py | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index bc0a654e9a649..408a1d512b9f8 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -60,7 +60,7 @@ def pull(value): assert results == {"aa", "bb", "cc"} -def test_xcom_map_transform_to_none_and_filter(run_ti: RunTI, mock_supervisor_comms): +def test_xcom_map_transform_to_none_and_filter_on_list(run_ti: RunTI, mock_supervisor_comms): results = set() values = ["a", "b", "c"] @@ -74,10 +74,10 @@ def push(): def pull(value): results.add(value) - def c_to_none(v): - if v == "c": + def c_to_none(value): + if value == "c": return None - return v + return value pull.expand(value=push().map(c_to_none).filter(None)) @@ -91,6 +91,37 @@ def c_to_none(v): assert results == {"a", "b"} +def test_xcom_map_transform_to_none_and_filter_on_dict(run_ti: RunTI, mock_supervisor_comms): + results = set() + values = {"a": "alpha", "b": "beta", "c": "charly"} + + with DAG("test") as dag: + + @dag.task() + def push(): + return values + + @dag.task() + def pull(value): + results.add(value) + + def c_to_none(value): + if "c" in value: + return None + return value + + pull.expand(value=push().map(c_to_none).filter(None)) + + # Mock xcom result from push task + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) + + # Run "pull". This should automatically convert "c" to None. + for map_index in range(3): + assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS + + assert dict(results) == {"a": "alpha", "b": "beta"} + + def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): results = set() From f83d2acfe232f5d52a2528fcff6731891d4b6397 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 08:08:48 +0200 Subject: [PATCH 25/40] refactor: Fixed conversion of dict to list in CallableResultMixin --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +- task-sdk/tests/task_sdk/definitions/test_xcom_arg.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 316a2a54d9d98..e73fdd98abe75 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -395,7 +395,7 @@ def _get_callable_name(f: Callable | str) -> str: class CallableResultMixin(Generic[T], Sequence, metaclass=ABCMeta): def __init__(self, value: T, callables: MapCallables) -> None: - self.value = list(value.items()) if isinstance(value, dict) else value + self.value = list(value) if isinstance(value, dict) else value self.callables = callables def _apply_callables(self, value): diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 408a1d512b9f8..231e687f6b227 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -119,7 +119,7 @@ def c_to_none(value): for map_index in range(3): assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS - assert dict(results) == {"a": "alpha", "b": "beta"} + assert results == {"a", "b"} def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): From f828b7f2bc04b9d7407b40247894180766fdf089 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 08:09:53 +0200 Subject: [PATCH 26/40] refactor: Added mapping as bounded type --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index e73fdd98abe75..d8eb3a5429835 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -44,7 +44,7 @@ # safety (those callables are arbitrary user code). MapCallables = Sequence[Callable[[Any], Any]] FilterCallables = Sequence[Callable[[Any], bool]] -T = TypeVar("T", bound=Sequence) +T = TypeVar("T", bound=Union[Sequence[Any], Mapping[Any, Any]]) class XComArg(ResolveMixin, DependencyMixin): From 7fac45230bc278d2c84e0584b3d59575f0763dbe Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 08:14:50 +0200 Subject: [PATCH 27/40] refactor: Ignore type check on XComResult --- task-sdk/tests/task_sdk/definitions/test_xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 231e687f6b227..c8ee524c10981 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -113,7 +113,7 @@ def c_to_none(value): pull.expand(value=push().map(c_to_none).filter(None)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # type: ignore # Run "pull". This should automatically convert "c" to None. for map_index in range(3): From 04fb8077239673e25b03cdfe87ace507dea7233f Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 08:16:29 +0200 Subject: [PATCH 28/40] refactor: Changed elif to if in resolved method of MapXComArg --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d8eb3a5429835..5efb23f01b6fe 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -500,7 +500,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) if isinstance(value, (Sequence, dict)): return _MapResult(value, self.callables) - elif isinstance(value, Iterable): + if isinstance(value, Iterable): return _LazyMapResult(value, self.callables) raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") From ec273bf94a38c1ad9fb4fb96603bb890adbf0e02 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 08:34:24 +0200 Subject: [PATCH 29/40] refactor: Explicitly convert value to Sequence in CallableResultMixin --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 3139d4195e5c2..7dd93888e465a 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -395,9 +395,15 @@ def _get_callable_name(f: Callable | str) -> str: class CallableResultMixin(Generic[T], Sequence, metaclass=ABCMeta): def __init__(self, value: T, callables: MapCallables) -> None: - self.value = list(value) if isinstance(value, dict) else value + self.value = self._convert(value) self.callables = callables + @classmethod + def _convert(cls, value: T) -> Sequence: + if isinstance(value, Mapping): + return list(value) + return value + def _apply_callables(self, value): for func in self.callables: value = func(value) From 800e8ea8502af91923ce90b1f75952d6c32c7d75 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 09:02:50 +0200 Subject: [PATCH 30/40] refactor: Simplified _LazyMapResult and _FilterResult --- .../src/airflow/sdk/definitions/xcom_arg.py | 99 +++++++------------ 1 file changed, 38 insertions(+), 61 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 7dd93888e465a..4b482f127509d 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -17,11 +17,11 @@ from __future__ import annotations -import contextlib import inspect import itertools from abc import ABCMeta -from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized +from collections.abc import Iterable, Iterator, Mapping, Set, Sequence, Sized +from contextlib import suppress from functools import singledispatch from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload @@ -44,7 +44,6 @@ # safety (those callables are arbitrary user code). MapCallables = Sequence[Callable[[Any], Any]] FilterCallables = Sequence[Callable[[Any], bool]] -T = TypeVar("T", bound=Union[Sequence[Any], Mapping[Any, Any]]) class XComArg(ResolveMixin, DependencyMixin): @@ -386,31 +385,38 @@ def _get_callable_name(f: Callable | str) -> str: return f.__name__ # Parse the source to find whatever is behind "def". For safety, we don't # want to evaluate the code in any meaningful way! - with contextlib.suppress(Exception): + with suppress(Exception): kw, name, _ = f.lstrip().split(None, 2) if kw == "def": return name return "" -class CallableResultMixin(Generic[T], Sequence, metaclass=ABCMeta): - def __init__(self, value: T, callables: MapCallables) -> None: +class CallableResultMixin(Sequence, metaclass=ABCMeta): + def __init__(self, value: list | set | dict, callables: MapCallables) -> None: self.value = self._convert(value) self.callables = callables @classmethod - def _convert(cls, value: T) -> Sequence: - if isinstance(value, Mapping): + def _convert(cls, value: list | set | dict) -> Sequence: + if isinstance(value, dict): return list(value) return value + def append(self, value: list | set) -> Any: + if isinstance(self.value, list): + self.value.append(value) + elif isinstance(self.value, set): + self.value.add(value) + return value + def _apply_callables(self, value): for func in self.callables: value = func(value) return value -class _MapResult(CallableResultMixin[Sequence]): +class _MapResult(CallableResultMixin): def __getitem__(self, index: Any) -> Any: value = self._apply_callables(self.value[index]) return value @@ -419,23 +425,15 @@ def __len__(self) -> int: return len(self.value) -class _LazyMapResult(CallableResultMixin[list]): +class _LazyMapResult(CallableResultMixin): def __init__(self, value: Iterable, callables: MapCallables) -> None: super().__init__([], callables) self._iterator = iter(value) - self._exhausted = False def _next_mapped(self) -> Any: - """Return the next transformed item from the iterator.""" - while not self._exhausted: - try: - item = next(self._iterator) - result = self._apply_callables(item) - self.value.append(result) - return result - except StopIteration: - self._exhausted = True - raise StopIteration + item = next(self._iterator) + result = self.append(self._apply_callables(item)) + return result def __getitem__(self, index: Any) -> Any: if index < 0: @@ -446,26 +444,20 @@ def __getitem__(self, index: Any) -> Any: self._next_mapped() except StopIteration: raise IndexError - value = self.value[index] - return value + return self.value[index] def __len__(self) -> int: - # Fully consume the iterator to get accurate length - while not self._exhausted: - try: + with suppress(StopIteration): + while True: self._next_mapped() - except StopIteration: - break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - while not self._exhausted: - try: + with suppress(StopIteration): + while True: yield self._next_mapped() - except StopIteration: - break class MapXComArg(XComArg): @@ -639,54 +631,39 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) -class _FilterResult(CallableResultMixin[list]): - def __init__(self, value: Iterable | Sequence | dict, callables: FilterCallables) -> None: +class _FilterResult(CallableResultMixin): + def __init__(self, value: Iterable | list | set | dict, callables: FilterCallables) -> None: super().__init__([], callables) self._iterator = iter(value) - self._exhausted = False def _next_filtered(self) -> Any: """Return the next item from the iterator that passes all filters.""" - while not self._exhausted: - try: - item = next(self._iterator) - if self._apply_callables(item): - self.value.append(item) - return item - except StopIteration: - self._exhausted = True - raise StopIteration + while True: + item = next(self._iterator) + if self._apply_callables(item): + return self.append(item) - def __getitem__(self, index: Any) -> Any: + def __getitem__(self, index: int) -> Any: if index < 0: raise IndexError while len(self.value) <= index: - try: - self._next_filtered() - except StopIteration: - raise IndexError - - value = self.value[index] - return value + self._next_filtered() + return self.value[index] def __len__(self) -> int: - # Force full evaluation to determine total length - while not self._exhausted: - try: + # Fully consume the iterator to get total length + with suppress(StopIteration): + while True: self._next_filtered() - except StopIteration: - break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - while not self._exhausted: - try: + with suppress(StopIteration): + while True: yield self._next_filtered() - except StopIteration: - break class FilterXComArg(XComArg): From eb28ffe9f9848aec7854f4958b17b0e1d50552b6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 09:08:18 +0200 Subject: [PATCH 31/40] refactor: Re-used comon values variable where possible in test xcom args --- .../task_sdk/definitions/test_xcom_arg.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index c8ee524c10981..4f0df29fc5d7d 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -36,6 +36,8 @@ def test_xcom_map(run_ti: RunTI, mock_supervisor_comms): results = set() + values = ["a", "b", "c"] + with DAG("test") as dag: @dag.task @@ -52,7 +54,7 @@ def pull(value): assert set(dag.task_dict) == {"push", "pull"} # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) for map_index in range(3): assert run_ti(dag, "pull", map_index) == TerminalTIState.SUCCESS @@ -124,12 +126,13 @@ def c_to_none(value): def test_xcom_convert_to_kwargs_fails_task(run_ti: RunTI, mock_supervisor_comms, captured_logs): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -143,7 +146,7 @@ def c_to_none(v): pull.expand_kwargs(push().map(c_to_none)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # The first two "pull" tis should succeed. for map_index in range(2): @@ -177,11 +180,13 @@ def c_to_none(v): def test_xcom_map_error_fails_task(mock_supervisor_comms, run_ti, captured_logs): + values = ["a", "b", "c"] + with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -195,7 +200,7 @@ def does_not_work_with_c(v): pull.expand_kwargs(push().map(does_not_work_with_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # The third one (for "c") will fail. assert run_ti(dag, "pull", 2) == TerminalTIState.FAILED @@ -222,12 +227,13 @@ def does_not_work_with_c(v): def test_xcom_map_nest(mock_supervisor_comms, run_ti): results = set() + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def pull(value): @@ -237,7 +243,7 @@ def pull(value): pull.expand_kwargs(converted) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Now "pull" should apply the mapping functions in order. for map_index in range(3): @@ -295,12 +301,13 @@ def xcom_get(): def test_xcom_map_raise_to_skip(run_ti, mock_supervisor_comms): result = [] + values = ["a", "b", "c"] with DAG("test") as dag: @dag.task() def push(): - return ["a", "b", "c"] + return values @dag.task() def forward(value): @@ -314,7 +321,7 @@ def skip_c(v): forward.expand_kwargs(push().map(skip_c)) # Mock xcom result from push task - mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=["a", "b", "c"]) + mock_supervisor_comms.get_message.return_value = XComResult(key="return_value", value=values) # Run "forward". This should automatically skip "c". states = [run_ti(dag, "forward", map_index) for map_index in range(3)] From 844b785209979df2bd2fbab2e059df9744b2ff6a Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 10:41:21 +0200 Subject: [PATCH 32/40] refactor: Changed types of results classes --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 4b482f127509d..3455da0927f82 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -393,17 +393,17 @@ def _get_callable_name(f: Callable | str) -> str: class CallableResultMixin(Sequence, metaclass=ABCMeta): - def __init__(self, value: list | set | dict, callables: MapCallables) -> None: + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: self.value = self._convert(value) self.callables = callables @classmethod - def _convert(cls, value: list | set | dict) -> Sequence: + def _convert(cls, value: Sequence | dict) -> list | set: if isinstance(value, dict): return list(value) return value - def append(self, value: list | set) -> Any: + def append(self, value: Sequence) -> Any: if isinstance(self.value, list): self.value.append(value) elif isinstance(self.value, set): @@ -632,7 +632,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: class _FilterResult(CallableResultMixin): - def __init__(self, value: Iterable | list | set | dict, callables: FilterCallables) -> None: + def __init__(self, value: Iterable, callables: FilterCallables) -> None: super().__init__([], callables) self._iterator = iter(value) From 4e2bda92562079c52a755f45ce40dfe94cd720c9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 11:28:18 +0200 Subject: [PATCH 33/40] refactor: Raise an ValueError if value isn't list, set or dict --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 3455da0927f82..49d6578916daa 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -401,7 +401,11 @@ def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: def _convert(cls, value: Sequence | dict) -> list | set: if isinstance(value, dict): return list(value) - return value + if isinstance(value, (list, set)): + return value + raise ValueError( + f"XCom filter expects sequence or dict, not {type(value).__name__}" + ) def append(self, value: Sequence) -> Any: if isinstance(self.value, list): From 68a8329677a7d4f26af5470aa3eb2717a2c61839 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 12:27:34 +0200 Subject: [PATCH 34/40] refactor: Sets musts be converted to lists also otherwise it's not indexable --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 49d6578916daa..6bdf28dd71c57 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -398,10 +398,10 @@ def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: self.callables = callables @classmethod - def _convert(cls, value: Sequence | dict) -> list | set: - if isinstance(value, dict): + def _convert(cls, value: Sequence | dict) -> list: + if isinstance(value, (dict, set)): return list(value) - if isinstance(value, (list, set)): + if isinstance(value, list): return value raise ValueError( f"XCom filter expects sequence or dict, not {type(value).__name__}" From d0835a61f62fe30bee99abc695b73a3f69ece42f Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 13:12:56 +0200 Subject: [PATCH 35/40] refactor: Try except the StopIteration when yielding instead of suppress --- .../src/airflow/sdk/definitions/xcom_arg.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 6bdf28dd71c57..daa7e3ef63314 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -451,17 +451,21 @@ def __getitem__(self, index: Any) -> Any: return self.value[index] def __len__(self) -> int: - with suppress(StopIteration): - while True: + while True: + try: self._next_mapped() + except StopIteration: + break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - with suppress(StopIteration): - while True: + while True: + try: yield self._next_mapped() + except StopIteration: + break class MapXComArg(XComArg): @@ -657,17 +661,21 @@ def __getitem__(self, index: int) -> Any: def __len__(self) -> int: # Fully consume the iterator to get total length - with suppress(StopIteration): - while True: + while True: + try: self._next_filtered() + except StopIteration: + break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - with suppress(StopIteration): - while True: + while True: + try: yield self._next_filtered() + except StopIteration: + break class FilterXComArg(XComArg): From 91adef40b1dfec8815899fd883531864029cfde5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 16 Apr 2025 13:29:10 +0200 Subject: [PATCH 36/40] refactor: Fixed __getitem__ magic method of _FilterResult --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index daa7e3ef63314..c6ec18ce7c0ac 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -656,7 +656,10 @@ def __getitem__(self, index: int) -> Any: raise IndexError while len(self.value) <= index: - self._next_filtered() + try: + self._next_filtered() + except StopIteration: + break return self.value[index] def __len__(self) -> int: From 40191f8aa57df3462926557c7f00efd2a7bf70c6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 17 Apr 2025 12:05:48 +0200 Subject: [PATCH 37/40] refactor: Renamed CallableResultMixin to _MappableResult and refactored _LazyMapResult and _FilterResult --- .../src/airflow/sdk/definitions/xcom_arg.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index c6ec18ce7c0ac..1413da654a85a 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -23,7 +23,7 @@ from collections.abc import Iterable, Iterator, Mapping, Set, Sequence, Sized from contextlib import suppress from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Callable, overload from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator @@ -392,8 +392,8 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class CallableResultMixin(Sequence, metaclass=ABCMeta): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: +class _MappableResult(Sequence, metaclass=ABCMeta): + def __init__(self, value: Sequence | dict, callables: FilterCallables | MapCallables) -> None: self.value = self._convert(value) self.callables = callables @@ -407,20 +407,13 @@ def _convert(cls, value: Sequence | dict) -> list: f"XCom filter expects sequence or dict, not {type(value).__name__}" ) - def append(self, value: Sequence) -> Any: - if isinstance(self.value, list): - self.value.append(value) - elif isinstance(self.value, set): - self.value.add(value) - return value - def _apply_callables(self, value): for func in self.callables: value = func(value) return value -class _MapResult(CallableResultMixin): +class _MapResult(_MappableResult): def __getitem__(self, index: Any) -> Any: value = self._apply_callables(self.value[index]) return value @@ -429,14 +422,14 @@ def __len__(self) -> int: return len(self.value) -class _LazyMapResult(CallableResultMixin): +class _LazyMapResult(_MappableResult): def __init__(self, value: Iterable, callables: MapCallables) -> None: super().__init__([], callables) self._iterator = iter(value) def _next_mapped(self) -> Any: item = next(self._iterator) - result = self.append(self._apply_callables(item)) + result = self.value.append(self._apply_callables(item)) return result def __getitem__(self, index: Any) -> Any: @@ -451,6 +444,7 @@ def __getitem__(self, index: Any) -> Any: return self.value[index] def __len__(self) -> int: + # Fully consume the iterator to get total length while True: try: self._next_mapped() @@ -639,17 +633,16 @@ def resolve(self, context: Mapping[str, Any]) -> Any: return _ConcatResult(values) -class _FilterResult(CallableResultMixin): +class _FilterResult(_MappableResult): def __init__(self, value: Iterable, callables: FilterCallables) -> None: super().__init__([], callables) self._iterator = iter(value) def _next_filtered(self) -> Any: - """Return the next item from the iterator that passes all filters.""" while True: item = next(self._iterator) if self._apply_callables(item): - return self.append(item) + return self.value.append(item) def __getitem__(self, index: int) -> Any: if index < 0: @@ -680,6 +673,12 @@ def __iter__(self) -> Iterator: except StopIteration: break + def _apply_callables(self, value) -> bool: + for func in self.callables: + if not func(value): + return False + return True + class FilterXComArg(XComArg): """ From dec8eb8bdb8ae6319c7de625e65e42553cd97365 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 17 Apr 2025 13:40:27 +0200 Subject: [PATCH 38/40] refactor: Check if result in PlainXComArg needs runtime resolution --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 1413da654a85a..d8b405a66c8b8 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -34,6 +34,8 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY +from airflow.sdk.definitions._internal.expandinput import _needs_run_time_resolution + if TYPE_CHECKING: from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -366,6 +368,8 @@ def resolve(self, context: Mapping[str, Any]) -> Any: default=NOTSET, ) if not isinstance(result, ArgNotSet): + if _needs_run_time_resolution(result): + result = result.resolve(context) return result if self.key == XCOM_RETURN_KEY: return None From e586ed00e1a05ed578abe1c5eba0ac7ad4172022 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 17 Apr 2025 21:16:54 +0200 Subject: [PATCH 39/40] refactor: Refactored _LazyMapResult and _FilterResult with __next__ magic method --- .../src/airflow/sdk/definitions/xcom_arg.py | 58 ++++++++++--------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d8b405a66c8b8..dcc5da01d995c 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -19,11 +19,11 @@ import inspect import itertools -from abc import ABCMeta -from collections.abc import Iterable, Iterator, Mapping, Set, Sequence, Sized +from abc import ABCMeta, abstractmethod +from collections.abc import Iterable, Iterator, Mapping, Sequence, Sized from contextlib import suppress from functools import singledispatch -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, overload, _T_co from airflow.exceptions import AirflowException, XComNotFound from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator @@ -34,8 +34,6 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY -from airflow.sdk.definitions._internal.expandinput import _needs_run_time_resolution - if TYPE_CHECKING: from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -368,7 +366,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: default=NOTSET, ) if not isinstance(result, ArgNotSet): - if _needs_run_time_resolution(result): + if isinstance(result, ResolveMixin): result = result.resolve(context) return result if self.key == XCOM_RETURN_KEY: @@ -396,13 +394,19 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class _MappableResult(Sequence, metaclass=ABCMeta): +class _MappableResult(Sequence): def __init__(self, value: Sequence | dict, callables: FilterCallables | MapCallables) -> None: self.value = self._convert(value) self.callables = callables - @classmethod - def _convert(cls, value: Sequence | dict) -> list: + def __getitem__(self, index: Any) -> Any: + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + @staticmethod + def _convert(value: Sequence | dict) -> list: if isinstance(value, (dict, set)): return list(value) if isinstance(value, list): @@ -411,7 +415,7 @@ def _convert(cls, value: Sequence | dict) -> list: f"XCom filter expects sequence or dict, not {type(value).__name__}" ) - def _apply_callables(self, value): + def _apply_callables(self, value) -> Any: for func in self.callables: value = func(value) return value @@ -431,10 +435,10 @@ def __init__(self, value: Iterable, callables: MapCallables) -> None: super().__init__([], callables) self._iterator = iter(value) - def _next_mapped(self) -> Any: - item = next(self._iterator) - result = self.value.append(self._apply_callables(item)) - return result + def __next__(self) -> Any: + value = self._apply_callables(next(self._iterator)) + self.value.append(value) + return value def __getitem__(self, index: Any) -> Any: if index < 0: @@ -442,26 +446,24 @@ def __getitem__(self, index: Any) -> Any: while len(self.value) <= index: try: - self._next_mapped() + next(self) except StopIteration: raise IndexError return self.value[index] def __len__(self) -> int: - # Fully consume the iterator to get total length while True: try: - self._next_mapped() + next(self) except StopIteration: break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - while True: try: - yield self._next_mapped() + yield next(self) except StopIteration: break @@ -642,11 +644,12 @@ def __init__(self, value: Iterable, callables: FilterCallables) -> None: super().__init__([], callables) self._iterator = iter(value) - def _next_filtered(self) -> Any: + def __next__(self) -> Any: while True: - item = next(self._iterator) - if self._apply_callables(item): - return self.value.append(item) + value = next(self._iterator) + if self._apply_callables(value): + self.value.append(value) + return value def __getitem__(self, index: int) -> Any: if index < 0: @@ -654,26 +657,25 @@ def __getitem__(self, index: int) -> Any: while len(self.value) <= index: try: - self._next_filtered() + next(self) except StopIteration: break + return self.value[index] def __len__(self) -> int: - # Fully consume the iterator to get total length while True: try: - self._next_filtered() + next(self) except StopIteration: break return len(self.value) def __iter__(self) -> Iterator: yield from self.value - while True: try: - yield self._next_filtered() + yield next(self) except StopIteration: break From 7a9c48758485a162f336cf82abee3e46932f20e7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 17 Apr 2025 21:28:37 +0200 Subject: [PATCH 40/40] refactor: Changed non_filter method of FilterXComArg to staticmethod --- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index dcc5da01d995c..95d99a957d9f3 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -711,8 +711,8 @@ def __init__( raise ValueError("filter() argument must be a plain function, not a @task operator") self.callables = callables - @classmethod - def none_filter(cls, value) -> bool: + @staticmethod + def none_filter(value) -> bool: return value if True else False def __repr__(self) -> str: