From a9e33903e0ea53e7820b12435fc52d61b2c5aba7 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 31 Jan 2022 01:43:43 +0800 Subject: [PATCH 1/6] Clear XCom after end-to-end DAG tests --- tests/jobs/test_backfill_job.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 0878f63ddffc6..d78413ca60acb 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -47,7 +47,13 @@ from airflow.utils.timeout import timeout from airflow.utils.types import DagRunType from tests.models import TEST_DAGS_FOLDER -from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots +from tests.test_utils.db import ( + clear_db_dags, + clear_db_pools, + clear_db_runs, + clear_db_xcom, + set_default_pool_slots, +) from tests.test_utils.mock_executor import MockExecutor from tests.test_utils.timetables import cron_timetable @@ -66,6 +72,7 @@ class TestBackfillJob: def clean_db(): clear_db_dags() clear_db_runs() + clear_db_xcom() clear_db_pools() @pytest.fixture(autouse=True) @@ -1512,7 +1519,7 @@ def test_backfill_has_job_id(self): job.run() assert executor.job_id is not None - def test_mapped_dag(self, dag_maker): + def test_mapped_dag(self): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor From e96ac4d5b3a88f0cfa5b7b383c87794926a7a825 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 31 Jan 2022 04:05:43 +0800 Subject: [PATCH 2/6] Add test for taskflow task-mapping --- tests/dags/test_mapped_taskflow.py | 31 ++++++++++++++++++++++++++++++ tests/jobs/test_backfill_job.py | 7 ++++--- 2 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 tests/dags/test_mapped_taskflow.py diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py new file mode 100644 index 0000000000000..f21a9a5e8a42d --- /dev/null +++ b/tests/dags/test_mapped_taskflow.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow import DAG +from airflow.utils.dates import days_ago + +with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag: + + @dag.task + def make_list(): + return [1, 2, {'a': 'b'}] + + @dag.task + def consumer(value): + print(repr(value)) + + consumer.map(value=make_list()) diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index d78413ca60acb..40593d526a328 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -1519,13 +1519,14 @@ def test_backfill_has_job_id(self): job.run() assert executor.job_id is not None - def test_mapped_dag(self): + @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"]) + def test_mapped_dag(self, dag_id): """End-to-end test of a simple mapped dag""" # Use SequentialExecutor for more predictable test behaviour from airflow.executors.sequential_executor import SequentialExecutor - self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py')) - dag = self.dagbag.get_dag('test_mapped_classic') + self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py')) + dag = self.dagbag.get_dag(dag_id) # This needs a real executor to run, so that the `make_list` task can write out the TaskMap From 69ae74e2bbc4f7c985b949d6fb75e04492409ec9 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 31 Jan 2022 04:09:34 +0800 Subject: [PATCH 3/6] Rewrite decorated task mapping Mapping a traditional operator (where arguments go to the operator) and a task flow operator (where arguments go to the *function*) have very different semantics, so we need some special code for them. --- airflow/decorators/base.py | 74 +++++++++++++++++++++++++-------- airflow/models/baseoperator.py | 53 ++++------------------- airflow/models/taskinstance.py | 2 +- tests/decorators/test_python.py | 40 +++++++++++------- 4 files changed, 92 insertions(+), 77 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9cf423fb69e11..ee439442b1769 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -280,30 +280,70 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names names = ", ".join(repr(n) for n in unknown_args) raise TypeError(f'{funcname} got unexpected keyword arguments {names}') - def map( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> XComArg: + def map(self, *args, **kwargs) -> XComArg: self._validate_arg_names("map", kwargs) - dag = dag or DagContext.get_current_dag() - task_group = task_group or TaskGroupContext.get_current_task_group(dag) - task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) - operator = MappedOperator.from_decorator( - decorator=self, + partial_kwargs: Dict[str, Any] = { + **self.kwargs, + "python_callable": self.function, + "multiple_outputs": self.multiple_outputs, + } + + dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) + task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag)) + task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) + + operator = MappedOperator( + operator_class=self.operator_class, + partial_kwargs=partial_kwargs, + mapped_kwargs={}, + task_id=task_id, dag=dag, task_group=task_group, - task_id=task_id, - mapped_kwargs=kwargs, + deps=MappedOperator._deps(self.operator_class.deps), ) + operator.mapped_kwargs["op_args"] = list(args) + operator.mapped_kwargs["op_kwargs"] = kwargs + + for arg in itertools.chain(args, kwargs.values()): + XComArg.apply_upstream_relationship(operator, arg) return XComArg(operator=operator) - def partial( - self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> "_TaskDecorator[Function, OperatorSubclass]": - self._validate_arg_names("partial", kwargs, {'task_id'}) - partial_kwargs = self.kwargs.copy() - partial_kwargs.update(kwargs) - return attr.evolve(self, kwargs=partial_kwargs) + def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]": + self._validate_arg_names("partial", kwargs) + + op_args = self.kwargs.get("op_args", []) + op_args.extend(args) + + op_kwargs = self.kwargs.get("op_kwargs", {}) + duplicated_keys = set(op_kwargs).intersection(kwargs) + if len(duplicated_keys) == 1: + raise TypeError(f"duplicated partial argument: {duplicated_keys.pop()}") + elif duplicated_keys: + duplicated_keys_display = ", ".join(sorted(duplicated_keys)) + raise TypeError(f"duplicated partial arguments: {duplicated_keys_display}") + op_kwargs.update(kwargs) + + return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs}) + + +class Task(Generic[Function]): + """Declaration of a @task-decorated callable for type-checking. + + An instance of this type inherits the call signature of the decorated + function wrapped in it (not *exactly* since it actually returns an XComArg, + but there's no way to express that right now), and provides two additional + methods for task-mapping. + + This type is implemented by ``_TaskDecorator`` at runtime. + """ + + __call__: Function + + function: Function + + map: Callable[..., XComArg] + partial: Callable[..., "Task[Function]"] class Task(Generic[Function]): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 34c84128391dc..5a219992843f4 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -82,7 +82,6 @@ from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: - from airflow.decorators.base import _TaskDecorator from airflow.models.dag import DAG from airflow.utils.task_group import TaskGroup @@ -243,7 +242,7 @@ def __new__(cls, name, bases, namespace, **kwargs): return new_cls # The class level partial function. This is what handles the actual mapping - def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): + def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator": operator_class = cast("Type[BaseOperator]", cls) # Validate that the args we passed are known -- at call/DAG parse time, not run time! _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) @@ -1632,7 +1631,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> dag._remove_task(operator.task_id) operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore - return MappedOperator( + return cls( operator_class=type(operator), task_id=operator.task_id, task_group=task_group, @@ -1648,37 +1647,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> deps=cls._deps(operator.deps), ) - @classmethod - def from_decorator( - cls, - *, - decorator: "_TaskDecorator", - dag: Optional["DAG"], - task_group: Optional["TaskGroup"], - task_id: str, - mapped_kwargs: Dict[str, Any], - ) -> "MappedOperator": - """Create a mapped operator from a task decorator. - - Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``. - The task decorator calling this should be responsible for validation. - """ - from airflow.models.xcom_arg import XComArg - - operator = MappedOperator( - operator_class=decorator.operator_class, - partial_kwargs=decorator.kwargs, - mapped_kwargs={}, - task_id=task_id, - dag=dag, - task_group=task_group, - deps=cls._deps(decorator.operator_class.deps), - ) - operator.mapped_kwargs.update(mapped_kwargs) - for arg in mapped_kwargs.values(): - XComArg.apply_upstream_relationship(operator, arg) - return operator - @classmethod def _deps(cls, deps: Iterable[BaseTIDep]): if deps is BaseOperator.deps: @@ -1902,22 +1870,17 @@ def expand_mapped_task( return ret - def unmap(self) -> BaseOperator: - """Get the "normal" Operator after applying the current mapping""" + def create_unmapped_operator(self, dag: "DAG", kwargs: Dict[str, Any]) -> BaseOperator: assert not isinstance(self.operator_class, str) + return self.operator_class(dag=dag, task_id=self.task_id, **kwargs) + def unmap(self) -> BaseOperator: + """Get the "normal" Operator after applying the current mapping""" dag = self.get_dag() if not dag: - raise RuntimeError("Cannot unmapp a task unless it has a dag") - - args = { - **self.partial_kwargs, - **self.mapped_kwargs, - } + raise RuntimeError("Cannot unmap a task unless it has a DAG") dag._remove_task(self.task_id) - task = self.operator_class(task_id=self.task_id, dag=self.dag, **args) - - return task + return self.create_unmapped_operator(dag, {**self.partial_kwargs, **self.mapped_kwargs}) # TODO: Deprecate for Airflow 3.0 diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4996b9a7db073..f10032dfc2964 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1713,7 +1713,7 @@ def handle_failure( test_mode: Optional[bool] = None, force_fail: bool = False, error_file: Optional[str] = None, - session=NEW_SESSION, + session: Session = NEW_SESSION, ) -> None: """Handle Failure for the TaskInstance""" if test_mode is None: diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 0c93b49e1fe00..02860551b1c05 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -490,7 +490,7 @@ def double(number: int): assert isinstance(doubled_0, XComArg) assert isinstance(doubled_0.operator, MappedOperator) assert doubled_0.operator.task_id == "double" - assert doubled_0.operator.mapped_kwargs == {"number": literal} + assert doubled_0.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}} assert doubled_1.operator.task_id == "double__1" @@ -514,25 +514,37 @@ def test_partial_mapped_decorator() -> None: def product(number: int, multiple: int): return number * multiple + literal = [1, 2, 3] + with DAG('test_dag', start_date=DEFAULT_DATE) as dag: - literal = [1, 2, 3] - quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal) + quadrupled = product.partial(multiple=3).map(number=literal) doubled = product.partial(multiple=2).map(number=literal) trippled = product.partial(multiple=3).map(number=literal) - product.partial(multiple=2) + product.partial(multiple=2) # No operator is actually created. + + assert dag.task_dict == { + "product": quadrupled.operator, + "product__1": doubled.operator, + "product__2": trippled.operator, + } assert isinstance(doubled, XComArg) assert isinstance(doubled.operator, MappedOperator) - assert doubled.operator.task_id == "product" - assert doubled.operator.mapped_kwargs == {"number": literal} - assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2} - - assert trippled.operator.task_id == "product__1" - assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3} - - assert quadrupled.operator.task_id == "times_4" + assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}} + assert doubled.operator.partial_kwargs == { + "python_callable": product.function, + "multiple_outputs": False, + "op_args": [], + "op_kwargs": {"multiple": 2}, + } + + assert isinstance(trippled.operator, MappedOperator) # For type-checking on partial_kwargs. + assert trippled.operator.partial_kwargs == { + "python_callable": product.function, + "multiple_outputs": False, + "op_args": [], + "op_kwargs": {"multiple": 3}, + } assert doubled.operator is not trippled.operator - - assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks From 69889520721ee5bc90927351038f60d70121489f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 9 Feb 2022 15:41:34 +0800 Subject: [PATCH 4/6] Merge decorated mapped task's op args and kwargs Previously we were merging partial_kwargs and mapped_kwargs too naively and did not correctly handle op_args and op_kwargs; those need special logic due to the mapping semantics of decorated tasks. --- airflow/decorators/base.py | 56 ++++++++++++++++++++------------- airflow/models/baseoperator.py | 6 ++-- tests/decorators/test_python.py | 19 ++++++++++- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index ee439442b1769..5d0fd12c0416a 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -293,7 +293,7 @@ def map(self, *args, **kwargs) -> XComArg: task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag)) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) - operator = MappedOperator( + operator = DecoratedMappedOperator( operator_class=self.operator_class, partial_kwargs=partial_kwargs, mapped_kwargs={}, @@ -316,34 +316,48 @@ def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass op_args.extend(args) op_kwargs = self.kwargs.get("op_kwargs", {}) - duplicated_keys = set(op_kwargs).intersection(kwargs) - if len(duplicated_keys) == 1: - raise TypeError(f"duplicated partial argument: {duplicated_keys.pop()}") - elif duplicated_keys: - duplicated_keys_display = ", ".join(sorted(duplicated_keys)) - raise TypeError(f"duplicated partial arguments: {duplicated_keys_display}") - op_kwargs.update(kwargs) + op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial") return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs}) -class Task(Generic[Function]): - """Declaration of a @task-decorated callable for type-checking. - - An instance of this type inherits the call signature of the decorated - function wrapped in it (not *exactly* since it actually returns an XComArg, - but there's no way to express that right now), and provides two additional - methods for task-mapping. +def _merge_kwargs( + kwargs1: Dict[str, XComArg], + kwargs2: Dict[str, XComArg], + *, + fail_reason: str, +) -> Dict[str, XComArg]: + duplicated_keys = set(kwargs1).intersection(kwargs2) + if len(duplicated_keys) == 1: + raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}") + elif duplicated_keys: + duplicated_keys_display = ", ".join(sorted(duplicated_keys)) + raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}") + return {**kwargs1, **kwargs2} - This type is implemented by ``_TaskDecorator`` at runtime. - """ - __call__: Function +class DecoratedMappedOperator(MappedOperator): + """MappedOperator implementation for @task-decorated task function. - function: Function + This has special logic to merge op_args and op_kwargs. + """ - map: Callable[..., XComArg] - partial: Callable[..., "Task[Function]"] + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: + assert not isinstance(self.operator_class, str) + op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", []) + op_kwargs = _merge_kwargs( + self.partial_kwargs.pop("op_kwargs", {}), + self.mapped_kwargs.pop("op_kwargs", {}), + fail_reason="mapping already partial", + ) + return self.operator_class( + dag=dag, + task_id=self.task_id, + op_args=op_args, + op_kwargs=op_kwargs, + **self.partial_kwargs, + **self.mapped_kwargs, + ) class Task(Generic[Function]): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5a219992843f4..32fe145d5cbf7 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1870,9 +1870,9 @@ def expand_mapped_task( return ret - def create_unmapped_operator(self, dag: "DAG", kwargs: Dict[str, Any]) -> BaseOperator: + def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: assert not isinstance(self.operator_class, str) - return self.operator_class(dag=dag, task_id=self.task_id, **kwargs) + return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs) def unmap(self) -> BaseOperator: """Get the "normal" Operator after applying the current mapping""" @@ -1880,7 +1880,7 @@ def unmap(self) -> BaseOperator: if not dag: raise RuntimeError("Cannot unmap a task unless it has a DAG") dag._remove_task(self.task_id) - return self.create_unmapped_operator(dag, {**self.partial_kwargs, **self.mapped_kwargs}) + return self.create_unmapped_operator(dag) # TODO: Deprecate for Airflow 3.0 diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 02860551b1c05..52f5e8767edba 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -17,7 +17,7 @@ # under the License. import sys from collections import namedtuple -from datetime import date, timedelta +from datetime import date, datetime, timedelta from typing import Dict # noqa: F401 # This is used by annotation tests. from typing import Tuple @@ -548,3 +548,20 @@ def product(number: int, multiple: int): } assert doubled.operator is not trippled.operator + + +def test_mapped_decorator_unmap_merge_op_kwargs(): + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + + @task_decorator + def task1(): + ... + + @task_decorator + def task2(arg1, arg2): + ... + + task2.partial(arg1=1).map(arg2=task1()) + + unmapped = dag.get_task("task2").unmap() + assert set(unmapped.op_kwargs) == {"arg1", "arg2"} From db0274ce89fbb6c9381ef366916cdbdf88cd4134 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 9 Feb 2022 17:15:38 +0800 Subject: [PATCH 5/6] Revise serialization for mapped decorated task Some attributes are removed from serialization to match the format of the (unmapped) _PythonDecoratedOperator. Some simplication is implemented to op_kwargs to save some space. --- airflow/decorators/base.py | 24 +++++---- airflow/models/baseoperator.py | 2 +- airflow/serialization/serialized_objects.py | 27 ++++++++-- tests/decorators/test_python.py | 14 +---- tests/serialization/test_dag_serialization.py | 53 +++++++++++++++++++ 5 files changed, 92 insertions(+), 28 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 5d0fd12c0416a..53a12c62d2318 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -283,17 +283,15 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names def map(self, *args, **kwargs) -> XComArg: self._validate_arg_names("map", kwargs) - partial_kwargs: Dict[str, Any] = { - **self.kwargs, - "python_callable": self.function, - "multiple_outputs": self.multiple_outputs, - } - + partial_kwargs = self.kwargs.copy() dag = partial_kwargs.pop("dag", DagContext.get_current_dag()) task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag)) task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) - operator = DecoratedMappedOperator( + # Unfortunately attrs's type hinting support does not work well with + # subclassing; it complains that arguments forwarded to the superclass + # are "unexpected" (they are fine at runtime). + operator = cast(Any, DecoratedMappedOperator)( operator_class=self.operator_class, partial_kwargs=partial_kwargs, mapped_kwargs={}, @@ -301,7 +299,10 @@ def map(self, *args, **kwargs) -> XComArg: dag=dag, task_group=task_group, deps=MappedOperator._deps(self.operator_class.deps), + multiple_outputs=self.multiple_outputs, + python_callable=self.function, ) + operator.mapped_kwargs["op_args"] = list(args) operator.mapped_kwargs["op_kwargs"] = kwargs @@ -336,11 +337,12 @@ def _merge_kwargs( return {**kwargs1, **kwargs2} +@attr.define(kw_only=True) class DecoratedMappedOperator(MappedOperator): - """MappedOperator implementation for @task-decorated task function. + """MappedOperator implementation for @task-decorated task function.""" - This has special logic to merge op_args and op_kwargs. - """ + multiple_outputs: bool + python_callable: Callable def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: assert not isinstance(self.operator_class, str) @@ -355,6 +357,8 @@ def create_unmapped_operator(self, dag: "DAG") -> BaseOperator: task_id=self.task_id, op_args=op_args, op_kwargs=op_kwargs, + multiple_outputs=self.multiple_outputs, + python_callable=self.python_callable, **self.partial_kwargs, **self.mapped_kwargs, ) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 32fe145d5cbf7..ce2277eed56e0 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1717,7 +1717,7 @@ def inherits_from_dummy_operator(self): @classmethod def get_serialized_fields(cls): if cls.__serialized_fields is None: - fields_dict = attr.fields_dict(cls) + fields_dict = attr.fields_dict(MappedOperator) cls.__serialized_fields = frozenset( fields_dict.keys() - { diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index d6abda7c74899..017f2276964ca 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -16,6 +16,7 @@ # under the License. """Serialized DAG and BaseOperator""" +import contextlib import datetime import enum import logging @@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable: return timetable_class.deserialize(var[Encoding.VAR]) -class _XcomRef(NamedTuple): +class _XComRef(NamedTuple): """ Used to store info needed to create XComArg when deserializing MappedOperator. @@ -497,8 +498,8 @@ def _serialize_xcomarg(cls, arg: XComArg) -> dict: return {"key": arg.key, "task_id": arg.operator.task_id} @classmethod - def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef: - return _XcomRef(key=encoded['key'], task_id=encoded['task_id']) + def _deserialize_xcomref(cls, encoded: dict) -> _XComRef: + return _XComRef(key=encoded['key'], task_id=encoded['task_id']) class DependencyDetector: @@ -566,9 +567,19 @@ def task_type(self, task_type: str): @classmethod def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]: - stock_deps = op.deps is MappedOperator.DEFAULT_DEPS serialize_op = cls._serialize_node(op, include_deps=not stock_deps) + + # Simplify op_kwargs format. It must be a dict, so we flatten it. + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + with contextlib.suppress(KeyError): + op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"] + assert op_kwargs[Encoding.TYPE] == DAT.DICT + serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR] + # It must be a class at this point for it to work, not a string assert isinstance(op.operator_class, type) serialize_op['_task_type'] = op.operator_class.__name__ @@ -715,7 +726,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, elif k == "params": v = cls._deserialize_params_dict(v) elif k in ("mapped_kwargs", "partial_kwargs"): + if "op_kwargs" not in v: + op_kwargs: Optional[dict] = None + else: + op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()} v = {arg: cls._deserialize(value) for arg, value in v.items()} + if op_kwargs is not None: + v["op_kwargs"] = op_kwargs elif k in cls._decorated_fields or k not in op.get_serialized_fields(): v = cls._deserialize(v) # else use v as it is @@ -1002,7 +1019,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG': if isinstance(task, MappedOperator): for d in (task.mapped_kwargs, task.partial_kwargs): for k, v in d.items(): - if not isinstance(v, _XcomRef): + if not isinstance(v, _XComRef): continue d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 52f5e8767edba..235a8297144a6 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -532,20 +532,10 @@ def product(number: int, multiple: int): assert isinstance(doubled, XComArg) assert isinstance(doubled.operator, MappedOperator) assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}} - assert doubled.operator.partial_kwargs == { - "python_callable": product.function, - "multiple_outputs": False, - "op_args": [], - "op_kwargs": {"multiple": 2}, - } + assert doubled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 2}} assert isinstance(trippled.operator, MappedOperator) # For type-checking on partial_kwargs. - assert trippled.operator.partial_kwargs == { - "python_callable": product.function, - "multiple_outputs": False, - "op_args": [], - "op_kwargs": {"multiple": 3}, - } + assert trippled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 3}} assert doubled.operator is not trippled.operator diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 447b1732a78b0..1e8d510fd7205 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1654,6 +1654,59 @@ def test_mapped_operator_xcomarg_serde(): assert xcom_arg.operator is serialized_dag.task_dict['op1'] +def test_mapped_decorator_serde(): + from airflow.decorators import task + from airflow.models.xcom_arg import XComArg + from airflow.serialization.serialized_objects import _XComRef + + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + op1 = BaseOperator(task_id="op1") + xcomarg = XComArg(op1, "my_key") + + @task(retry_delay=30) + def x(arg1, arg2, arg3, arg4): + print(arg1, arg2, arg3, arg4) + + x.partial("foo", arg3=[1, 2, {"a": "b"}]).map({"a": 1, "b": 2}, arg4=xcomarg) + + original = dag.get_task("x") + + serialized = SerializedBaseOperator._serialize(original) + assert serialized == { + '_is_dummy': False, + '_is_mapped': True, + '_task_module': 'airflow.decorators.python', + '_task_type': '_PythonDecoratedOperator', + 'downstream_task_ids': [], + 'partial_kwargs': { + 'op_args': ["foo"], + 'op_kwargs': {'arg3': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]}, + 'retry_delay': 30, + }, + 'mapped_kwargs': { + 'op_args': [{"__type": "dict", "__var": {'a': 1, 'b': 2}}], + 'op_kwargs': {'arg4': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'my_key'}}}, + }, + 'task_id': 'x', + 'template_ext': [], + 'template_fields': ['op_args', 'op_kwargs'], + } + + deserialized = SerializedBaseOperator.deserialize_operator(serialized) + assert isinstance(deserialized, MappedOperator) + assert deserialized.deps is MappedOperator.DEFAULT_DEPS + + assert deserialized.mapped_kwargs == { + "op_args": [{"a": 1, "b": 2}], + "op_kwargs": {"arg4": _XComRef("op1", "my_key")}, + } + assert deserialized.partial_kwargs == { + "retry_delay": 30, + "op_args": ["foo"], + "op_kwargs": {"arg3": [1, 2, {"a": "b"}]}, + } + + def test_mapped_task_group_serde(): execution_date = datetime(2020, 1, 1) From 0d3c6245a7b91d392702b849732758d6c462b564 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 9 Feb 2022 21:17:06 +0800 Subject: [PATCH 6/6] Add test to ensure unmapping converts partial kwargs --- tests/decorators/test_python.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 235a8297144a6..ee94fde610d7a 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -555,3 +555,27 @@ def task2(arg1, arg2): unmapped = dag.get_task("task2").unmap() assert set(unmapped.op_kwargs) == {"arg1", "arg2"} + + +def test_mapped_decorator_unmap_converts_partial_kwargs(): + with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag: + + @task_decorator + def task1(arg): + ... + + @task_decorator(retry_delay=30) + def task2(arg1, arg2): + ... + + task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2])) + + # Arguments to the task decorator are stored in partial_kwargs, and + # converted into their intended form after the task is unmapped. + mapped_task2 = dag.get_task("task2") + assert mapped_task2.partial_kwargs["retry_delay"] == 30 + assert mapped_task2.unmap().retry_delay == timedelta(seconds=30) + + mapped_task1 = dag.get_task("task1") + assert "retry_delay" not in mapped_task1.partial_kwargs + mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator default.