From 423bf9b7ccdaa55f0b5d1faaf24dcb6a5f2582bc Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 2 Dec 2021 12:55:11 +0000 Subject: [PATCH 01/17] Map and Partial DAG authoring interface for Operators --- airflow/models/baseoperator.py | 290 ++++++++++-------- airflow/models/dag.py | 17 +- airflow/models/skipmixin.py | 5 +- airflow/models/taskmixin.py | 157 +++++++++- airflow/serialization/serialized_objects.py | 8 +- airflow/utils/task_group.py | 61 ++-- docs/spelling_wordlist.txt | 1 + tests/models/test_baseoperator.py | 79 ++++- tests/serialization/test_dag_serialization.py | 8 +- 9 files changed, 453 insertions(+), 173 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index c3c5be65a9d15..b2a26ac28dbfe 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -61,7 +61,7 @@ from airflow.models.param import ParamsDict from airflow.models.pool import Pool from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances -from airflow.models.taskmixin import DependencyMixin +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -70,7 +70,6 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources @@ -91,6 +90,23 @@ T = TypeVar('T', bound=FunctionType) +class _PartialDescriptor: + """A descriptor that guards against ``.partial`` being called on Task objects.""" + + class_method = None + + def __get__( + self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None + ) -> Callable[..., "MappedOperator"]: + # Call this "partial" so it looks nicer in stack traces + def partial(*, task_id: str, **kwargs): + raise TypeError("partial can only be called on Operator classes, not Tasks themselves") + + if obj is not None: + return partial + return self.class_method.__get__(cls, cls) + + class BaseOperatorMeta(abc.ABCMeta): """Metaclass of BaseOperator.""" @@ -110,12 +126,13 @@ def _apply_defaults(cls, func: T) -> T: # per decoration, i.e. each function decorated using apply_defaults will # have a different sig_cache. sig_cache = signature(func) - non_optional_args = { - name + non_varaidc_params = { + name: param for (name, param) in sig_cache.parameters.items() - if param.default == param.empty - and param.name != 'self' - and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + } + non_optional_args = { + name for (name, param) in non_varaidc_params.items() if param.default == param.empty } class autostacklevel_warn: @@ -139,7 +156,7 @@ def warn(self, message, category=None, stacklevel=1, source=None): func.__globals__['warnings'] = autostacklevel_warn() @functools.wraps(func) - def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: + def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -159,9 +176,8 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: params = kwargs.get('params', {}) or {} dag_params.update(params) - default_args = {} - if 'default_args' in kwargs: - default_args = kwargs['default_args'] + default_args = kwargs.pop('default_args', {}) + if default_args: if 'params' in default_args: dag_params.update(default_args['params']) del default_args['params'] @@ -181,13 +197,17 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: if dag_params: kwargs['params'] = dag_params - if default_args: - kwargs['default_args'] = default_args + hook = getattr(self, '_hook_apply_defaults', None) + if hook: + args, kwargs = hook(**kwargs, default_args=default_args) + default_args = kwargs.pop('default_args', {}) - if hasattr(self, '_hook_apply_defaults'): - args, kwargs = self._hook_apply_defaults(*args, **kwargs) + if not hasattr(self, '_BaseOperator__init_kwargs'): + self._BaseOperator__init_kwargs = {} - result = func(self, *args, **kwargs) + result = func(self, **kwargs, default_args=default_args) + # Store the args passed to init -- we need them to support task.map serialzation! + self._BaseOperator__init_kwargs.update(kwargs) # type: ignore # Here we set upstream task defined by XComArgs passed to template fields of the operator self.set_xcomargs_dependencies() @@ -196,16 +216,59 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: self._BaseOperator__instantiated = True return result + apply_defaults.__non_optional_args = non_optional_args # type: ignore + apply_defaults.__param_names = set(non_varaidc_params.keys()) # type: ignore + return cast(T, apply_defaults) def __new__(cls, name, bases, namespace, **kwargs): new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + try: + # Update the partial descriptor with the class method so it call call the actual function (but let + # subclasses override it if they need to) + partial_desc = vars(new_cls)['partial'] + if isinstance(partial_desc, _PartialDescriptor): + actual_partial = cls.partial + partial_desc.class_method = classmethod(actual_partial) + except KeyError: + pass new_cls.__init__ = cls._apply_defaults(new_cls.__init__) return new_cls + # The class level partial function. This is what handles the actual mapping + def partial(cls, *, task_id: str, **kwargs): + unknown_args = set(kwargs.keys()) + # Validate that the args we passed are known -- at call/DAG parse time, not run time! + # + # This loop _assumes_ that all unknown args from a class are passed to the superclass's __init__, but + # there is no way for us to validate that this is actually what operators do. + for clazz in cls.mro(): + # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access "__init__" directly + init = clazz.__init__ # type: ignore + + if not hasattr(init, '_BaseOperatorMeta__param_names'): + continue + unknown_args.difference_update(init.__param_names) + if not unknown_args: + # If we have no args left ot check: stop looking at the MRO chian + break + + if unknown_args: + if len(unknown_args) == 1: + raise TypeError( + f'{cls.__name__}.partial got unexpected keyword argument {unknown_args.pop()!r}' + ) + else: + names = ", ".join(repr(n) for n in unknown_args) + raise TypeError(f'{cls.__name__}.partial got unexpected keyword arguments {names}') + operator_class = cast("Type[BaseOperator]", cls) + return MappedOperator( + task_id=task_id, operator_class=operator_class, partial_kwargs=kwargs, mapped_kwargs={} + ) + @functools.total_ordering -class BaseOperator(Operator, LoggingMixin, DependencyMixin, metaclass=BaseOperatorMeta): +class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta): """ Abstract base class for all operators. Since operators create objects that become nodes in the dag, BaseOperator contains many recursive methods for @@ -452,6 +515,8 @@ class derived from this one results in the creation of a task object, # The _serialized_fields are lazily loaded when get_serialized_fields() method is called __serialized_fields: Optional[FrozenSet[str]] = None + partial: Callable[..., "MappedOperator"] = _PartialDescriptor() # type: ignore + _comps = { 'task_id', 'dag_id', @@ -480,6 +545,9 @@ class derived from this one results in the creation of a task object, # If True then the class constructor was called __instantiated = False + # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task + # when mapping + __init_kwargs: Dict[str, Any] # Set to True before calling execute method _lock_for_execution = False @@ -490,6 +558,9 @@ class derived from this one results in the creation of a task object, # Setting it to None by default as other Operators do not have that field subdag: Optional["DAG"] = None + start_date: Optional[pendulum.DateTime] = None + end_date: Optional[pendulum.DateTime] = None + def __init__( self, task_id: str, @@ -541,6 +612,8 @@ def __init__( from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext + self.__init_kwargs = {} + super().__init__() if kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): @@ -575,13 +648,11 @@ def __init__( self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute - self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) - self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) @@ -792,6 +863,8 @@ def __setattr__(self, key, value): if self._lock_for_execution: # Skip any custom behaviour during execute return + if key in self.__init_kwargs: + self.__init_kwargs[key] = value if self.__instantiated and key in self.template_fields: # Resolve upstreams set by assigning an XComArg after initializing # an operator, example: @@ -816,7 +889,11 @@ def get_outlet_defs(self): return self._outlets @property - def dag(self) -> 'DAG': + def node_id(self): + return self.task_id + + @property # type: ignore[override] + def dag(self) -> 'DAG': # type: ignore[override] """Returns the Operator's DAG if set, otherwise raises an error""" if self._dag: return self._dag @@ -1191,21 +1268,11 @@ def resolve_template_files(self) -> None: self.log.exception(e) self.prepare_template() - @property - def upstream_list(self) -> List["BaseOperator"]: - """@property: list of tasks directly upstream""" - return [self.dag.get_task(tid) for tid in self._upstream_task_ids] - @property def upstream_task_ids(self) -> Set[str]: """@property: set of ids of tasks directly upstream""" return self._upstream_task_ids - @property - def downstream_list(self) -> List["BaseOperator"]: - """@property: list of tasks directly downstream""" - return [self.dag.get_task(tid) for tid in self._downstream_task_ids] - @property def downstream_task_ids(self) -> Set[str]: """@property: set of ids of tasks directly downstream""" @@ -1374,7 +1441,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: else: return self._downstream_task_ids - def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]: + def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: """ Get list of the direct relatives to the current task, upstream or downstream. @@ -1392,13 +1459,6 @@ def task_type(self) -> str: """@property: type of the task""" return self.__class__.__name__ - def add_only_new(self, item_set: Set[str], item: str, dag_id: str) -> None: - """Adds only new items to item set""" - if item in item_set: - self.log.warning('Dependency %s, %s already registered for DAG: %s', self, item, dag_id) - else: - item_set.add(item) - @property def roots(self) -> List["BaseOperator"]: """Required by TaskMixin""" @@ -1409,90 +1469,6 @@ def leaves(self) -> List["BaseOperator"]: """Required by TaskMixin""" return [self] - def _set_relatives( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - upstream: bool = False, - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """Sets relatives for the task or task list.""" - if not isinstance(task_or_task_list, Sequence): - task_or_task_list = [task_or_task_list] - - task_list: List["BaseOperator"] = [] - for task_object in task_or_task_list: - task_object.update_relative(self, not upstream) - relatives = task_object.leaves if upstream else task_object.roots - for task in relatives: - if not isinstance(task, BaseOperator): - raise AirflowException( - f"Relationships can only be set between Operators; received {task.__class__.__name__}" - ) - task_list.append(task) - - # relationships can only be set if the tasks share a single DAG. Tasks - # without a DAG are assigned to that DAG. - dags = { - task._dag.dag_id: task._dag for task in self.roots + task_list if task.has_dag() # type: ignore - } - - if len(dags) > 1: - raise AirflowException( - f'Tried to set relationships between tasks in more than one DAG: {dags.values()}' - ) - elif len(dags) == 1: - dag = dags.popitem()[1] - else: - raise AirflowException( - f"Tried to create relationships between tasks that don't have DAGs yet. " - f"Set the DAG for at least one task and try again: {[self] + task_list}" - ) - - if dag and not self.has_dag(): - # If this task does not yet have a dag, add it to the same dag as the other task and - # put it in the dag's root TaskGroup. - self.dag = dag - self.dag.task_group.add(self) - - for task in task_list: - if dag and not task.has_dag(): - # If the other task does not yet have a dag, add it to the same dag as this task and - # put it in the dag's root TaskGroup. - task.dag = dag - task.dag.task_group.add(task) - if upstream: - task.add_only_new(task.get_direct_relative_ids(upstream=False), self.task_id, self.dag.dag_id) - self.add_only_new(self._upstream_task_ids, task.task_id, task.dag.dag_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, task.task_id, self.task_id) - else: - self.add_only_new(self._downstream_task_ids, task.task_id, task.dag.dag_id) - task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id, self.dag.dag_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, self.task_id, task.task_id) - - def set_downstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """ - Set a task or a task list to be directly downstream from the current - task. Required by TaskMixin. - """ - self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) - - def set_upstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """ - Set a task or a task list to be directly upstream from the current - task. Required by TaskMixin. - """ - self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) - @property def output(self): """Returns reference to XCom pushed by current operator""" @@ -1614,8 +1590,11 @@ def get_serialized_fields(cls): 'dag', '_dag', '_BaseOperator__instantiated', + '_BaseOperator__init_kwargs', } - | { + | { # Class level defaults need to be added to this list + 'start_date', + 'end_date', '_task_type', 'subdag', 'ui_color', @@ -1659,6 +1638,71 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + def map(self, **kwargs) -> "MappedOperator": + return MappedOperator( + operator_class=type(self), + operator=self, + task_id=self.task_id, + task_group=getattr(self, 'task_group', None), + dag=getattr(self, '_dag', None), + start_date=self.start_date, + end_date=self.end_date, + partial_kwargs=self.__init_kwargs, + mapped_kwargs=kwargs, + ) + + +@attr.define(kw_only=True) +class MappedOperator(DAGNode): + """Object representing a mapped operator in a DAG""" + + operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__) + task_id: str + partial_kwargs: Dict[str, Any] + mapped_kwargs: Dict[str, Any] + operator: Optional[BaseOperator] = None + dag: Optional["DAG"] = None + upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) + downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) + + task_group: Optional["TaskGroup"] = attr.ib(repr=False, default=None) + # BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks + start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) + end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) + + def __attrs_post_init__(self): + if self.dag and self.operator: + # As soon as we are a mapped task, replace the unmapped version in the list of tasks + self.dag.remove_task(self.task_id) + self.dag.add_task(self) + + @property + def node_id(self): + return self.task_id + + def map(self, **kwargs) -> "MappedOperator": + """ + Update the mapping parameters in place. + + :return: ``self`` for easier method chaining + """ + mapped_kwargs = self.mapped_kwargs.copy() + mapped_kwargs.update(kwargs) + return attr.evolve(self, mapped_kwargs=mapped_kwargs) + + @property + def roots(self) -> List["MappedOperator"]: + """Required by TaskMixin""" + return [self] + + @property + def leaves(self) -> List["MappedOperator"]: + """Required by TaskMixin""" + return [self] + + def has_dag(self): + return self.dag is not None + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 89b79b98b5390..7fc6f5a09572e 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1755,7 +1755,7 @@ def topological_sort(self, include_subdag_tasks: bool = False): acyclic = False for node in list(graph_unsorted.values()): for edge in node.upstream_list: - if edge.task_id in graph_unsorted: + if edge.node_id in graph_unsorted: break # no edges in upstream tasks else: @@ -2075,10 +2075,10 @@ def filter_task_group(group, parent_group): # the cut. subdag_task_groups = dag.task_group.get_task_group_dict() for group in subdag_task_groups.values(): - group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys()) - group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys()) - group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys()) - group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys()) + group.upstream_group_ids.intersection_update(subdag_task_groups.keys()) + group.downstream_group_ids.intersection_update(subdag_task_groups.keys()) + group.upstream_task_ids.intersection_update(dag.task_dict.keys()) + group.downstream_task_ids.intersection_update(dag.task_dict.keys()) for t in dag.tasks: # Removing upstream/downstream references to tasks that did not @@ -2092,7 +2092,7 @@ def filter_task_group(group, parent_group): return dag def has_task(self, task_id: str): - return task_id in (t.task_id for t in self.tasks) + return task_id in self.task_dict def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator: if task_id in self.task_dict: @@ -2197,6 +2197,11 @@ def add_tasks(self, tasks): for task in tasks: self.add_task(task) + def remove_task(self, task_id: str) -> None: + del self.task_dict[task_id] + self._task_group.used_group_ids.remove(task_id) + self.task_count = len(self.task_dict) + def run( self, start_date=None, diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index a552af8904ac7..20879ef0df65a 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -17,7 +17,7 @@ # under the License. import warnings -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union, cast from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone @@ -147,7 +147,8 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable task = ti.task dag = task.dag - downstream_tasks = task.downstream_list + # At runtime, the downstream list will only be operators + downstream_tasks = cast("List[BaseOperator]", task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 8942229e7aa45..3fbbf57bc4ff6 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -17,7 +17,18 @@ import warnings from abc import abstractmethod -from typing import Sequence, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Union + +import pendulum + +from airflow.exceptions import AirflowException + +if TYPE_CHECKING: + from logging import Logger + + from airflow.models.dag import DAG + from airflow.utils.edgemodifier import EdgeModifier + from airflow.utils.task_group import TaskGroup class DependencyMixin: @@ -88,3 +99,147 @@ def __init_subclass__(cls) -> None: stacklevel=2, ) return super().__init_subclass__() + + +class DAGNode(DependencyMixin): + """ + A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or + unmapped. + """ + + dag: Optional["DAG"] = None + + @property + @abstractmethod + def node_id(self) -> str: + raise NotImplementedError() + + task_group: Optional["TaskGroup"] + """The task_group that contains this node""" + + start_date: Optional[pendulum.DateTime] + end_date: Optional[pendulum.DateTime] + + def has_dag(self) -> bool: + return self.dag is not None + + @property + @abstractmethod + def upstream_task_ids(self) -> Set[str]: + raise NotImplementedError() + + @property + @abstractmethod + def downstream_task_ids(self) -> Set[str]: + raise NotImplementedError() + + @property + def log(self) -> "Logger": + raise NotImplementedError() + + @property + @abstractmethod + def roots(self) -> Sequence["DAGNode"]: + raise NotImplementedError() + + @property + @abstractmethod + def leaves(self) -> Sequence["DAGNode"]: + raise NotImplementedError() + + def _set_relatives( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + upstream: bool = False, + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Sets relatives for the task or task list.""" + from airflow.models.baseoperator import BaseOperator, MappedOperator + + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + task_list: List[DAGNode] = [] + for task_object in task_or_task_list: + task_object.update_relative(self, not upstream) + relatives = task_object.leaves if upstream else task_object.roots + for task in relatives: + if not isinstance(task, (BaseOperator, MappedOperator)): + raise AirflowException( + f"Relationships can only be set between Operators; received {task.__class__.__name__}" + ) + task_list.append(task) + + # relationships can only be set if the tasks share a single DAG. Tasks + # without a DAG are assigned to that DAG. + dags: Set["DAG"] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} + + if len(dags) > 1: + raise AirflowException(f'Tried to set relationships between tasks in more than one DAG: {dags}') + elif len(dags) == 1: + dag = dags.pop() + else: + raise AirflowException( + f"Tried to create relationships between tasks that don't have DAGs yet. " + f"Set the DAG for at least one task and try again: {[self, *task_list]}" + ) + + if not self.has_dag(): + # If this task does not yet have a dag, add it to the same dag as the other task and + # put it in the dag's root TaskGroup. + self.dag = dag + self.dag.task_group.add(self) + + def add_only_new(obj, item_set: Set[str], item: str) -> None: + """Adds only new items to item set""" + if item in item_set: + self.log.warning('Dependency %s, %s already registered for DAG: %s', obj, item, dag.dag_id) + else: + item_set.add(item) + + for task in task_list: + if dag and not task.has_dag(): + # If the other task does not yet have a dag, add it to the same dag as this task and + # put it in the dag's root TaskGroup. + dag.add_task(task) + dag.task_group.add(task) + if upstream: + add_only_new(task, task.downstream_task_ids, self.node_id) + add_only_new(self, self.upstream_task_ids, task.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) + else: + add_only_new(self, self.downstream_task_ids, task.node_id) + add_only_new(task, task.upstream_task_ids, self.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) + + def set_downstream( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Set a node (or nodes) to be directly downstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) + + def set_upstream( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Set a node (or nodes) to be directly downstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) + + @property + def downstream_list(self) -> Iterable["DAGNode"]: + """List of nodes directly downstream""" + if not self.dag: + raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + return [self.dag.get_task(tid) for tid in self.downstream_task_ids] + + @property + def upstream_list(self) -> Iterable["DAGNode"]: + """List of nodes directly upstream""" + if not self.dag: + raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + return [self.dag.get_task(tid) for tid in self.upstream_task_ids] diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 561a54bdcdd72..2b95a6a1b0fb8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1002,10 +1002,10 @@ def deserialize_task_group( else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val) in encoded_group["children"].items() } - group.upstream_group_ids = set(cls._deserialize(encoded_group["upstream_group_ids"])) - group.downstream_group_ids = set(cls._deserialize(encoded_group["downstream_group_ids"])) - group.upstream_task_ids = set(cls._deserialize(encoded_group["upstream_task_ids"])) - group.downstream_task_ids = set(cls._deserialize(encoded_group["downstream_task_ids"])) + group.upstream_group_ids.update(cls._deserialize(encoded_group["upstream_group_ids"])) + group.downstream_group_ids.update(cls._deserialize(encoded_group["downstream_group_ids"])) + group.upstream_task_ids.update(cls._deserialize(encoded_group["upstream_task_ids"])) + group.downstream_task_ids.update(cls._deserialize(encoded_group["downstream_task_ids"])) return group diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 84f8f459d129d..72c043c1a43cf 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -21,18 +21,20 @@ """ import copy import re +import weakref from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException, DuplicateTaskIdFound -from airflow.models.taskmixin import DependencyMixin +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.utils.helpers import validate_group_key if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG + from airflow.utils.edgemodifier import EdgeModifier -class TaskGroup(DependencyMixin): +class TaskGroup(DAGNode): """ A collection of tasks. When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across all tasks within the group if necessary. @@ -69,6 +71,8 @@ class TaskGroup(DependencyMixin): :type from_decorator: add_suffix_on_collision """ + used_group_ids: Set[Optional[str]] + def __init__( self, group_id: Optional[str], @@ -92,7 +96,7 @@ def __init__( raise AirflowException("Root TaskGroup cannot have parent_group") # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. - self.used_group_ids: Set[Optional[str]] = set() + self.used_group_ids = set() self._parent_group = None else: if prefix_group_id: @@ -121,9 +125,10 @@ def __init__( self._check_for_group_id_collisions(add_suffix_on_collision) self.used_group_ids.add(self.group_id) - self.used_group_ids.add(self.downstream_join_id) - self.used_group_ids.add(self.upstream_join_id) - self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} + if self.group_id: + self.used_group_ids.add(self.downstream_join_id) + self.used_group_ids.add(self.upstream_join_id) + self.children: Dict[str, DAGNode] = {} if self._parent_group: self._parent_group.add(self) @@ -135,8 +140,9 @@ def __init__( # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() - self.upstream_task_ids: Set[Optional[str]] = set() - self.downstream_task_ids: Set[Optional[str]] = set() + # Since the parent class defines these as read-only properties, we can 't just do `self.x = ...` + self.__dict__['upstream_task_ids'] = set() + self.__dict__['downstream_task_ids'] = set() def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): if self._group_id is None: @@ -162,6 +168,10 @@ def create_root(cls, dag: "DAG") -> "TaskGroup": """Create a root TaskGroup with no group_id or parent.""" return cls(group_id=None, dag=dag) + @property + def node_id(self): + return self.group_id + @property def is_root(self) -> bool: """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" @@ -174,18 +184,20 @@ def __iter__(self): else: yield child - def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None: + def add(self, task: DAGNode) -> None: """Add a task to this TaskGroup.""" - key = task.group_id if isinstance(task, TaskGroup) else task.task_id + key = task.node_id if key in self.children: - raise DuplicateTaskIdFound(f"Task id '{key}' has already been added to the DAG") + node_type = "Task" if hasattr(task, 'task_id') else "Task Group" + raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") if isinstance(task, TaskGroup): if task.children: raise AirflowException("Cannot add a non-empty TaskGroup") - self.children[key] = task # type: ignore + self.children[key] = task + task.task_group = weakref.proxy(self) @property def group_id(self) -> Optional[str]: @@ -207,8 +219,6 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None: Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids accordingly so that we can reduce the number of edges when displaying Graph view. """ - from airflow.models.baseoperator import BaseOperator - if isinstance(other, TaskGroup): # Handles setting relationship between a TaskGroup and another TaskGroup if upstream: @@ -221,19 +231,22 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None: else: # Handles setting relationship between a TaskGroup and a task for task in other.roots: - if not isinstance(task, BaseOperator): + if not isinstance(task, DAGNode): raise AirflowException( "Relationships can only be set between TaskGroup " f"or operators; received {task.__class__.__name__}" ) if upstream: - self.upstream_task_ids.add(task.task_id) + self.upstream_task_ids.add(task.node_id) else: - self.downstream_task_ids.add(task.task_id) + self.downstream_task_ids.add(task.node_id) def _set_relative( - self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]], upstream: bool = False + self, + task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]], + upstream: bool = False, + edge_modifier: Optional["EdgeModifier"] = None, ) -> None: """ Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. @@ -252,16 +265,6 @@ def _set_relative( for task_like in task_or_task_list: self.update_relative(task_like, upstream) - def set_downstream( - self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]] - ) -> None: - """Set a TaskGroup/task/list of task downstream of this TaskGroup.""" - self._set_relative(task_or_task_list, upstream=False) - - def set_upstream(self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]]) -> None: - """Set a TaskGroup/task/list of task upstream of this TaskGroup.""" - self._set_relative(task_or_task_list, upstream=True) - def __enter__(self) -> "TaskGroup": TaskGroupContext.push_context_managed_task_group(self) return self @@ -348,7 +351,7 @@ def build_map(task_group): build_map(self) return task_group_map - def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]: + def get_child_by_label(self, label: str) -> DAGNode: """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)""" return self.children[self.child_id(label)] diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index fdad20716ebb5..eebe41da31275 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1408,6 +1408,7 @@ unicode unittest unittests unix +unmapped unpause unpausing unpredicted diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index cabe1b58785a5..d0373806df6cf 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -22,13 +22,20 @@ from unittest import mock import jinja2 +import pendulum import pytest from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException from airflow.lineage.entities import File from airflow.models import DAG -from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, cross_downstream +from airflow.models.baseoperator import ( + BaseOperator, + BaseOperatorMeta, + MappedOperator, + chain, + cross_downstream, +) from airflow.utils.context import Context from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup @@ -640,3 +647,73 @@ def test_operator_retries(caplog, dag_maker, retries, expected): retries=retries, ) assert caplog.record_tuples == expected + + +def test_task_mapping_with_dag(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + task1 >> mapped + + assert isinstance(mapped, MappedOperator) + assert mapped.operator + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'} + with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + task1 >> mapped + + assert mapped.partial_kwargs['owner'] == 'test' + assert mapped.start_date == pendulum.instance(default_args['start_date']) + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator( + task_id='a', + ).partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id='a', arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.partial_kwargs == {'arg1': 'a', 'trigger_rule': TriggerRule.ONE_FAILED} + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError): + MockOperator.partial(task_id='a', foo='bar', bar=2) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index d1e6e794ef72e..1729f8bee2de4 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1094,14 +1094,12 @@ def test_no_new_fields_added_to_base_operator(self): tests should be added for it. """ base_operator = BaseOperator(task_id="10") - fields = base_operator.__dict__ + fields = {k: v for (k, v) in vars(base_operator).items() if k in BaseOperator.get_serialized_fields()} assert fields == { - '_BaseOperator__instantiated': True, '_downstream_task_ids': set(), '_inlets': [], '_log': base_operator.log, '_outlets': [], - '_upstream_task_ids': set(), '_pre_execute_hook': None, '_post_execute_hook': None, 'depends_on_past': False, @@ -1114,10 +1112,8 @@ def test_no_new_fields_added_to_base_operator(self): 'email': None, 'email_on_failure': True, 'email_on_retry': True, - 'end_date': None, 'execution_timeout': None, 'executor_config': {}, - 'inlets': [], 'label': '10', 'max_active_tis_per_dag': None, 'max_retry_delay': None, @@ -1125,7 +1121,6 @@ def test_no_new_fields_added_to_base_operator(self): 'on_failure_callback': None, 'on_retry_callback': None, 'on_success_callback': None, - 'outlets': [], 'owner': 'airflow', 'params': {}, 'pool': 'default_pool', @@ -1138,7 +1133,6 @@ def test_no_new_fields_added_to_base_operator(self): 'retry_exponential_backoff': False, 'run_as_user': None, 'sla': None, - 'start_date': None, 'task_id': '10', 'trigger_rule': 'all_success', 'wait_for_downstream': False, From da478ef34d364cf24c29947b838098d2154b33aa Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 8 Dec 2021 15:19:18 +0000 Subject: [PATCH 02/17] Add TaskGroup mapping DAG author interface --- airflow/models/baseoperator.py | 35 ++++++++++++++++++---- airflow/models/dag.py | 11 +++++-- airflow/models/taskmixin.py | 20 +++++++++++++ airflow/utils/task_group.py | 55 +++++++++++++++++++++++++++++++--- tests/utils/test_task_group.py | 49 ++++++++++++++++++++++++++++++ 5 files changed, 158 insertions(+), 12 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b2a26ac28dbfe..dd8eb0aec4e17 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -236,7 +236,7 @@ def __new__(cls, name, bases, namespace, **kwargs): return new_cls # The class level partial function. This is what handles the actual mapping - def partial(cls, *, task_id: str, **kwargs): + def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): unknown_args = set(kwargs.keys()) # Validate that the args we passed are known -- at call/DAG parse time, not run time! # @@ -263,7 +263,7 @@ def partial(cls, *, task_id: str, **kwargs): raise TypeError(f'{cls.__name__}.partial got unexpected keyword arguments {names}') operator_class = cast("Type[BaseOperator]", cls) return MappedOperator( - task_id=task_id, operator_class=operator_class, partial_kwargs=kwargs, mapped_kwargs={} + task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={} ) @@ -561,6 +561,19 @@ class derived from this one results in the creation of a task object, start_date: Optional[pendulum.DateTime] = None end_date: Optional[pendulum.DateTime] = None + def __new__(cls, dag: Optional['DAG'] = None, task_group: Optional["TaskGroup"] = None, **kwargs): + # If we are creating a new Task _and_ we are in the context of a MappedTaskGroup, then we should only + # create mapped operators. + from airflow.models.dag import DagContext + from airflow.utils.task_group import MappedTaskGroup, TaskGroupContext + + dag = dag or DagContext.get_current_dag() + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + + if isinstance(task_group, MappedTaskGroup): + return cls.partial(dag=dag, task_group=task_group, **kwargs) + return super().__new__(cls) + def __init__( self, task_id: str, @@ -1665,17 +1678,29 @@ class MappedOperator(DAGNode): upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) - task_group: Optional["TaskGroup"] = attr.ib(repr=False, default=None) + task_group: Optional["TaskGroup"] = attr.ib(repr=False) # BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) def __attrs_post_init__(self): if self.dag and self.operator: - # As soon as we are a mapped task, replace the unmapped version in the list of tasks - self.dag.remove_task(self.task_id) + # When BaseOperator() was called with a DAG, it would have been added straight away, but now we + # are mapped, we want to _remove_ that task (`self.operator`) from the dag + self.dag._remove_task(self.task_id) + + if self.task_group: + self.task_id = self.task_group.child_id(self.task_id) + self.task_group.add(self) + if self.dag: self.dag.add_task(self) + @task_group.default + def _default_task_group(self): + from airflow.utils.task_group import TaskGroupContext + + return TaskGroupContext.get_current_task_group(self.dag) + @property def node_id(self): return self.task_id diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 7fc6f5a09572e..db06440a0bc9a 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2197,9 +2197,14 @@ def add_tasks(self, tasks): for task in tasks: self.add_task(task) - def remove_task(self, task_id: str) -> None: - del self.task_dict[task_id] - self._task_group.used_group_ids.remove(task_id) + def _remove_task(self, task_id: str) -> None: + # This is "private" as removing could leave a whole in dependencies if done incorrectly, and this + # doesn't guard against that + task = self.task_dict.pop(task_id) + tg = getattr(task, 'task_group', None) + if tg: + tg._remove(task) + self.task_count = len(self.task_dict) def run( diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 3fbbf57bc4ff6..d608942642e7c 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -243,3 +243,23 @@ def upstream_list(self) -> Iterable["DAGNode"]: if not self.dag: raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') return [self.dag.get_task(tid) for tid in self.upstream_task_ids] + + def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: + """ + Get set of the direct relative ids to the current task, upstream or + downstream. + """ + if upstream: + return self.upstream_task_ids + else: + return self.downstream_task_ids + + def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: + """ + Get list of the direct relatives to the current task, upstream or + downstream. + """ + if upstream: + return self.upstream_list + else: + return self.downstream_list diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 72c043c1a43cf..f456e88fc619e 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -22,7 +22,7 @@ import copy import re import weakref -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException, DuplicateTaskIdFound from airflow.models.taskmixin import DAGNode, DependencyMixin @@ -90,6 +90,8 @@ def __init__( self.prefix_group_id = prefix_group_id self.default_args = copy.deepcopy(default_args or {}) + dag = dag or DagContext.get_current_dag() + if group_id is None: # This creates a root TaskGroup. if parent_group: @@ -98,6 +100,7 @@ def __init__( # of used group_id to avoid duplication. self.used_group_ids = set() self._parent_group = None + self.dag = dag else: if prefix_group_id: # If group id is used as prefix, it should not contain spaces nor dots @@ -109,14 +112,17 @@ def __init__( if not group_id: raise ValueError("group_id must not be empty") - dag = dag or DagContext.get_current_dag() - if not parent_group and not dag: raise AirflowException("TaskGroup can only be used inside a dag") self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) if not self._parent_group: raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") + if dag is not self._parent_group.dag: + raise RuntimeError( + "Cannot mix TaskGroups from different DAGs: %s and %s", dag, self._parent_group.dag + ) + self.used_group_ids = self._parent_group.used_group_ids # if given group_id already used assign suffix by incrementing largest used suffix integer @@ -177,6 +183,14 @@ def is_root(self) -> bool: """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" return not self.group_id + @property + def upstream_task_ids(self) -> Set[str]: + return self.__dict__['upstream_task_ids'] + + @property + def downstream_task_ids(self) -> Set[str]: + return self.__dict__['downstream_task_ids'] + def __iter__(self): for child in self.children.values(): if isinstance(child, TaskGroup): @@ -193,12 +207,28 @@ def add(self, task: DAGNode) -> None: raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") if isinstance(task, TaskGroup): + if self.dag: + if task.dag is not None and self.dag is not task.dag: + raise RuntimeError( + "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag + ) + task.dag = self.dag if task.children: raise AirflowException("Cannot add a non-empty TaskGroup") self.children[key] = task task.task_group = weakref.proxy(self) + def _remove(self, task: DAGNode) -> None: + key = task.node_id + + if key not in self.children: + raise KeyError(f"Node id {key!r} not part of this task group") + + self.used_group_ids.remove(key) + del self.children[key] + task.task_group = None + @property def group_id(self) -> Optional[str]: """group_id of this TaskGroup.""" @@ -242,7 +272,7 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None: else: self.downstream_task_ids.add(task.node_id) - def _set_relative( + def _set_relatives( self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]], upstream: bool = False, @@ -355,6 +385,23 @@ def get_child_by_label(self, label: str) -> DAGNode: """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)""" return self.children[self.child_id(label)] + def map(self, arg: Iterable) -> "MappedTaskGroup": + if self.children: + raise RuntimeError("Cannot map a TaskGroup that already has children") + if not self.group_id: + raise RuntimeError("Cannot map a TaskGroup before it has a group_id") + if self._parent_group: + self._parent_group._remove(self) + return MappedTaskGroup(self._group_id) + + +class MappedTaskGroup(TaskGroup): + """ + A TaskGroup that is dynamically expanded at run time. + + Do not create instances of this class directly, instead use :meth:`TaskGroup.map` + """ + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 2dcef95885983..841e0113e30e1 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -21,6 +21,7 @@ from airflow.decorators import dag, task_group as task_group_decorator from airflow.models import DAG +from airflow.models.baseoperator import MappedOperator from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator @@ -28,6 +29,8 @@ from airflow.utils.dates import days_ago from airflow.utils.task_group import TaskGroup from airflow.www.views import dag_edges, task_group_to_dict +from tests.models import DEFAULT_DATE +from tests.test_utils.mock_operators import MockOperator EXPECTED_JSON = { 'id': None, @@ -998,3 +1001,49 @@ def wrap(): assert isinstance(total_3, XComArg) wrap() + + +def test_map() -> None: + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + start = MockOperator(task_id="start") + end = MockOperator(task_id="end") + literal = ['a', 'b', 'c'] + with TaskGroup("process_one").map(literal) as process_one: + one = MockOperator(task_id='one') + two = MockOperator(task_id='two') + three = MockOperator(task_id='three') + + one >> two >> three + + start >> process_one >> end + + # check the mapped operators are attached to the task broup + assert process_one.has_task(one) + + assert isinstance(one, MappedOperator) + assert start.downstream_list == [one] + assert one in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 5 + + assert end.upstream_list == [three] + assert three.downstream_list == [end] + + +def test_nested_map() -> None: + with DAG("test-dag", start_date=DEFAULT_DATE): + start = MockOperator(task_id="start") + end = MockOperator(task_id="end") + literal = ['a', 'b', 'c'] + with TaskGroup("process_one").map(literal) as process_one: + one = MockOperator(task_id='one') + + with TaskGroup("process_two").map(literal) as process_one_two: + two = MockOperator(task_id='two') + three = MockOperator(task_id='three') + two >> three + + four = MockOperator(task_id='four') + one >> process_one_two >> four + + start >> process_one >> end From 69fd94a0de647894aa51bc111ef3b9cbb1140f33 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 9 Dec 2021 15:13:40 +0000 Subject: [PATCH 03/17] Add mapping and partial support to TaskFlow tasks --- airflow/decorators/base.py | 138 ++++++++++++++++++++++++-------- airflow/models/xcom_arg.py | 21 +++-- tests/decorators/test_python.py | 54 ++++++++++++- 3 files changed, 162 insertions(+), 51 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1a7e717d7bda2..a294e4a55657d 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -21,8 +21,10 @@ from inspect import signature from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast +import attr + from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models.baseoperator import BaseOperator, MappedOperator from airflow.models.dag import DAG, DagContext from airflow.models.xcom_arg import XComArg from airflow.utils.context import Context @@ -179,6 +181,87 @@ def _hook_apply_defaults(self, *args, **kwargs): T = TypeVar("T", bound=Callable) +OperatorSubclass = TypeVar("OperatorSubclass", bound=BaseOperator) + + +@attr.define +class OperatorWrapper(Generic[T, OperatorSubclass]): + """ + Helper class for providing dynamic task mapping to decorated functions. + + ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. + + :meta private: + """ + function: T = attr.ib(validator=attr.validators.is_callable()) + operator_class: Type[BaseOperator] + multiple_outputs: bool = attr.ib() + kwargs: Dict[str, Any] = attr.ib(factory=dict) + + function_arg_names: Set[str] = attr.ib(repr=False) + + @function_arg_names.default + def _get_arg_names(self): + return set(signature(self.function).parameters) + + @function.validator + def _validate_function(self, _, f): + if not callable(f): + # Not likely to be hit -- checked by the decorator first! + raise TypeError('`python_callable` param must be callable') + if 'self' in self.function_arg_names: + raise TypeError('@task does not support methods') + + @multiple_outputs.default + def _infer_multiple_outputs(self): + sig = signature(self.function).return_annotation + ttype = getattr(sig, "__origin__", None) + + return sig is not inspect.Signature.empty and ttype in (dict, Dict) + + def __attrs_post_init__(self): + self.kwargs.setdefault('task_id', self.function.__name__) + + def __call__(self, *args, **kwargs) -> XComArg: + op = self.operator_class( + python_callable=self.function, + op_args=args, + op_kwargs=kwargs, + multiple_outputs=self.multiple_outputs, + **self.kwargs, + ) + if self.function.__doc__: + op.doc_md = self.function.__doc__ + return XComArg(op) + + def map( + self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + ) -> XComArg: + + dag = dag or DagContext.get_current_dag() + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) + + mapped_arg = XComArg( + operator=MappedOperator( + operator_class=self.operator_class, + task_id=task_id, + dag=dag, + task_group=task_group, + partial_kwargs=self.kwargs, + mapped_kwargs=kwargs, + ), + ) + + return mapped_arg + + def partial( + self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + ) -> "OperatorWrapper[T, OperatorSubclass]": + partial_kwargs = self.kwargs.copy() + partial_kwargs.update(kwargs) + return attr.evolve(self, kwargs=partial_kwargs) + def task_decorator_factory( python_callable: Optional[Callable] = None, @@ -202,38 +285,23 @@ def task_decorator_factory( :type decorated_operator_class: BaseOperator """ - # try to infer from type annotation - if python_callable and multiple_outputs is None: - sig = signature(python_callable).return_annotation - ttype = getattr(sig, "__origin__", None) - - multiple_outputs = sig != inspect.Signature.empty and ttype in (dict, Dict) - - def wrapper(f: T): - """ - Python wrapper to generate PythonDecoratedOperator out of simple python functions. - Used for Airflow Decorated interface - """ - validate_python_callable(f) - kwargs.setdefault('task_id', f.__name__) - - @functools.wraps(f) - def factory(*args, **f_kwargs): - op = decorated_operator_class( - python_callable=f, - op_args=args, - op_kwargs=f_kwargs, - multiple_outputs=multiple_outputs, - **kwargs, - ) - if f.__doc__: - op.doc_md = f.__doc__ - return XComArg(op) - - return cast(T, factory) - - if callable(python_callable): - return wrapper(python_callable) + if multiple_outputs is None: + multiple_outputs = attr.NOTHING + if python_callable: + return OperatorWrapper( + function=python_callable, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ) # type: ignore elif python_callable is not None: - raise AirflowException('No args allowed while using @task, use kwargs instead') - return wrapper + raise TypeError('No args allowed while using @task, use kwargs instead') + return cast( + "Callable[[T], T]", + functools.partial( + OperatorWrapper, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ), + ) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index d6f0f290deca8..3f3d5d6ef2f1d 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -17,8 +17,8 @@ from typing import Any, List, Optional, Sequence, Union from airflow.exceptions import AirflowException -from airflow.models.baseoperator import BaseOperator -from airflow.models.taskmixin import DependencyMixin +from airflow.models.baseoperator import BaseOperator, MappedOperator +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier @@ -59,7 +59,7 @@ class XComArg(DependencyMixin): :type key: str """ - def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY): + def __init__(self, operator: Union[BaseOperator, MappedOperator], key: str = XCOM_RETURN_KEY): self._operator = operator self._key = key @@ -93,17 +93,17 @@ def __str__(self): return xcom_pull @property - def operator(self) -> BaseOperator: + def operator(self) -> Union[BaseOperator, MappedOperator]: """Returns operator of this XComArg.""" return self._operator @property - def roots(self) -> List[BaseOperator]: + def roots(self) -> List[DAGNode]: """Required by TaskMixin""" return [self._operator] @property - def leaves(self) -> List[BaseOperator]: + def leaves(self) -> List[DAGNode]: """Required by TaskMixin""" return [self._operator] @@ -133,13 +133,10 @@ def resolve(self, context: Context) -> Any: Pull XCom value for the existing arg. This method is run during ``op.execute()`` in respectable context. """ - resolved_value = self.operator.xcom_pull( - context=context, - task_ids=[self.operator.task_id], - key=str(self.key), # xcom_pull supports only key as str - dag_id=self.operator.dag.dag_id, - ) + resolved_value = context['ti'].xcom_pull(task_ids=[self.operator.task_id], key=str(self.key)) if not resolved_value: + if TYPE_CHECKING: + assert self.operator.dag raise AirflowException( f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} ' f'with key="{self.key}"" is not found!' diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 610fa84db7b15..d0840cf4b6418 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -25,6 +25,7 @@ from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow.models.baseoperator import MappedOperator from airflow.models.xcom_arg import XComArg from airflow.utils import timezone from airflow.utils.session import create_session @@ -109,7 +110,7 @@ def test_python_operator_python_callable_is_callable(self): """Tests that @task will only instantiate if the python_callable argument is callable.""" not_callable = {} - with pytest.raises(AirflowException): + with pytest.raises(TypeError): task_decorator(not_callable, dag=self.dag) def test_infer_multiple_outputs_using_typing(self): @@ -201,7 +202,7 @@ def add_number(num: int) -> int: def test_fail_method(self): """Tests that @task will fail if signature is not binding.""" - with pytest.raises(AirflowException): + with pytest.raises(TypeError): class Test: num = 2 @@ -210,8 +211,6 @@ class Test: def add_number(self, num: int) -> int: return self.num + num - Test().add_number(2) - def test_fail_multiple_outputs_key_type(self): @task_decorator(multiple_outputs=True) def add_number(num: int): @@ -498,3 +497,50 @@ def add_2(number: int): ret = add_2(test_number) assert ret.operator.doc_md.strip(), "Adds 2 to number." + + +def test_mapped_decorator() -> None: + @task_decorator + def double(number: int): + return number * 2 + + with DAG('test_dag', start_date=DEFAULT_DATE): + literal = [1, 2, 3] + doubled_0 = double.map(number=literal) + doubled_1 = double.map(number=literal) + + assert isinstance(doubled_0, XComArg) + assert isinstance(doubled_0.operator, MappedOperator) + assert doubled_0.operator.task_id == "double" + assert doubled_0.operator.mapped_kwargs == {"number": literal} + + assert doubled_1.operator.task_id == "double__1" + + +def test_partial_mapped_decorator() -> None: + @task_decorator + def product(number: int, multiple: int): + return number * multiple + + with DAG('test_dag', start_date=DEFAULT_DATE) as dag: + literal = [1, 2, 3] + quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal) + doubled = product.partial(multiple=2).map(number=literal) + trippled = product.partial(multiple=3).map(number=literal) + + product.partial(multiple=2) + + assert isinstance(doubled, XComArg) + assert isinstance(doubled.operator, MappedOperator) + assert doubled.operator.task_id == "product" + assert doubled.operator.mapped_kwargs == {"number": literal} + assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2} + + assert trippled.operator.task_id == "product__1" + assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3} + + assert quadrupled.operator.task_id == "times_4" + + assert doubled.operator is not trippled.operator + + assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks From 7256a373354f17f1c5bdf6b51ac30eae6961a258 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 4 Jan 2022 14:24:08 +0000 Subject: [PATCH 04/17] fixup! Add mapping and partial support to TaskFlow tasks --- airflow/decorators/base.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index a294e4a55657d..c6f43c386b9f5 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -19,7 +19,20 @@ import inspect import re from inspect import signature -from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast +from typing import ( + Any, + Callable, + Collection, + Dict, + Generic, + Mapping, + Optional, + Sequence, + Set, + Type, + TypeVar, + cast, +) import attr @@ -193,6 +206,7 @@ class OperatorWrapper(Generic[T, OperatorSubclass]): :meta private: """ + function: T = attr.ib(validator=attr.validators.is_callable()) operator_class: Type[BaseOperator] multiple_outputs: bool = attr.ib() From 80a0d921169513c0d40f07f3992c5707b2f934ed Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 10 Dec 2021 17:27:02 +0000 Subject: [PATCH 05/17] Add mapping support to `@task_group` decorator --- airflow/decorators/base.py | 19 +++--- airflow/decorators/task_group.py | 112 ++++++++++++++++++++++--------- airflow/utils/task_group.py | 13 +++- tests/utils/test_task_group.py | 47 ++++++++++++- 4 files changed, 148 insertions(+), 43 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index c6f43c386b9f5..6171e3d84a333 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -194,7 +194,7 @@ def _hook_apply_defaults(self, *args, **kwargs): T = TypeVar("T", bound=Callable) -OperatorSubclass = TypeVar("OperatorSubclass", bound=BaseOperator) +OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") @attr.define @@ -208,10 +208,11 @@ class OperatorWrapper(Generic[T, OperatorSubclass]): """ function: T = attr.ib(validator=attr.validators.is_callable()) - operator_class: Type[BaseOperator] + operator_class: Type[OperatorSubclass] multiple_outputs: bool = attr.ib() kwargs: Dict[str, Any] = attr.ib(factory=dict) + decorator_name: str = attr.ib(repr=False, default="task") function_arg_names: Set[str] = attr.ib(repr=False) @function_arg_names.default @@ -220,11 +221,8 @@ def _get_arg_names(self): @function.validator def _validate_function(self, _, f): - if not callable(f): - # Not likely to be hit -- checked by the decorator first! - raise TypeError('`python_callable` param must be callable') if 'self' in self.function_arg_names: - raise TypeError('@task does not support methods') + raise TypeError(f'@{self.decorator_name} does not support methods') @multiple_outputs.default def _infer_multiple_outputs(self): @@ -279,8 +277,9 @@ def partial( def task_decorator_factory( python_callable: Optional[Callable] = None, + *, multiple_outputs: Optional[bool] = None, - decorated_operator_class: Optional[Type[BaseOperator]] = None, + decorated_operator_class: Type[BaseOperator], **kwargs, ) -> Callable[[T], T]: """ @@ -300,14 +299,14 @@ def task_decorator_factory( """ if multiple_outputs is None: - multiple_outputs = attr.NOTHING + multiple_outputs = cast(bool, attr.NOTHING) if python_callable: - return OperatorWrapper( + return OperatorWrapper( # type: ignore function=python_callable, multiple_outputs=multiple_outputs, operator_class=decorated_operator_class, kwargs=kwargs, - ) # type: ignore + ) elif python_callable is not None: raise TypeError('No args allowed while using @task, use kwargs instead') return cast( diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 04ef1614c54c4..936b05de589a4 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -20,19 +20,97 @@ together when the DAG is displayed graphically. """ import functools +import warnings from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, cast, overload -from airflow.utils.task_group import TaskGroup +import attr + +from airflow.utils.task_group import MappedTaskGroup, TaskGroup if TYPE_CHECKING: from airflow.models import DAG F = TypeVar("F", bound=Callable[..., Any]) +T = TypeVar("T", bound=Callable) +R = TypeVar("R") task_group_sig = signature(TaskGroup.__init__) +@attr.define +class TaskGroupDecorator(Generic[R]): + """:meta private:""" + + function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable()) + kwargs: Dict[str, Any] = attr.ib(factory=dict) + """kwargs for the TaskGroup""" + + @function.validator + def _validate_function(self, _, f): + if 'self' in signature(f).parameters: + raise TypeError('@task_group does not support methods') + + @kwargs.validator + def _validate(self, _, kwargs): + task_group_sig.bind_partial(**kwargs) + + def __attrs_post_init__(self): + self.kwargs.setdefault('group_id', self.function.__name__) + + def _make_task_group(self, **kwargs) -> TaskGroup: + return TaskGroup(**kwargs) + + def __call__(self, *args, **kwargs) -> R: + with self._make_task_group(add_suffix_on_collision=True, **self.kwargs): + # Invoke function to run Tasks inside the TaskGroup + return self.function(*args, **kwargs) + + def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": + return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs) + + def map(self, **kwargs) -> R: + return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).map(**kwargs) + + +@attr.define +class MappedTaskGroupDecorator(TaskGroupDecorator[R]): + """:meta private:""" + + partial_kwargs: Dict[str, Any] = attr.ib(factory=dict) + """static kwargs for the decorated function""" + mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict) + """kwargs for the decorated function""" + + _invoked: bool = attr.ib(init=False, default=False, repr=False) + + def __call__(self, *args, **kwargs): + raise RuntimeError("Mapped @task_group's cannot be called. Use `.map` and `.partial` instead") + + def _make_task_group(self, **kwargs) -> MappedTaskGroup: + tg = MappedTaskGroup(**kwargs) + tg.partial_kwargs = self.partial_kwargs + tg.mapped_kwargs = self.mapped_kwargs + return tg + + def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": + self.partial_kwargs.update(kwargs) + return self + + def map(self, **kwargs) -> R: + self.mapped_kwargs.update(kwargs) + + call_kwargs = self.partial_kwargs.copy() + call_kwargs.update({k: object() for k in self.mapped_kwargs}) + + self._invoked = True + return super().__call__(**call_kwargs) + + def __del__(self): + if not self._invoked: + warnings.warn(f"Partial task group {self.function.__name__} was never mapped!") + + # This covers the @task_group() case. Annotations are copied from the TaskGroup # class, only providing a default to 'group_id' (this is optional for the # decorator and defaults to the decorated function's name). Please keep them in @@ -54,7 +132,6 @@ def task_group( ) -> Callable[[F], F]: ... - # This covers the @task_group case (no parentheses). @overload def task_group(python_callable: F) -> F: @@ -73,31 +150,6 @@ def task_group(python_callable=None, *tg_args, **tg_kwargs): :param tg_args: Positional arguments for the TaskGroup object. :param tg_kwargs: Keyword arguments for the TaskGroup object. """ - - def wrapper(f): - # Setting group_id as function name if not given in kwarg group_id - if not tg_args and 'group_id' not in tg_kwargs: - tg_kwargs['group_id'] = f.__name__ - task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs) - - @functools.wraps(f) - def factory(*args, **kwargs): - # Generate signature for decorated function and bind the arguments when called - # we do this to extract parameters so we can annotate them on the DAG object. - # In addition, this fails if we are missing any args/kwargs with TypeError as expected. - # Apply defaults to capture default values if set. - - # Initialize TaskGroup with bound arguments - with TaskGroup( - *task_group_bound_args.args, - add_suffix_on_collision=True, - **task_group_bound_args.kwargs, - ): - # Invoke function to run Tasks inside the TaskGroup - return f(*args, **kwargs) - - return factory - if callable(python_callable): - return wrapper(python_callable) - return wrapper + return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs) + return cast("Callable[[T], T]", functools.partial(TaskGroupDecorator, kwargs=tg_kwargs)) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index f456e88fc619e..84702aea3bf08 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -22,11 +22,12 @@ import copy import re import weakref -from typing import TYPE_CHECKING, Dict, Generator, Iterable, List, Optional, Sequence, Set, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException, DuplicateTaskIdFound from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.utils.helpers import validate_group_key +from airflow.utils.types import NOTSET if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator @@ -392,7 +393,11 @@ def map(self, arg: Iterable) -> "MappedTaskGroup": raise RuntimeError("Cannot map a TaskGroup before it has a group_id") if self._parent_group: self._parent_group._remove(self) - return MappedTaskGroup(self._group_id) + tg = MappedTaskGroup(self._group_id) + tg.mapped_arg = arg + tg.mapped_kwargs = {} + tg.partial_kwargs = {} + return tg class MappedTaskGroup(TaskGroup): @@ -402,6 +407,10 @@ class MappedTaskGroup(TaskGroup): Do not create instances of this class directly, instead use :meth:`TaskGroup.map` """ + mapped_arg: Any = NOTSET + mapped_kwargs: Dict[str, Any] + partial_kwargs: Dict[str, Any] + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 841e0113e30e1..51e0319e49963 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -27,7 +27,7 @@ from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator from airflow.utils.dates import days_ago -from airflow.utils.task_group import TaskGroup +from airflow.utils.task_group import MappedTaskGroup, TaskGroup from airflow.www.views import dag_edges, task_group_to_dict from tests.models import DEFAULT_DATE from tests.test_utils.mock_operators import MockOperator @@ -1018,7 +1018,9 @@ def test_map() -> None: start >> process_one >> end # check the mapped operators are attached to the task broup + assert isinstance(process_one, MappedTaskGroup) assert process_one.has_task(one) + assert process_one.mapped_arg is literal assert isinstance(one, MappedOperator) assert start.downstream_list == [one] @@ -1047,3 +1049,46 @@ def test_nested_map() -> None: one >> process_one_two >> four start >> process_one >> end + + +def test_decorator_unknown_args(): + """Test that unknown args passed to the decorator cause an error at parse time""" + with pytest.raises(TypeError): + + @task_group_decorator(b=2) + def tg(): + ... + + +def test_decorator_partial_unmapped(): + @task_group_decorator + def tg(): + ... + + with pytest.warns(UserWarning, match='was never mapped'): + with DAG("test-dag", start_date=DEFAULT_DATE): + tg.partial() + + +def test_decorator_map(): + @task_group_decorator + def my_task_group(my_arg_1: str, unmapped: bool): + assert unmapped is True + assert isinstance(my_arg_1, object) + task_1 = DummyOperator(task_id="task_1") + task_2 = BashOperator(task_id="task_2", bash_command='echo "${my_arg_1}"', env={'my_arg_1': my_arg_1}) + task_3 = DummyOperator(task_id="task_3") + task_1 >> [task_2, task_3] + + return task_1, task_2, task_3 + + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + lines = ["foo", "bar", "baz"] + + (task_1, task_2, task_3) = my_task_group.partial(unmapped=True).map(my_arg_1=lines) + + assert task_1 in dag.tasks + + tg = dag.task_group.get_child_by_label("my_task_group") + assert isinstance(tg, MappedTaskGroup) + assert "my_arg_1" in tg.mapped_kwargs From 273f32cec1ba732329c92afc3db89a3ef66e29e1 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 13:54:06 +0000 Subject: [PATCH 06/17] Improve validation of map and partial argument names --- airflow/decorators/base.py | 50 +++++++++++++++++--------- airflow/models/baseoperator.py | 60 ++++++++++++++++++------------- tests/decorators/test_python.py | 14 ++++++++ tests/models/test_baseoperator.py | 7 +++- 4 files changed, 89 insertions(+), 42 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 6171e3d84a333..99e9602b53b8a 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -17,8 +17,8 @@ import functools import inspect +import itertools import re -from inspect import signature from typing import ( Any, Callable, @@ -54,7 +54,7 @@ def validate_python_callable(python_callable): """ if not callable(python_callable): raise TypeError('`python_callable` param must be callable') - if 'self' in signature(python_callable).parameters.keys(): + if 'self' in inspect.signature(python_callable).parameters.keys(): raise AirflowException('@task does not support methods') @@ -142,7 +142,7 @@ def __init__( op_kwargs = op_kwargs or {} # Check that arguments can be binded - signature(python_callable).bind(*op_args, **op_kwargs) + inspect.signature(python_callable).bind(*op_args, **op_kwargs) self.multiple_outputs = multiple_outputs self.op_args = op_args self.op_kwargs = op_kwargs @@ -184,7 +184,7 @@ def _hook_apply_defaults(self, *args, **kwargs): python_callable = kwargs['python_callable'] default_args = kwargs.get('default_args') or {} op_kwargs = kwargs.get('op_kwargs') or {} - f_sig = signature(python_callable) + f_sig = inspect.signature(python_callable) for arg in f_sig.parameters: if arg not in op_kwargs and arg in default_args: op_kwargs[arg] = default_args[arg] @@ -217,7 +217,7 @@ class OperatorWrapper(Generic[T, OperatorSubclass]): @function_arg_names.default def _get_arg_names(self): - return set(signature(self.function).parameters) + return set(inspect.signature(self.function).parameters) @function.validator def _validate_function(self, _, f): @@ -226,7 +226,7 @@ def _validate_function(self, _, f): @multiple_outputs.default def _infer_multiple_outputs(self): - sig = signature(self.function).return_annotation + sig = inspect.signature(self.function).return_annotation ttype = getattr(sig, "__origin__", None) return sig is not inspect.Signature.empty and ttype in (dict, Dict) @@ -246,6 +246,21 @@ def __call__(self, *args, **kwargs) -> XComArg: op.doc_md = self.function.__doc__ return XComArg(op) + def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()): + unknown_args = kwargs.copy() + for name in itertools.chain(self.function_arg_names, valid_names): + unknown_args.pop(name, None) + + if not unknown_args: + # If we have no args left ot check, we are valid + return + + if len(unknown_args) == 1: + raise TypeError(f'{funcname} got unexpected keyword argument {unknown_args.popitem()[0]!r}') + else: + names = ", ".join(repr(n) for n in unknown_args) + raise TypeError(f'{funcname} got unexpected keyword arguments {names}') + def map( self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> XComArg: @@ -254,22 +269,25 @@ def map( task_group = task_group or TaskGroupContext.get_current_task_group(dag) task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) - mapped_arg = XComArg( - operator=MappedOperator( - operator_class=self.operator_class, - task_id=task_id, - dag=dag, - task_group=task_group, - partial_kwargs=self.kwargs, - mapped_kwargs=kwargs, - ), + self._validate_arg_names("map", kwargs) + + operator = MappedOperator( + operator_class=self.operator_class, + task_id=task_id, + dag=dag, + task_group=task_group, + partial_kwargs=self.kwargs, + # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves + mapped_kwargs={}, ) + operator.mapped_kwargs.update(kwargs) - return mapped_arg + return XComArg(operator=operator) def partial( self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> "OperatorWrapper[T, OperatorSubclass]": + self._validate_arg_names("partial", kwargs, {'task_id'}) partial_kwargs = self.kwargs.copy() partial_kwargs.update(kwargs) return attr.evolve(self, kwargs=partial_kwargs) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index dd8eb0aec4e17..5759ec5a1457d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -237,31 +237,9 @@ def __new__(cls, name, bases, namespace, **kwargs): # The class level partial function. This is what handles the actual mapping def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): - unknown_args = set(kwargs.keys()) - # Validate that the args we passed are known -- at call/DAG parse time, not run time! - # - # This loop _assumes_ that all unknown args from a class are passed to the superclass's __init__, but - # there is no way for us to validate that this is actually what operators do. - for clazz in cls.mro(): - # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access "__init__" directly - init = clazz.__init__ # type: ignore - - if not hasattr(init, '_BaseOperatorMeta__param_names'): - continue - unknown_args.difference_update(init.__param_names) - if not unknown_args: - # If we have no args left ot check: stop looking at the MRO chian - break - - if unknown_args: - if len(unknown_args) == 1: - raise TypeError( - f'{cls.__name__}.partial got unexpected keyword argument {unknown_args.pop()!r}' - ) - else: - names = ", ".join(repr(n) for n in unknown_args) - raise TypeError(f'{cls.__name__}.partial got unexpected keyword arguments {names}') operator_class = cast("Type[BaseOperator]", cls) + # Validate that the args we passed are known -- at call/DAG parse time, not run time! + _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) return MappedOperator( task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={} ) @@ -1665,6 +1643,36 @@ def map(self, **kwargs) -> "MappedOperator": ) +def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]): + if isinstance(str, cls): + # Serialized version -- would have been validated at parse time + return + + # use a dict so order of args is same as code order + unknown_args = value.copy() + for clazz in cls.mro(): + # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access "__init__" directly + init = clazz.__init__ # type: ignore + + if not hasattr(init, '_BaseOperatorMeta__param_names'): + continue + + for name in init._BaseOperatorMeta__param_names: + unknown_args.pop(name, None) + + if not unknown_args: + # If we have no args left ot check: stop looking at the MRO chian + return + + if len(unknown_args) == 1: + raise TypeError( + f'{cls.__name__}.{func_name} got unexpected keyword argument {unknown_args.popitem()[0]!r}' + ) + else: + names = ", ".join(repr(n) for n in unknown_args) + raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword arguments {names}') + + @attr.define(kw_only=True) class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG""" @@ -1672,7 +1680,9 @@ class MappedOperator(DAGNode): operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__) task_id: str partial_kwargs: Dict[str, Any] - mapped_kwargs: Dict[str, Any] + mapped_kwargs: Dict[str, Any] = attr.ib( + validator=lambda self, _, v: _validate_kwarg_names_for_mapping(self.operator_class, "map", v) + ) operator: Optional[BaseOperator] = None dag: Optional["DAG"] = None upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index d0840cf4b6418..48f6c144ed24a 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -517,6 +517,20 @@ def double(number: int): assert doubled_1.operator.task_id == "double__1" +def test_mapped_decorator_invalid_args() -> None: + @task_decorator + def double(number: int): + return number * 2 + + with DAG('test_dag', start_date=DEFAULT_DATE): + literal = [1, 2, 3] + + with pytest.raises(TypeError, match="arguments 'other', 'b'"): + double.partial(other=1, b='a') + with pytest.raises(TypeError, match="argument 'other'"): + double.map(number=literal, other=1) + + def test_partial_mapped_decorator() -> None: @task_decorator def product(number: int, multiple: int): diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index d0373806df6cf..fc6815ce77114 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -696,6 +696,11 @@ def test_task_mapping_default_args(): assert mapped.start_date == pendulum.instance(default_args['start_date']) +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator(task_id='a').map(file=[1, 2, {'a': 'b'}]) + + def test_partial_on_instance() -> None: """`.partial` on an instance should fail -- it's only designed to be called on classes""" with pytest.raises(TypeError): @@ -715,5 +720,5 @@ def test_partial_on_class_invalid_ctor_args() -> None: I.e. if an arg is not known on the class or any of its parent classes we error at parse time """ - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): MockOperator.partial(task_id='a', foo='bar', bar=2) From 25a7fdd744224a7ad8bfb5ec5edc0cae99bed95a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 13:57:09 +0000 Subject: [PATCH 07/17] Remove args from map and parial decorators --- airflow/decorators/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 99e9602b53b8a..1956da12aee90 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -262,7 +262,7 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names raise TypeError(f'{funcname} got unexpected keyword arguments {names}') def map( - self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> XComArg: dag = dag or DagContext.get_current_dag() @@ -285,7 +285,7 @@ def map( return XComArg(operator=operator) def partial( - self, *args, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> "OperatorWrapper[T, OperatorSubclass]": self._validate_arg_names("partial", kwargs, {'task_id'}) partial_kwargs = self.kwargs.copy() From ad8f047fdd26f015103d0b9fb62c2c6d5635cad7 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 14:17:52 +0000 Subject: [PATCH 08/17] Update airflow/models/dag.py Co-authored-by: Kaxil Naik --- airflow/models/dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index db06440a0bc9a..a750b58fc4a79 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2198,7 +2198,7 @@ def add_tasks(self, tasks): self.add_task(task) def _remove_task(self, task_id: str) -> None: - # This is "private" as removing could leave a whole in dependencies if done incorrectly, and this + # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this # doesn't guard against that task = self.task_dict.pop(task_id) tg = getattr(task, 'task_group', None) From 27b12cdfa6eea44ff2921f328755a672ffa0876d Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 16:09:04 +0000 Subject: [PATCH 09/17] Address some code review comments --- airflow/decorators/base.py | 30 +++++++++++++----------- airflow/models/baseoperator.py | 39 +++++++++++++++++-------------- airflow/models/dag.py | 8 +++---- tests/models/test_baseoperator.py | 2 +- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1956da12aee90..f098bb161e053 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -36,6 +36,7 @@ import attr +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator, MappedOperator from airflow.models.dag import DAG, DagContext @@ -197,8 +198,8 @@ def _hook_apply_defaults(self, *args, **kwargs): OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") -@attr.define -class OperatorWrapper(Generic[T, OperatorSubclass]): +@attr.define(slots=False) +class _TaskDecorator(Generic[T, OperatorSubclass]): """ Helper class for providing dynamic task mapping to decorated functions. @@ -213,11 +214,14 @@ class OperatorWrapper(Generic[T, OperatorSubclass]): kwargs: Dict[str, Any] = attr.ib(factory=dict) decorator_name: str = attr.ib(repr=False, default="task") - function_arg_names: Set[str] = attr.ib(repr=False) - @function_arg_names.default - def _get_arg_names(self): - return set(inspect.signature(self.function).parameters) + @cached_property + def function_signature(self): + return inspect.signature(self.function) + + @cached_property + def function_arg_names(self) -> Set[str]: + return set(self.function_signature.parameters) @function.validator def _validate_function(self, _, f): @@ -226,10 +230,10 @@ def _validate_function(self, _, f): @multiple_outputs.default def _infer_multiple_outputs(self): - sig = inspect.signature(self.function).return_annotation - ttype = getattr(sig, "__origin__", None) + return_type = self.function_signature.return_annotation + ttype = getattr(return_type, "__origin__", None) - return sig is not inspect.Signature.empty and ttype in (dict, Dict) + return return_type is not inspect.Signature.empty and ttype in (dict, Dict) def __attrs_post_init__(self): self.kwargs.setdefault('task_id', self.function.__name__) @@ -256,7 +260,7 @@ def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names return if len(unknown_args) == 1: - raise TypeError(f'{funcname} got unexpected keyword argument {unknown_args.popitem()[0]!r}') + raise TypeError(f'{funcname} got unexpected keyword argument {next(iter(unknown_args))!r}') else: names = ", ".join(repr(n) for n in unknown_args) raise TypeError(f'{funcname} got unexpected keyword arguments {names}') @@ -286,7 +290,7 @@ def map( def partial( self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs - ) -> "OperatorWrapper[T, OperatorSubclass]": + ) -> "_TaskDecorator[T, OperatorSubclass]": self._validate_arg_names("partial", kwargs, {'task_id'}) partial_kwargs = self.kwargs.copy() partial_kwargs.update(kwargs) @@ -319,7 +323,7 @@ def task_decorator_factory( if multiple_outputs is None: multiple_outputs = cast(bool, attr.NOTHING) if python_callable: - return OperatorWrapper( # type: ignore + return _TaskDecorator( # type: ignore function=python_callable, multiple_outputs=multiple_outputs, operator_class=decorated_operator_class, @@ -330,7 +334,7 @@ def task_decorator_factory( return cast( "Callable[[T], T]", functools.partial( - OperatorWrapper, + _TaskDecorator, multiple_outputs=multiple_outputs, operator_class=decorated_operator_class, kwargs=kwargs, diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5759ec5a1457d..03ef46e5157e2 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1630,17 +1630,7 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) def map(self, **kwargs) -> "MappedOperator": - return MappedOperator( - operator_class=type(self), - operator=self, - task_id=self.task_id, - task_group=getattr(self, 'task_group', None), - dag=getattr(self, '_dag', None), - start_date=self.start_date, - end_date=self.end_date, - partial_kwargs=self.__init_kwargs, - mapped_kwargs=kwargs, - ) + return MappedOperator.from_operator(self, kwargs) def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]): @@ -1651,7 +1641,7 @@ def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, v # use a dict so order of args is same as code order unknown_args = value.copy() for clazz in cls.mro(): - # Mypy doesn't like doing `clas.__init__`, Error is: Cannot access "__init__" directly + # Mypy doesn't like doing `class.__init__`, Error is: Cannot access "__init__" directly init = clazz.__init__ # type: ignore if not hasattr(init, '_BaseOperatorMeta__param_names'): @@ -1683,7 +1673,6 @@ class MappedOperator(DAGNode): mapped_kwargs: Dict[str, Any] = attr.ib( validator=lambda self, _, v: _validate_kwarg_names_for_mapping(self.operator_class, "map", v) ) - operator: Optional[BaseOperator] = None dag: Optional["DAG"] = None upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) @@ -1693,12 +1682,26 @@ class MappedOperator(DAGNode): start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) - def __attrs_post_init__(self): - if self.dag and self.operator: - # When BaseOperator() was called with a DAG, it would have been added straight away, but now we - # are mapped, we want to _remove_ that task (`self.operator`) from the dag - self.dag._remove_task(self.task_id) + @classmethod + def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator": + dag: Optional["DAG"] = getattr(operator, '_dag', None) + if dag: + # When BaseOperator() was called within a DAG, it would have been added straight away, but now we + # are mapped, we want to _remove_ that task from the dag + dag._remove_task(operator.task_id) + return MappedOperator( + operator_class=type(operator), + task_id=operator.task_id, + task_group=getattr(operator, 'task_group', None), + dag=getattr(operator, '_dag', None), + start_date=operator.start_date, + end_date=operator.end_date, + partial_kwargs=operator._BaseOperator__init_kwargs, # type: ignore + mapped_kwargs=mapped_kwargs, + ) + + def __attrs_post_init__(self): if self.task_group: self.task_id = self.task_group.child_id(self.task_id) self.task_group.add(self) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index a750b58fc4a79..136ed04ac8d38 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2075,10 +2075,10 @@ def filter_task_group(group, parent_group): # the cut. subdag_task_groups = dag.task_group.get_task_group_dict() for group in subdag_task_groups.values(): - group.upstream_group_ids.intersection_update(subdag_task_groups.keys()) - group.downstream_group_ids.intersection_update(subdag_task_groups.keys()) - group.upstream_task_ids.intersection_update(dag.task_dict.keys()) - group.downstream_task_ids.intersection_update(dag.task_dict.keys()) + group.upstream_group_ids.intersection_update(subdag_task_groups) + group.downstream_group_ids.intersection_update(subdag_task_groups) + group.upstream_task_ids.intersection_update(dag.task_dict) + group.downstream_task_ids.intersection_update(dag.task_dict) for t in dag.tasks: # Removing upstream/downstream references to tasks that did not diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index fc6815ce77114..c0a4593646100 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -676,7 +676,7 @@ def test_task_mapping_without_dag_context(): task1 >> mapped assert isinstance(mapped, MappedOperator) - assert mapped.operator + assert mapped in dag.tasks assert task1.downstream_list == [mapped] assert mapped in dag.tasks # At parse time there should only be two tasks! From e2fe0c3f28cebdef78c86dde980716d664bbc7a0 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 16:51:32 +0000 Subject: [PATCH 10/17] Code review --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 03ef46e5157e2..9bd8f47ed9e6c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -99,7 +99,7 @@ def __get__( self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None ) -> Callable[..., "MappedOperator"]: # Call this "partial" so it looks nicer in stack traces - def partial(*, task_id: str, **kwargs): + def partial(**kwargs): raise TypeError("partial can only be called on Operator classes, not Tasks themselves") if obj is not None: @@ -217,7 +217,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: return result apply_defaults.__non_optional_args = non_optional_args # type: ignore - apply_defaults.__param_names = set(non_varaidc_params.keys()) # type: ignore + apply_defaults.__param_names = set(non_varaidc_params) # type: ignore return cast(T, apply_defaults) From b246c7b6af0b8b76e136331e68216c916f0cdbcf Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 17:41:57 +0000 Subject: [PATCH 11/17] Move initializers to MappedTaskGroup constructor --- airflow/utils/task_group.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 84702aea3bf08..bf91c38408307 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -393,11 +393,7 @@ def map(self, arg: Iterable) -> "MappedTaskGroup": raise RuntimeError("Cannot map a TaskGroup before it has a group_id") if self._parent_group: self._parent_group._remove(self) - tg = MappedTaskGroup(self._group_id) - tg.mapped_arg = arg - tg.mapped_kwargs = {} - tg.partial_kwargs = {} - return tg + return MappedTaskGroup(group_id=self._group_id, mapped_arg=arg) class MappedTaskGroup(TaskGroup): @@ -411,6 +407,13 @@ class MappedTaskGroup(TaskGroup): mapped_kwargs: Dict[str, Any] partial_kwargs: Dict[str, Any] + def __init__(self, group_id: Optional[str] = None, mapped_arg: Any = NOTSET, **kwargs): + if mapped_arg is not NOTSET: + self.mapped_arg = mapped_arg + self.mapped_kwargs = {} + self.partial_kwargs = {} + super().__init__(group_id=group_id, **kwargs) + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" From 2362c78d762b9510db745343d73cf57242c8e956 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 20:11:10 +0000 Subject: [PATCH 12/17] Make DAGNode a proper Abstract Base Class --- airflow/models/taskmixin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index d608942642e7c..f352377c464e8 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -16,7 +16,7 @@ # under the License. import warnings -from abc import abstractmethod +from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Union import pendulum @@ -101,7 +101,7 @@ def __init_subclass__(cls) -> None: return super().__init_subclass__() -class DAGNode(DependencyMixin): +class DAGNode(DependencyMixin, metaclass=ABCMeta): """ A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or unmapped. From 6c2182053f97ae441c48b519c1b716a97d0f6b6a Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Fri, 17 Dec 2021 20:12:46 +0000 Subject: [PATCH 13/17] Apply suggestions from code review Co-authored-by: Jarek Potiuk --- airflow/models/baseoperator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 9bd8f47ed9e6c..db2d5b062478a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -126,13 +126,13 @@ def _apply_defaults(cls, func: T) -> T: # per decoration, i.e. each function decorated using apply_defaults will # have a different sig_cache. sig_cache = signature(func) - non_varaidc_params = { + non_variadic_params = { name: param for (name, param) in sig_cache.parameters.items() if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) } non_optional_args = { - name for (name, param) in non_varaidc_params.items() if param.default == param.empty + name for (name, param) in non_variadic_params.items() if param.default == param.empty } class autostacklevel_warn: @@ -217,7 +217,7 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: return result apply_defaults.__non_optional_args = non_optional_args # type: ignore - apply_defaults.__param_names = set(non_varaidc_params) # type: ignore + apply_defaults.__param_names = set(non_variadic_params) # type: ignore return cast(T, apply_defaults) From efbc0b27e8a216242c3f200da82f5d6cb08ee7a3 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 4 Jan 2022 19:33:56 +0800 Subject: [PATCH 14/17] Prevent mapping an already mapped Task/TaskGroup Also prevent calls like .partial(...).partial(...). It is uncertain whether these kinds of repeated partial/map calls have utility, so let's disable them entirely for now to simplify implementation. We can always add them if they are proven useful. --- airflow/decorators/base.py | 6 +++--- airflow/decorators/task_group.py | 18 +++++++++++------- airflow/models/baseoperator.py | 6 +++--- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index f098bb161e053..07798bf0a2c14 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -291,10 +291,10 @@ def map( def partial( self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> "_TaskDecorator[T, OperatorSubclass]": + if self.kwargs: + raise RuntimeError("Already a partial task") self._validate_arg_names("partial", kwargs, {'task_id'}) - partial_kwargs = self.kwargs.copy() - partial_kwargs.update(kwargs) - return attr.evolve(self, kwargs=partial_kwargs) + return attr.evolve(self, kwargs=kwargs) def task_decorator_factory( diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 936b05de589a4..6036a46d73925 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -82,10 +82,8 @@ class MappedTaskGroupDecorator(TaskGroupDecorator[R]): mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict) """kwargs for the decorated function""" - _invoked: bool = attr.ib(init=False, default=False, repr=False) - def __call__(self, *args, **kwargs): - raise RuntimeError("Mapped @task_group's cannot be called. Use `.map` and `.partial` instead") + raise RuntimeError("A mapped @task_group cannot be called. Use `.map` and `.partial` instead") def _make_task_group(self, **kwargs) -> MappedTaskGroup: tg = MappedTaskGroup(**kwargs) @@ -94,20 +92,26 @@ def _make_task_group(self, **kwargs) -> MappedTaskGroup: return tg def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": + if self.partial_kwargs: + raise RuntimeError("Already a partial task group") self.partial_kwargs.update(kwargs) return self def map(self, **kwargs) -> R: - self.mapped_kwargs.update(kwargs) + if self.mapped_kwargs: + raise RuntimeError("Already a mapped task group") + self.mapped_kwargs = kwargs call_kwargs = self.partial_kwargs.copy() - call_kwargs.update({k: object() for k in self.mapped_kwargs}) + duplicated_keys = set(call_kwargs).intersection(kwargs) + if duplicated_keys: + raise RuntimeError(f"Cannot map partial arguments: {', '.join(sorted(duplicated_keys))}") + call_kwargs.update({k: object() for k in kwargs}) - self._invoked = True return super().__call__(**call_kwargs) def __del__(self): - if not self._invoked: + if not self.mapped_kwargs: warnings.warn(f"Partial task group {self.function.__name__} was never mapped!") diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index db2d5b062478a..e1eee43317f7d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1724,9 +1724,9 @@ def map(self, **kwargs) -> "MappedOperator": :return: ``self`` for easier method chaining """ - mapped_kwargs = self.mapped_kwargs.copy() - mapped_kwargs.update(kwargs) - return attr.evolve(self, mapped_kwargs=mapped_kwargs) + if self.mapped_kwargs: + raise RuntimeError("Already a mapped task") + return attr.evolve(self, mapped_kwargs=kwargs) @property def roots(self) -> List["MappedOperator"]: From deb54df1c354d1223e096b104aa904683be3ad67 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 4 Jan 2022 14:31:23 +0000 Subject: [PATCH 15/17] fixup! Prevent mapping an already mapped Task/TaskGroup --- airflow/decorators/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 07798bf0a2c14..f098bb161e053 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -291,10 +291,10 @@ def map( def partial( self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs ) -> "_TaskDecorator[T, OperatorSubclass]": - if self.kwargs: - raise RuntimeError("Already a partial task") self._validate_arg_names("partial", kwargs, {'task_id'}) - return attr.evolve(self, kwargs=kwargs) + partial_kwargs = self.kwargs.copy() + partial_kwargs.update(kwargs) + return attr.evolve(self, kwargs=partial_kwargs) def task_decorator_factory( From 16df4abeda17b4262f9bb44f25c7fad0a3edf324 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 4 Jan 2022 14:52:25 +0000 Subject: [PATCH 16/17] fixup! Add mapping and partial support to TaskFlow tasks --- airflow/models/xcom_arg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 3f3d5d6ef2f1d..5fe798b9ee3a1 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union from airflow.exceptions import AirflowException from airflow.models.baseoperator import BaseOperator, MappedOperator From 364843ffe338e0d10814bd0a30e07e9cdc427456 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 4 Jan 2022 14:53:49 +0000 Subject: [PATCH 17/17] fixup! Add mapping support to `@task_group` decorator --- airflow/decorators/task_group.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 6036a46d73925..e93002384cca9 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -136,6 +136,7 @@ def task_group( ) -> Callable[[F], F]: ... + # This covers the @task_group case (no parentheses). @overload def task_group(python_callable: F) -> F: