Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 75 additions & 17 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,30 +280,88 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names
names = ", ".join(repr(n) for n in unknown_args)
raise TypeError(f'{funcname} got unexpected keyword arguments {names}')

def map(
self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
) -> XComArg:
def map(self, *args, **kwargs) -> XComArg:
self._validate_arg_names("map", kwargs)
dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)

operator = MappedOperator.from_decorator(
decorator=self,
partial_kwargs = self.kwargs.copy()
dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag))
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)

# Unfortunately attrs's type hinting support does not work well with
# subclassing; it complains that arguments forwarded to the superclass
# are "unexpected" (they are fine at runtime).
operator = cast(Any, DecoratedMappedOperator)(
operator_class=self.operator_class,
partial_kwargs=partial_kwargs,
mapped_kwargs={},
task_id=task_id,
dag=dag,
task_group=task_group,
task_id=task_id,
mapped_kwargs=kwargs,
deps=MappedOperator._deps(self.operator_class.deps),
multiple_outputs=self.multiple_outputs,
python_callable=self.function,
)

operator.mapped_kwargs["op_args"] = list(args)
operator.mapped_kwargs["op_kwargs"] = kwargs

for arg in itertools.chain(args, kwargs.values()):
XComArg.apply_upstream_relationship(operator, arg)
return XComArg(operator=operator)

def partial(
self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
) -> "_TaskDecorator[Function, OperatorSubclass]":
self._validate_arg_names("partial", kwargs, {'task_id'})
partial_kwargs = self.kwargs.copy()
partial_kwargs.update(kwargs)
return attr.evolve(self, kwargs=partial_kwargs)
def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
self._validate_arg_names("partial", kwargs)

op_args = self.kwargs.get("op_args", [])
op_args.extend(args)

op_kwargs = self.kwargs.get("op_kwargs", {})
op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial")

return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs})


def _merge_kwargs(
kwargs1: Dict[str, XComArg],
kwargs2: Dict[str, XComArg],
*,
fail_reason: str,
) -> Dict[str, XComArg]:
duplicated_keys = set(kwargs1).intersection(kwargs2)
if len(duplicated_keys) == 1:
raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
elif duplicated_keys:
duplicated_keys_display = ", ".join(sorted(duplicated_keys))
raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
return {**kwargs1, **kwargs2}


@attr.define(kw_only=True)
class DecoratedMappedOperator(MappedOperator):
"""MappedOperator implementation for @task-decorated task function."""

multiple_outputs: bool
python_callable: Callable

def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
assert not isinstance(self.operator_class, str)
op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", [])
op_kwargs = _merge_kwargs(
self.partial_kwargs.pop("op_kwargs", {}),
self.mapped_kwargs.pop("op_kwargs", {}),
fail_reason="mapping already partial",
)
return self.operator_class(
dag=dag,
task_id=self.task_id,
op_args=op_args,
op_kwargs=op_kwargs,
multiple_outputs=self.multiple_outputs,
python_callable=self.python_callable,
**self.partial_kwargs,
**self.mapped_kwargs,
)


class Task(Generic[Function]):
Expand Down
55 changes: 9 additions & 46 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@
from airflow.utils.weight_rule import WeightRule

if TYPE_CHECKING:
from airflow.decorators.base import _TaskDecorator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup

Expand Down Expand Up @@ -243,7 +242,7 @@ def __new__(cls, name, bases, namespace, **kwargs):
return new_cls

# The class level partial function. This is what handles the actual mapping
def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs):
def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator":
operator_class = cast("Type[BaseOperator]", cls)
# Validate that the args we passed are known -- at call/DAG parse time, not run time!
_validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
Expand Down Expand Up @@ -1632,7 +1631,7 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) ->
dag._remove_task(operator.task_id)

operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore
return MappedOperator(
return cls(
operator_class=type(operator),
task_id=operator.task_id,
task_group=task_group,
Expand All @@ -1648,37 +1647,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) ->
deps=cls._deps(operator.deps),
)

@classmethod
def from_decorator(
cls,
*,
decorator: "_TaskDecorator",
dag: Optional["DAG"],
task_group: Optional["TaskGroup"],
task_id: str,
mapped_kwargs: Dict[str, Any],
) -> "MappedOperator":
"""Create a mapped operator from a task decorator.

Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``.
The task decorator calling this should be responsible for validation.
"""
from airflow.models.xcom_arg import XComArg

operator = MappedOperator(
operator_class=decorator.operator_class,
partial_kwargs=decorator.kwargs,
mapped_kwargs={},
task_id=task_id,
dag=dag,
task_group=task_group,
deps=cls._deps(decorator.operator_class.deps),
)
operator.mapped_kwargs.update(mapped_kwargs)
for arg in mapped_kwargs.values():
XComArg.apply_upstream_relationship(operator, arg)
return operator

@classmethod
def _deps(cls, deps: Iterable[BaseTIDep]):
if deps is BaseOperator.deps:
Expand Down Expand Up @@ -1749,7 +1717,7 @@ def inherits_from_dummy_operator(self):
@classmethod
def get_serialized_fields(cls):
if cls.__serialized_fields is None:
fields_dict = attr.fields_dict(cls)
fields_dict = attr.fields_dict(MappedOperator)
cls.__serialized_fields = frozenset(
fields_dict.keys()
- {
Expand Down Expand Up @@ -1902,22 +1870,17 @@ def expand_mapped_task(

return ret

def unmap(self) -> BaseOperator:
"""Get the "normal" Operator after applying the current mapping"""
def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
assert not isinstance(self.operator_class, str)
return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs)

def unmap(self) -> BaseOperator:
"""Get the "normal" Operator after applying the current mapping"""
dag = self.get_dag()
if not dag:
raise RuntimeError("Cannot unmapp a task unless it has a dag")

args = {
**self.partial_kwargs,
**self.mapped_kwargs,
}
raise RuntimeError("Cannot unmap a task unless it has a DAG")
dag._remove_task(self.task_id)
task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)

return task
return self.create_unmapped_operator(dag)


# TODO: Deprecate for Airflow 3.0
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ def handle_failure(
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,
session=NEW_SESSION,
session: Session = NEW_SESSION,
) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
Expand Down
27 changes: 22 additions & 5 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

"""Serialized DAG and BaseOperator"""
import contextlib
import datetime
import enum
import logging
Expand Down Expand Up @@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
return timetable_class.deserialize(var[Encoding.VAR])


class _XcomRef(NamedTuple):
class _XComRef(NamedTuple):
"""
Used to store info needed to create XComArg when deserializing MappedOperator.

Expand Down Expand Up @@ -497,8 +498,8 @@ def _serialize_xcomarg(cls, arg: XComArg) -> dict:
return {"key": arg.key, "task_id": arg.operator.task_id}

@classmethod
def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef:
return _XcomRef(key=encoded['key'], task_id=encoded['task_id'])
def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
return _XComRef(key=encoded['key'], task_id=encoded['task_id'])


class DependencyDetector:
Expand Down Expand Up @@ -566,9 +567,19 @@ def task_type(self, task_type: str):

@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:

stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
serialize_op = cls._serialize_node(op, include_deps=not stock_deps)

# Simplify op_kwargs format. It must be a dict, so we flatten it.
with contextlib.suppress(KeyError):
op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"]
assert op_kwargs[Encoding.TYPE] == DAT.DICT
serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
with contextlib.suppress(KeyError):
op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"]
assert op_kwargs[Encoding.TYPE] == DAT.DICT
serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]

# It must be a class at this point for it to work, not a string
assert isinstance(op.operator_class, type)
serialize_op['_task_type'] = op.operator_class.__name__
Expand Down Expand Up @@ -715,7 +726,13 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator,
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k in ("mapped_kwargs", "partial_kwargs"):
if "op_kwargs" not in v:
op_kwargs: Optional[dict] = None
else:
op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
v = {arg: cls._deserialize(value) for arg, value in v.items()}
if op_kwargs is not None:
v["op_kwargs"] = op_kwargs
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
Expand Down Expand Up @@ -1002,7 +1019,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
if isinstance(task, MappedOperator):
for d in (task.mapped_kwargs, task.partial_kwargs):
for k, v in d.items():
if not isinstance(v, _XcomRef):
if not isinstance(v, _XComRef):
continue

d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)
Expand Down
31 changes: 31 additions & 0 deletions tests/dags/test_mapped_taskflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from airflow import DAG
from airflow.utils.dates import days_ago

with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:

@dag.task
def make_list():
return [1, 2, {'a': 'b'}]

@dag.task
def consumer(value):
print(repr(value))

consumer.map(value=make_list())
Loading