diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 1a7e717d7bda2..f098bb161e053 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -17,12 +17,28 @@ import functools import inspect +import itertools import re -from inspect import signature -from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast - +from typing import ( + Any, + Callable, + Collection, + Dict, + Generic, + Mapping, + Optional, + Sequence, + Set, + Type, + TypeVar, + cast, +) + +import attr + +from airflow.compat.functools import cached_property from airflow.exceptions import AirflowException -from airflow.models import BaseOperator +from airflow.models.baseoperator import BaseOperator, MappedOperator from airflow.models.dag import DAG, DagContext from airflow.models.xcom_arg import XComArg from airflow.utils.context import Context @@ -39,7 +55,7 @@ def validate_python_callable(python_callable): """ if not callable(python_callable): raise TypeError('`python_callable` param must be callable') - if 'self' in signature(python_callable).parameters.keys(): + if 'self' in inspect.signature(python_callable).parameters.keys(): raise AirflowException('@task does not support methods') @@ -127,7 +143,7 @@ def __init__( op_kwargs = op_kwargs or {} # Check that arguments can be binded - signature(python_callable).bind(*op_args, **op_kwargs) + inspect.signature(python_callable).bind(*op_args, **op_kwargs) self.multiple_outputs = multiple_outputs self.op_args = op_args self.op_kwargs = op_kwargs @@ -169,7 +185,7 @@ def _hook_apply_defaults(self, *args, **kwargs): python_callable = kwargs['python_callable'] default_args = kwargs.get('default_args') or {} op_kwargs = kwargs.get('op_kwargs') or {} - f_sig = signature(python_callable) + f_sig = inspect.signature(python_callable) for arg in f_sig.parameters: if arg not in op_kwargs and arg in default_args: op_kwargs[arg] = default_args[arg] @@ -179,11 +195,113 @@ def _hook_apply_defaults(self, *args, **kwargs): T = TypeVar("T", bound=Callable) +OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") + + +@attr.define(slots=False) +class _TaskDecorator(Generic[T, OperatorSubclass]): + """ + Helper class for providing dynamic task mapping to decorated functions. + + ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. + + :meta private: + """ + + function: T = attr.ib(validator=attr.validators.is_callable()) + operator_class: Type[OperatorSubclass] + multiple_outputs: bool = attr.ib() + kwargs: Dict[str, Any] = attr.ib(factory=dict) + + decorator_name: str = attr.ib(repr=False, default="task") + + @cached_property + def function_signature(self): + return inspect.signature(self.function) + + @cached_property + def function_arg_names(self) -> Set[str]: + return set(self.function_signature.parameters) + + @function.validator + def _validate_function(self, _, f): + if 'self' in self.function_arg_names: + raise TypeError(f'@{self.decorator_name} does not support methods') + + @multiple_outputs.default + def _infer_multiple_outputs(self): + return_type = self.function_signature.return_annotation + ttype = getattr(return_type, "__origin__", None) + + return return_type is not inspect.Signature.empty and ttype in (dict, Dict) + + def __attrs_post_init__(self): + self.kwargs.setdefault('task_id', self.function.__name__) + + def __call__(self, *args, **kwargs) -> XComArg: + op = self.operator_class( + python_callable=self.function, + op_args=args, + op_kwargs=kwargs, + multiple_outputs=self.multiple_outputs, + **self.kwargs, + ) + if self.function.__doc__: + op.doc_md = self.function.__doc__ + return XComArg(op) + + def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()): + unknown_args = kwargs.copy() + for name in itertools.chain(self.function_arg_names, valid_names): + unknown_args.pop(name, None) + + if not unknown_args: + # If we have no args left ot check, we are valid + return + + if len(unknown_args) == 1: + raise TypeError(f'{funcname} got unexpected keyword argument {next(iter(unknown_args))!r}') + else: + names = ", ".join(repr(n) for n in unknown_args) + raise TypeError(f'{funcname} got unexpected keyword arguments {names}') + + def map( + self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + ) -> XComArg: + + dag = dag or DagContext.get_current_dag() + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) + + self._validate_arg_names("map", kwargs) + + operator = MappedOperator( + operator_class=self.operator_class, + task_id=task_id, + dag=dag, + task_group=task_group, + partial_kwargs=self.kwargs, + # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves + mapped_kwargs={}, + ) + operator.mapped_kwargs.update(kwargs) + + return XComArg(operator=operator) + + def partial( + self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs + ) -> "_TaskDecorator[T, OperatorSubclass]": + self._validate_arg_names("partial", kwargs, {'task_id'}) + partial_kwargs = self.kwargs.copy() + partial_kwargs.update(kwargs) + return attr.evolve(self, kwargs=partial_kwargs) + def task_decorator_factory( python_callable: Optional[Callable] = None, + *, multiple_outputs: Optional[bool] = None, - decorated_operator_class: Optional[Type[BaseOperator]] = None, + decorated_operator_class: Type[BaseOperator], **kwargs, ) -> Callable[[T], T]: """ @@ -202,38 +320,23 @@ def task_decorator_factory( :type decorated_operator_class: BaseOperator """ - # try to infer from type annotation - if python_callable and multiple_outputs is None: - sig = signature(python_callable).return_annotation - ttype = getattr(sig, "__origin__", None) - - multiple_outputs = sig != inspect.Signature.empty and ttype in (dict, Dict) - - def wrapper(f: T): - """ - Python wrapper to generate PythonDecoratedOperator out of simple python functions. - Used for Airflow Decorated interface - """ - validate_python_callable(f) - kwargs.setdefault('task_id', f.__name__) - - @functools.wraps(f) - def factory(*args, **f_kwargs): - op = decorated_operator_class( - python_callable=f, - op_args=args, - op_kwargs=f_kwargs, - multiple_outputs=multiple_outputs, - **kwargs, - ) - if f.__doc__: - op.doc_md = f.__doc__ - return XComArg(op) - - return cast(T, factory) - - if callable(python_callable): - return wrapper(python_callable) + if multiple_outputs is None: + multiple_outputs = cast(bool, attr.NOTHING) + if python_callable: + return _TaskDecorator( # type: ignore + function=python_callable, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ) elif python_callable is not None: - raise AirflowException('No args allowed while using @task, use kwargs instead') - return wrapper + raise TypeError('No args allowed while using @task, use kwargs instead') + return cast( + "Callable[[T], T]", + functools.partial( + _TaskDecorator, + multiple_outputs=multiple_outputs, + operator_class=decorated_operator_class, + kwargs=kwargs, + ), + ) diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 04ef1614c54c4..e93002384cca9 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -20,19 +20,101 @@ together when the DAG is displayed graphically. """ import functools +import warnings from inspect import signature -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, cast, overload -from airflow.utils.task_group import TaskGroup +import attr + +from airflow.utils.task_group import MappedTaskGroup, TaskGroup if TYPE_CHECKING: from airflow.models import DAG F = TypeVar("F", bound=Callable[..., Any]) +T = TypeVar("T", bound=Callable) +R = TypeVar("R") task_group_sig = signature(TaskGroup.__init__) +@attr.define +class TaskGroupDecorator(Generic[R]): + """:meta private:""" + + function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable()) + kwargs: Dict[str, Any] = attr.ib(factory=dict) + """kwargs for the TaskGroup""" + + @function.validator + def _validate_function(self, _, f): + if 'self' in signature(f).parameters: + raise TypeError('@task_group does not support methods') + + @kwargs.validator + def _validate(self, _, kwargs): + task_group_sig.bind_partial(**kwargs) + + def __attrs_post_init__(self): + self.kwargs.setdefault('group_id', self.function.__name__) + + def _make_task_group(self, **kwargs) -> TaskGroup: + return TaskGroup(**kwargs) + + def __call__(self, *args, **kwargs) -> R: + with self._make_task_group(add_suffix_on_collision=True, **self.kwargs): + # Invoke function to run Tasks inside the TaskGroup + return self.function(*args, **kwargs) + + def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": + return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs) + + def map(self, **kwargs) -> R: + return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).map(**kwargs) + + +@attr.define +class MappedTaskGroupDecorator(TaskGroupDecorator[R]): + """:meta private:""" + + partial_kwargs: Dict[str, Any] = attr.ib(factory=dict) + """static kwargs for the decorated function""" + mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict) + """kwargs for the decorated function""" + + def __call__(self, *args, **kwargs): + raise RuntimeError("A mapped @task_group cannot be called. Use `.map` and `.partial` instead") + + def _make_task_group(self, **kwargs) -> MappedTaskGroup: + tg = MappedTaskGroup(**kwargs) + tg.partial_kwargs = self.partial_kwargs + tg.mapped_kwargs = self.mapped_kwargs + return tg + + def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": + if self.partial_kwargs: + raise RuntimeError("Already a partial task group") + self.partial_kwargs.update(kwargs) + return self + + def map(self, **kwargs) -> R: + if self.mapped_kwargs: + raise RuntimeError("Already a mapped task group") + self.mapped_kwargs = kwargs + + call_kwargs = self.partial_kwargs.copy() + duplicated_keys = set(call_kwargs).intersection(kwargs) + if duplicated_keys: + raise RuntimeError(f"Cannot map partial arguments: {', '.join(sorted(duplicated_keys))}") + call_kwargs.update({k: object() for k in kwargs}) + + return super().__call__(**call_kwargs) + + def __del__(self): + if not self.mapped_kwargs: + warnings.warn(f"Partial task group {self.function.__name__} was never mapped!") + + # This covers the @task_group() case. Annotations are copied from the TaskGroup # class, only providing a default to 'group_id' (this is optional for the # decorator and defaults to the decorated function's name). Please keep them in @@ -73,31 +155,6 @@ def task_group(python_callable=None, *tg_args, **tg_kwargs): :param tg_args: Positional arguments for the TaskGroup object. :param tg_kwargs: Keyword arguments for the TaskGroup object. """ - - def wrapper(f): - # Setting group_id as function name if not given in kwarg group_id - if not tg_args and 'group_id' not in tg_kwargs: - tg_kwargs['group_id'] = f.__name__ - task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs) - - @functools.wraps(f) - def factory(*args, **kwargs): - # Generate signature for decorated function and bind the arguments when called - # we do this to extract parameters so we can annotate them on the DAG object. - # In addition, this fails if we are missing any args/kwargs with TypeError as expected. - # Apply defaults to capture default values if set. - - # Initialize TaskGroup with bound arguments - with TaskGroup( - *task_group_bound_args.args, - add_suffix_on_collision=True, - **task_group_bound_args.kwargs, - ): - # Invoke function to run Tasks inside the TaskGroup - return f(*args, **kwargs) - - return factory - if callable(python_callable): - return wrapper(python_callable) - return wrapper + return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs) + return cast("Callable[[T], T]", functools.partial(TaskGroupDecorator, kwargs=tg_kwargs)) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index c3c5be65a9d15..e1eee43317f7d 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -61,7 +61,7 @@ from airflow.models.param import ParamsDict from airflow.models.pool import Pool from airflow.models.taskinstance import Context, TaskInstance, clear_task_instances -from airflow.models.taskmixin import DependencyMixin +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep @@ -70,7 +70,6 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.triggers.base import BaseTrigger from airflow.utils import timezone -from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.helpers import render_template_as_native, render_template_to_string, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_resources import Resources @@ -91,6 +90,23 @@ T = TypeVar('T', bound=FunctionType) +class _PartialDescriptor: + """A descriptor that guards against ``.partial`` being called on Task objects.""" + + class_method = None + + def __get__( + self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None + ) -> Callable[..., "MappedOperator"]: + # Call this "partial" so it looks nicer in stack traces + def partial(**kwargs): + raise TypeError("partial can only be called on Operator classes, not Tasks themselves") + + if obj is not None: + return partial + return self.class_method.__get__(cls, cls) + + class BaseOperatorMeta(abc.ABCMeta): """Metaclass of BaseOperator.""" @@ -110,12 +126,13 @@ def _apply_defaults(cls, func: T) -> T: # per decoration, i.e. each function decorated using apply_defaults will # have a different sig_cache. sig_cache = signature(func) - non_optional_args = { - name + non_variadic_params = { + name: param for (name, param) in sig_cache.parameters.items() - if param.default == param.empty - and param.name != 'self' - and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + } + non_optional_args = { + name for (name, param) in non_variadic_params.items() if param.default == param.empty } class autostacklevel_warn: @@ -139,7 +156,7 @@ def warn(self, message, category=None, stacklevel=1, source=None): func.__globals__['warnings'] = autostacklevel_warn() @functools.wraps(func) - def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: + def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext @@ -159,9 +176,8 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: params = kwargs.get('params', {}) or {} dag_params.update(params) - default_args = {} - if 'default_args' in kwargs: - default_args = kwargs['default_args'] + default_args = kwargs.pop('default_args', {}) + if default_args: if 'params' in default_args: dag_params.update(default_args['params']) del default_args['params'] @@ -181,13 +197,17 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: if dag_params: kwargs['params'] = dag_params - if default_args: - kwargs['default_args'] = default_args + hook = getattr(self, '_hook_apply_defaults', None) + if hook: + args, kwargs = hook(**kwargs, default_args=default_args) + default_args = kwargs.pop('default_args', {}) - if hasattr(self, '_hook_apply_defaults'): - args, kwargs = self._hook_apply_defaults(*args, **kwargs) + if not hasattr(self, '_BaseOperator__init_kwargs'): + self._BaseOperator__init_kwargs = {} - result = func(self, *args, **kwargs) + result = func(self, **kwargs, default_args=default_args) + # Store the args passed to init -- we need them to support task.map serialzation! + self._BaseOperator__init_kwargs.update(kwargs) # type: ignore # Here we set upstream task defined by XComArgs passed to template fields of the operator self.set_xcomargs_dependencies() @@ -196,16 +216,37 @@ def apply_defaults(self, *args: Any, **kwargs: Any) -> Any: self._BaseOperator__instantiated = True return result + apply_defaults.__non_optional_args = non_optional_args # type: ignore + apply_defaults.__param_names = set(non_variadic_params) # type: ignore + return cast(T, apply_defaults) def __new__(cls, name, bases, namespace, **kwargs): new_cls = super().__new__(cls, name, bases, namespace, **kwargs) + try: + # Update the partial descriptor with the class method so it call call the actual function (but let + # subclasses override it if they need to) + partial_desc = vars(new_cls)['partial'] + if isinstance(partial_desc, _PartialDescriptor): + actual_partial = cls.partial + partial_desc.class_method = classmethod(actual_partial) + except KeyError: + pass new_cls.__init__ = cls._apply_defaults(new_cls.__init__) return new_cls + # The class level partial function. This is what handles the actual mapping + def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): + operator_class = cast("Type[BaseOperator]", cls) + # Validate that the args we passed are known -- at call/DAG parse time, not run time! + _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs) + return MappedOperator( + task_id=task_id, operator_class=operator_class, dag=dag, partial_kwargs=kwargs, mapped_kwargs={} + ) + @functools.total_ordering -class BaseOperator(Operator, LoggingMixin, DependencyMixin, metaclass=BaseOperatorMeta): +class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta): """ Abstract base class for all operators. Since operators create objects that become nodes in the dag, BaseOperator contains many recursive methods for @@ -452,6 +493,8 @@ class derived from this one results in the creation of a task object, # The _serialized_fields are lazily loaded when get_serialized_fields() method is called __serialized_fields: Optional[FrozenSet[str]] = None + partial: Callable[..., "MappedOperator"] = _PartialDescriptor() # type: ignore + _comps = { 'task_id', 'dag_id', @@ -480,6 +523,9 @@ class derived from this one results in the creation of a task object, # If True then the class constructor was called __instantiated = False + # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task + # when mapping + __init_kwargs: Dict[str, Any] # Set to True before calling execute method _lock_for_execution = False @@ -490,6 +536,22 @@ class derived from this one results in the creation of a task object, # Setting it to None by default as other Operators do not have that field subdag: Optional["DAG"] = None + start_date: Optional[pendulum.DateTime] = None + end_date: Optional[pendulum.DateTime] = None + + def __new__(cls, dag: Optional['DAG'] = None, task_group: Optional["TaskGroup"] = None, **kwargs): + # If we are creating a new Task _and_ we are in the context of a MappedTaskGroup, then we should only + # create mapped operators. + from airflow.models.dag import DagContext + from airflow.utils.task_group import MappedTaskGroup, TaskGroupContext + + dag = dag or DagContext.get_current_dag() + task_group = task_group or TaskGroupContext.get_current_task_group(dag) + + if isinstance(task_group, MappedTaskGroup): + return cls.partial(dag=dag, task_group=task_group, **kwargs) + return super().__new__(cls) + def __init__( self, task_id: str, @@ -541,6 +603,8 @@ def __init__( from airflow.models.dag import DagContext from airflow.utils.task_group import TaskGroupContext + self.__init_kwargs = {} + super().__init__() if kwargs: if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'): @@ -575,13 +639,11 @@ def __init__( self._pre_execute_hook = pre_execute self._post_execute_hook = post_execute - self.start_date = start_date if start_date and not isinstance(start_date, datetime): self.log.warning("start_date for %s isn't datetime.datetime", self) elif start_date: self.start_date = timezone.convert_to_utc(start_date) - self.end_date = end_date if end_date: self.end_date = timezone.convert_to_utc(end_date) @@ -792,6 +854,8 @@ def __setattr__(self, key, value): if self._lock_for_execution: # Skip any custom behaviour during execute return + if key in self.__init_kwargs: + self.__init_kwargs[key] = value if self.__instantiated and key in self.template_fields: # Resolve upstreams set by assigning an XComArg after initializing # an operator, example: @@ -816,7 +880,11 @@ def get_outlet_defs(self): return self._outlets @property - def dag(self) -> 'DAG': + def node_id(self): + return self.task_id + + @property # type: ignore[override] + def dag(self) -> 'DAG': # type: ignore[override] """Returns the Operator's DAG if set, otherwise raises an error""" if self._dag: return self._dag @@ -1191,21 +1259,11 @@ def resolve_template_files(self) -> None: self.log.exception(e) self.prepare_template() - @property - def upstream_list(self) -> List["BaseOperator"]: - """@property: list of tasks directly upstream""" - return [self.dag.get_task(tid) for tid in self._upstream_task_ids] - @property def upstream_task_ids(self) -> Set[str]: """@property: set of ids of tasks directly upstream""" return self._upstream_task_ids - @property - def downstream_list(self) -> List["BaseOperator"]: - """@property: list of tasks directly downstream""" - return [self.dag.get_task(tid) for tid in self._downstream_task_ids] - @property def downstream_task_ids(self) -> Set[str]: """@property: set of ids of tasks directly downstream""" @@ -1374,7 +1432,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: else: return self._downstream_task_ids - def get_direct_relatives(self, upstream: bool = False) -> List["BaseOperator"]: + def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: """ Get list of the direct relatives to the current task, upstream or downstream. @@ -1392,13 +1450,6 @@ def task_type(self) -> str: """@property: type of the task""" return self.__class__.__name__ - def add_only_new(self, item_set: Set[str], item: str, dag_id: str) -> None: - """Adds only new items to item set""" - if item in item_set: - self.log.warning('Dependency %s, %s already registered for DAG: %s', self, item, dag_id) - else: - item_set.add(item) - @property def roots(self) -> List["BaseOperator"]: """Required by TaskMixin""" @@ -1409,90 +1460,6 @@ def leaves(self) -> List["BaseOperator"]: """Required by TaskMixin""" return [self] - def _set_relatives( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - upstream: bool = False, - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """Sets relatives for the task or task list.""" - if not isinstance(task_or_task_list, Sequence): - task_or_task_list = [task_or_task_list] - - task_list: List["BaseOperator"] = [] - for task_object in task_or_task_list: - task_object.update_relative(self, not upstream) - relatives = task_object.leaves if upstream else task_object.roots - for task in relatives: - if not isinstance(task, BaseOperator): - raise AirflowException( - f"Relationships can only be set between Operators; received {task.__class__.__name__}" - ) - task_list.append(task) - - # relationships can only be set if the tasks share a single DAG. Tasks - # without a DAG are assigned to that DAG. - dags = { - task._dag.dag_id: task._dag for task in self.roots + task_list if task.has_dag() # type: ignore - } - - if len(dags) > 1: - raise AirflowException( - f'Tried to set relationships between tasks in more than one DAG: {dags.values()}' - ) - elif len(dags) == 1: - dag = dags.popitem()[1] - else: - raise AirflowException( - f"Tried to create relationships between tasks that don't have DAGs yet. " - f"Set the DAG for at least one task and try again: {[self] + task_list}" - ) - - if dag and not self.has_dag(): - # If this task does not yet have a dag, add it to the same dag as the other task and - # put it in the dag's root TaskGroup. - self.dag = dag - self.dag.task_group.add(self) - - for task in task_list: - if dag and not task.has_dag(): - # If the other task does not yet have a dag, add it to the same dag as this task and - # put it in the dag's root TaskGroup. - task.dag = dag - task.dag.task_group.add(task) - if upstream: - task.add_only_new(task.get_direct_relative_ids(upstream=False), self.task_id, self.dag.dag_id) - self.add_only_new(self._upstream_task_ids, task.task_id, task.dag.dag_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, task.task_id, self.task_id) - else: - self.add_only_new(self._downstream_task_ids, task.task_id, task.dag.dag_id) - task.add_only_new(task.get_direct_relative_ids(upstream=True), self.task_id, self.dag.dag_id) - if edge_modifier: - edge_modifier.add_edge_info(self.dag, self.task_id, task.task_id) - - def set_downstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """ - Set a task or a task list to be directly downstream from the current - task. Required by TaskMixin. - """ - self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) - - def set_upstream( - self, - task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], - edge_modifier: Optional[EdgeModifier] = None, - ) -> None: - """ - Set a task or a task list to be directly upstream from the current - task. Required by TaskMixin. - """ - self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) - @property def output(self): """Returns reference to XCom pushed by current operator""" @@ -1614,8 +1581,11 @@ def get_serialized_fields(cls): 'dag', '_dag', '_BaseOperator__instantiated', + '_BaseOperator__init_kwargs', } - | { + | { # Class level defaults need to be added to this list + 'start_date', + 'end_date', '_task_type', 'subdag', 'ui_color', @@ -1659,6 +1629,118 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) + def map(self, **kwargs) -> "MappedOperator": + return MappedOperator.from_operator(self, kwargs) + + +def _validate_kwarg_names_for_mapping(cls: Type[BaseOperator], func_name: str, value: Dict[str, Any]): + if isinstance(str, cls): + # Serialized version -- would have been validated at parse time + return + + # use a dict so order of args is same as code order + unknown_args = value.copy() + for clazz in cls.mro(): + # Mypy doesn't like doing `class.__init__`, Error is: Cannot access "__init__" directly + init = clazz.__init__ # type: ignore + + if not hasattr(init, '_BaseOperatorMeta__param_names'): + continue + + for name in init._BaseOperatorMeta__param_names: + unknown_args.pop(name, None) + + if not unknown_args: + # If we have no args left ot check: stop looking at the MRO chian + return + + if len(unknown_args) == 1: + raise TypeError( + f'{cls.__name__}.{func_name} got unexpected keyword argument {unknown_args.popitem()[0]!r}' + ) + else: + names = ", ".join(repr(n) for n in unknown_args) + raise TypeError(f'{cls.__name__}.{func_name} got unexpected keyword arguments {names}') + + +@attr.define(kw_only=True) +class MappedOperator(DAGNode): + """Object representing a mapped operator in a DAG""" + + operator_class: Type[BaseOperator] = attr.ib(repr=lambda c: c.__name__) + task_id: str + partial_kwargs: Dict[str, Any] + mapped_kwargs: Dict[str, Any] = attr.ib( + validator=lambda self, _, v: _validate_kwarg_names_for_mapping(self.operator_class, "map", v) + ) + dag: Optional["DAG"] = None + upstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) + downstream_task_ids: Set[str] = attr.ib(factory=set, repr=False) + + task_group: Optional["TaskGroup"] = attr.ib(repr=False) + # BaseOperator-like interface -- needed so we can add oursleves to the dag.tasks + start_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) + end_date: Optional[pendulum.DateTime] = attr.ib(repr=False, default=None) + + @classmethod + def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator": + dag: Optional["DAG"] = getattr(operator, '_dag', None) + if dag: + # When BaseOperator() was called within a DAG, it would have been added straight away, but now we + # are mapped, we want to _remove_ that task from the dag + dag._remove_task(operator.task_id) + + return MappedOperator( + operator_class=type(operator), + task_id=operator.task_id, + task_group=getattr(operator, 'task_group', None), + dag=getattr(operator, '_dag', None), + start_date=operator.start_date, + end_date=operator.end_date, + partial_kwargs=operator._BaseOperator__init_kwargs, # type: ignore + mapped_kwargs=mapped_kwargs, + ) + + def __attrs_post_init__(self): + if self.task_group: + self.task_id = self.task_group.child_id(self.task_id) + self.task_group.add(self) + if self.dag: + self.dag.add_task(self) + + @task_group.default + def _default_task_group(self): + from airflow.utils.task_group import TaskGroupContext + + return TaskGroupContext.get_current_task_group(self.dag) + + @property + def node_id(self): + return self.task_id + + def map(self, **kwargs) -> "MappedOperator": + """ + Update the mapping parameters in place. + + :return: ``self`` for easier method chaining + """ + if self.mapped_kwargs: + raise RuntimeError("Already a mapped task") + return attr.evolve(self, mapped_kwargs=kwargs) + + @property + def roots(self) -> List["MappedOperator"]: + """Required by TaskMixin""" + return [self] + + @property + def leaves(self) -> List["MappedOperator"]: + """Required by TaskMixin""" + return [self] + + def has_dag(self): + return self.dag is not None + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 89b79b98b5390..136ed04ac8d38 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -1755,7 +1755,7 @@ def topological_sort(self, include_subdag_tasks: bool = False): acyclic = False for node in list(graph_unsorted.values()): for edge in node.upstream_list: - if edge.task_id in graph_unsorted: + if edge.node_id in graph_unsorted: break # no edges in upstream tasks else: @@ -2075,10 +2075,10 @@ def filter_task_group(group, parent_group): # the cut. subdag_task_groups = dag.task_group.get_task_group_dict() for group in subdag_task_groups.values(): - group.upstream_group_ids = group.upstream_group_ids.intersection(subdag_task_groups.keys()) - group.downstream_group_ids = group.downstream_group_ids.intersection(subdag_task_groups.keys()) - group.upstream_task_ids = group.upstream_task_ids.intersection(dag.task_dict.keys()) - group.downstream_task_ids = group.downstream_task_ids.intersection(dag.task_dict.keys()) + group.upstream_group_ids.intersection_update(subdag_task_groups) + group.downstream_group_ids.intersection_update(subdag_task_groups) + group.upstream_task_ids.intersection_update(dag.task_dict) + group.downstream_task_ids.intersection_update(dag.task_dict) for t in dag.tasks: # Removing upstream/downstream references to tasks that did not @@ -2092,7 +2092,7 @@ def filter_task_group(group, parent_group): return dag def has_task(self, task_id: str): - return task_id in (t.task_id for t in self.tasks) + return task_id in self.task_dict def get_task(self, task_id: str, include_subdags: bool = False) -> BaseOperator: if task_id in self.task_dict: @@ -2197,6 +2197,16 @@ def add_tasks(self, tasks): for task in tasks: self.add_task(task) + def _remove_task(self, task_id: str) -> None: + # This is "private" as removing could leave a hole in dependencies if done incorrectly, and this + # doesn't guard against that + task = self.task_dict.pop(task_id) + tg = getattr(task, 'task_group', None) + if tg: + tg._remove(task) + + self.task_count = len(self.task_dict) + def run( self, start_date=None, diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index a552af8904ac7..20879ef0df65a 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -17,7 +17,7 @@ # under the License. import warnings -from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union, cast from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone @@ -147,7 +147,8 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[str, Iterable task = ti.task dag = task.dag - downstream_tasks = task.downstream_list + # At runtime, the downstream list will only be operators + downstream_tasks = cast("List[BaseOperator]", task.downstream_list) if downstream_tasks: # For a branching workflow that looks like this, when "branch" does skip_all_except("task1"), diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py index 8942229e7aa45..f352377c464e8 100644 --- a/airflow/models/taskmixin.py +++ b/airflow/models/taskmixin.py @@ -16,8 +16,19 @@ # under the License. import warnings -from abc import abstractmethod -from typing import Sequence, Union +from abc import ABCMeta, abstractmethod +from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Set, Union + +import pendulum + +from airflow.exceptions import AirflowException + +if TYPE_CHECKING: + from logging import Logger + + from airflow.models.dag import DAG + from airflow.utils.edgemodifier import EdgeModifier + from airflow.utils.task_group import TaskGroup class DependencyMixin: @@ -88,3 +99,167 @@ def __init_subclass__(cls) -> None: stacklevel=2, ) return super().__init_subclass__() + + +class DAGNode(DependencyMixin, metaclass=ABCMeta): + """ + A base class for a node in the graph of a workflow -- an Operator or a Task Group, either mapped or + unmapped. + """ + + dag: Optional["DAG"] = None + + @property + @abstractmethod + def node_id(self) -> str: + raise NotImplementedError() + + task_group: Optional["TaskGroup"] + """The task_group that contains this node""" + + start_date: Optional[pendulum.DateTime] + end_date: Optional[pendulum.DateTime] + + def has_dag(self) -> bool: + return self.dag is not None + + @property + @abstractmethod + def upstream_task_ids(self) -> Set[str]: + raise NotImplementedError() + + @property + @abstractmethod + def downstream_task_ids(self) -> Set[str]: + raise NotImplementedError() + + @property + def log(self) -> "Logger": + raise NotImplementedError() + + @property + @abstractmethod + def roots(self) -> Sequence["DAGNode"]: + raise NotImplementedError() + + @property + @abstractmethod + def leaves(self) -> Sequence["DAGNode"]: + raise NotImplementedError() + + def _set_relatives( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + upstream: bool = False, + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Sets relatives for the task or task list.""" + from airflow.models.baseoperator import BaseOperator, MappedOperator + + if not isinstance(task_or_task_list, Sequence): + task_or_task_list = [task_or_task_list] + + task_list: List[DAGNode] = [] + for task_object in task_or_task_list: + task_object.update_relative(self, not upstream) + relatives = task_object.leaves if upstream else task_object.roots + for task in relatives: + if not isinstance(task, (BaseOperator, MappedOperator)): + raise AirflowException( + f"Relationships can only be set between Operators; received {task.__class__.__name__}" + ) + task_list.append(task) + + # relationships can only be set if the tasks share a single DAG. Tasks + # without a DAG are assigned to that DAG. + dags: Set["DAG"] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag} + + if len(dags) > 1: + raise AirflowException(f'Tried to set relationships between tasks in more than one DAG: {dags}') + elif len(dags) == 1: + dag = dags.pop() + else: + raise AirflowException( + f"Tried to create relationships between tasks that don't have DAGs yet. " + f"Set the DAG for at least one task and try again: {[self, *task_list]}" + ) + + if not self.has_dag(): + # If this task does not yet have a dag, add it to the same dag as the other task and + # put it in the dag's root TaskGroup. + self.dag = dag + self.dag.task_group.add(self) + + def add_only_new(obj, item_set: Set[str], item: str) -> None: + """Adds only new items to item set""" + if item in item_set: + self.log.warning('Dependency %s, %s already registered for DAG: %s', obj, item, dag.dag_id) + else: + item_set.add(item) + + for task in task_list: + if dag and not task.has_dag(): + # If the other task does not yet have a dag, add it to the same dag as this task and + # put it in the dag's root TaskGroup. + dag.add_task(task) + dag.task_group.add(task) + if upstream: + add_only_new(task, task.downstream_task_ids, self.node_id) + add_only_new(self, self.upstream_task_ids, task.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, task.node_id, self.node_id) + else: + add_only_new(self, self.downstream_task_ids, task.node_id) + add_only_new(task, task.upstream_task_ids, self.node_id) + if edge_modifier: + edge_modifier.add_edge_info(self.dag, self.node_id, task.node_id) + + def set_downstream( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Set a node (or nodes) to be directly downstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier) + + def set_upstream( + self, + task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]], + edge_modifier: Optional["EdgeModifier"] = None, + ) -> None: + """Set a node (or nodes) to be directly downstream from the current node.""" + self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) + + @property + def downstream_list(self) -> Iterable["DAGNode"]: + """List of nodes directly downstream""" + if not self.dag: + raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + return [self.dag.get_task(tid) for tid in self.downstream_task_ids] + + @property + def upstream_list(self) -> Iterable["DAGNode"]: + """List of nodes directly upstream""" + if not self.dag: + raise AirflowException(f'Operator {self} has not been assigned to a DAG yet') + return [self.dag.get_task(tid) for tid in self.upstream_task_ids] + + def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: + """ + Get set of the direct relative ids to the current task, upstream or + downstream. + """ + if upstream: + return self.upstream_task_ids + else: + return self.downstream_task_ids + + def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: + """ + Get list of the direct relatives to the current task, upstream or + downstream. + """ + if upstream: + return self.upstream_list + else: + return self.downstream_list diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index d6f0f290deca8..5fe798b9ee3a1 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union from airflow.exceptions import AirflowException -from airflow.models.baseoperator import BaseOperator -from airflow.models.taskmixin import DependencyMixin +from airflow.models.baseoperator import BaseOperator, MappedOperator +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.models.xcom import XCOM_RETURN_KEY from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier @@ -59,7 +59,7 @@ class XComArg(DependencyMixin): :type key: str """ - def __init__(self, operator: BaseOperator, key: str = XCOM_RETURN_KEY): + def __init__(self, operator: Union[BaseOperator, MappedOperator], key: str = XCOM_RETURN_KEY): self._operator = operator self._key = key @@ -93,17 +93,17 @@ def __str__(self): return xcom_pull @property - def operator(self) -> BaseOperator: + def operator(self) -> Union[BaseOperator, MappedOperator]: """Returns operator of this XComArg.""" return self._operator @property - def roots(self) -> List[BaseOperator]: + def roots(self) -> List[DAGNode]: """Required by TaskMixin""" return [self._operator] @property - def leaves(self) -> List[BaseOperator]: + def leaves(self) -> List[DAGNode]: """Required by TaskMixin""" return [self._operator] @@ -133,13 +133,10 @@ def resolve(self, context: Context) -> Any: Pull XCom value for the existing arg. This method is run during ``op.execute()`` in respectable context. """ - resolved_value = self.operator.xcom_pull( - context=context, - task_ids=[self.operator.task_id], - key=str(self.key), # xcom_pull supports only key as str - dag_id=self.operator.dag.dag_id, - ) + resolved_value = context['ti'].xcom_pull(task_ids=[self.operator.task_id], key=str(self.key)) if not resolved_value: + if TYPE_CHECKING: + assert self.operator.dag raise AirflowException( f'XComArg result from {self.operator.task_id} at {self.operator.dag.dag_id} ' f'with key="{self.key}"" is not found!' diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 561a54bdcdd72..2b95a6a1b0fb8 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1002,10 +1002,10 @@ def deserialize_task_group( else SerializedTaskGroup.deserialize_task_group(val, group, task_dict) for label, (_type, val) in encoded_group["children"].items() } - group.upstream_group_ids = set(cls._deserialize(encoded_group["upstream_group_ids"])) - group.downstream_group_ids = set(cls._deserialize(encoded_group["downstream_group_ids"])) - group.upstream_task_ids = set(cls._deserialize(encoded_group["upstream_task_ids"])) - group.downstream_task_ids = set(cls._deserialize(encoded_group["downstream_task_ids"])) + group.upstream_group_ids.update(cls._deserialize(encoded_group["upstream_group_ids"])) + group.downstream_group_ids.update(cls._deserialize(encoded_group["downstream_group_ids"])) + group.upstream_task_ids.update(cls._deserialize(encoded_group["upstream_task_ids"])) + group.downstream_task_ids.update(cls._deserialize(encoded_group["downstream_task_ids"])) return group diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py index 84f8f459d129d..bf91c38408307 100644 --- a/airflow/utils/task_group.py +++ b/airflow/utils/task_group.py @@ -21,18 +21,21 @@ """ import copy import re -from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Sequence, Set, Union +import weakref +from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, List, Optional, Sequence, Set, Union from airflow.exceptions import AirflowException, DuplicateTaskIdFound -from airflow.models.taskmixin import DependencyMixin +from airflow.models.taskmixin import DAGNode, DependencyMixin from airflow.utils.helpers import validate_group_key +from airflow.utils.types import NOTSET if TYPE_CHECKING: from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG + from airflow.utils.edgemodifier import EdgeModifier -class TaskGroup(DependencyMixin): +class TaskGroup(DAGNode): """ A collection of tasks. When set_downstream() or set_upstream() are called on the TaskGroup, it is applied across all tasks within the group if necessary. @@ -69,6 +72,8 @@ class TaskGroup(DependencyMixin): :type from_decorator: add_suffix_on_collision """ + used_group_ids: Set[Optional[str]] + def __init__( self, group_id: Optional[str], @@ -86,14 +91,17 @@ def __init__( self.prefix_group_id = prefix_group_id self.default_args = copy.deepcopy(default_args or {}) + dag = dag or DagContext.get_current_dag() + if group_id is None: # This creates a root TaskGroup. if parent_group: raise AirflowException("Root TaskGroup cannot have parent_group") # used_group_ids is shared across all TaskGroups in the same DAG to keep track # of used group_id to avoid duplication. - self.used_group_ids: Set[Optional[str]] = set() + self.used_group_ids = set() self._parent_group = None + self.dag = dag else: if prefix_group_id: # If group id is used as prefix, it should not contain spaces nor dots @@ -105,14 +113,17 @@ def __init__( if not group_id: raise ValueError("group_id must not be empty") - dag = dag or DagContext.get_current_dag() - if not parent_group and not dag: raise AirflowException("TaskGroup can only be used inside a dag") self._parent_group = parent_group or TaskGroupContext.get_current_task_group(dag) if not self._parent_group: raise AirflowException("TaskGroup must have a parent_group except for the root TaskGroup") + if dag is not self._parent_group.dag: + raise RuntimeError( + "Cannot mix TaskGroups from different DAGs: %s and %s", dag, self._parent_group.dag + ) + self.used_group_ids = self._parent_group.used_group_ids # if given group_id already used assign suffix by incrementing largest used suffix integer @@ -121,9 +132,10 @@ def __init__( self._check_for_group_id_collisions(add_suffix_on_collision) self.used_group_ids.add(self.group_id) - self.used_group_ids.add(self.downstream_join_id) - self.used_group_ids.add(self.upstream_join_id) - self.children: Dict[str, Union["BaseOperator", "TaskGroup"]] = {} + if self.group_id: + self.used_group_ids.add(self.downstream_join_id) + self.used_group_ids.add(self.upstream_join_id) + self.children: Dict[str, DAGNode] = {} if self._parent_group: self._parent_group.add(self) @@ -135,8 +147,9 @@ def __init__( # so that we can optimize the number of edges when entire TaskGroups depend on each other. self.upstream_group_ids: Set[Optional[str]] = set() self.downstream_group_ids: Set[Optional[str]] = set() - self.upstream_task_ids: Set[Optional[str]] = set() - self.downstream_task_ids: Set[Optional[str]] = set() + # Since the parent class defines these as read-only properties, we can 't just do `self.x = ...` + self.__dict__['upstream_task_ids'] = set() + self.__dict__['downstream_task_ids'] = set() def _check_for_group_id_collisions(self, add_suffix_on_collision: bool): if self._group_id is None: @@ -162,11 +175,23 @@ def create_root(cls, dag: "DAG") -> "TaskGroup": """Create a root TaskGroup with no group_id or parent.""" return cls(group_id=None, dag=dag) + @property + def node_id(self): + return self.group_id + @property def is_root(self) -> bool: """Returns True if this TaskGroup is the root TaskGroup. Otherwise False""" return not self.group_id + @property + def upstream_task_ids(self) -> Set[str]: + return self.__dict__['upstream_task_ids'] + + @property + def downstream_task_ids(self) -> Set[str]: + return self.__dict__['downstream_task_ids'] + def __iter__(self): for child in self.children.values(): if isinstance(child, TaskGroup): @@ -174,18 +199,36 @@ def __iter__(self): else: yield child - def add(self, task: Union["BaseOperator", "TaskGroup"]) -> None: + def add(self, task: DAGNode) -> None: """Add a task to this TaskGroup.""" - key = task.group_id if isinstance(task, TaskGroup) else task.task_id + key = task.node_id if key in self.children: - raise DuplicateTaskIdFound(f"Task id '{key}' has already been added to the DAG") + node_type = "Task" if hasattr(task, 'task_id') else "Task Group" + raise DuplicateTaskIdFound(f"{node_type} id '{key}' has already been added to the DAG") if isinstance(task, TaskGroup): + if self.dag: + if task.dag is not None and self.dag is not task.dag: + raise RuntimeError( + "Cannot mix TaskGroups from different DAGs: %s and %s", self.dag, task.dag + ) + task.dag = self.dag if task.children: raise AirflowException("Cannot add a non-empty TaskGroup") - self.children[key] = task # type: ignore + self.children[key] = task + task.task_group = weakref.proxy(self) + + def _remove(self, task: DAGNode) -> None: + key = task.node_id + + if key not in self.children: + raise KeyError(f"Node id {key!r} not part of this task group") + + self.used_group_ids.remove(key) + del self.children[key] + task.task_group = None @property def group_id(self) -> Optional[str]: @@ -207,8 +250,6 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None: Update upstream_group_ids/downstream_group_ids/upstream_task_ids/downstream_task_ids accordingly so that we can reduce the number of edges when displaying Graph view. """ - from airflow.models.baseoperator import BaseOperator - if isinstance(other, TaskGroup): # Handles setting relationship between a TaskGroup and another TaskGroup if upstream: @@ -221,19 +262,22 @@ def update_relative(self, other: DependencyMixin, upstream=True) -> None: else: # Handles setting relationship between a TaskGroup and a task for task in other.roots: - if not isinstance(task, BaseOperator): + if not isinstance(task, DAGNode): raise AirflowException( "Relationships can only be set between TaskGroup " f"or operators; received {task.__class__.__name__}" ) if upstream: - self.upstream_task_ids.add(task.task_id) + self.upstream_task_ids.add(task.node_id) else: - self.downstream_task_ids.add(task.task_id) + self.downstream_task_ids.add(task.node_id) - def _set_relative( - self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]], upstream: bool = False + def _set_relatives( + self, + task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]], + upstream: bool = False, + edge_modifier: Optional["EdgeModifier"] = None, ) -> None: """ Call set_upstream/set_downstream for all root/leaf tasks within this TaskGroup. @@ -252,16 +296,6 @@ def _set_relative( for task_like in task_or_task_list: self.update_relative(task_like, upstream) - def set_downstream( - self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]] - ) -> None: - """Set a TaskGroup/task/list of task downstream of this TaskGroup.""" - self._set_relative(task_or_task_list, upstream=False) - - def set_upstream(self, task_or_task_list: Union["DependencyMixin", Sequence["DependencyMixin"]]) -> None: - """Set a TaskGroup/task/list of task upstream of this TaskGroup.""" - self._set_relative(task_or_task_list, upstream=True) - def __enter__(self) -> "TaskGroup": TaskGroupContext.push_context_managed_task_group(self) return self @@ -348,10 +382,38 @@ def build_map(task_group): build_map(self) return task_group_map - def get_child_by_label(self, label: str) -> Union["BaseOperator", "TaskGroup"]: + def get_child_by_label(self, label: str) -> DAGNode: """Get a child task/TaskGroup by its label (i.e. task_id/group_id without the group_id prefix)""" return self.children[self.child_id(label)] + def map(self, arg: Iterable) -> "MappedTaskGroup": + if self.children: + raise RuntimeError("Cannot map a TaskGroup that already has children") + if not self.group_id: + raise RuntimeError("Cannot map a TaskGroup before it has a group_id") + if self._parent_group: + self._parent_group._remove(self) + return MappedTaskGroup(group_id=self._group_id, mapped_arg=arg) + + +class MappedTaskGroup(TaskGroup): + """ + A TaskGroup that is dynamically expanded at run time. + + Do not create instances of this class directly, instead use :meth:`TaskGroup.map` + """ + + mapped_arg: Any = NOTSET + mapped_kwargs: Dict[str, Any] + partial_kwargs: Dict[str, Any] + + def __init__(self, group_id: Optional[str] = None, mapped_arg: Any = NOTSET, **kwargs): + if mapped_arg is not NOTSET: + self.mapped_arg = mapped_arg + self.mapped_kwargs = {} + self.partial_kwargs = {} + super().__init__(group_id=group_id, **kwargs) + class TaskGroupContext: """TaskGroup context is used to keep the current TaskGroup when TaskGroup is used as ContextManager.""" diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index fdad20716ebb5..eebe41da31275 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1408,6 +1408,7 @@ unicode unittest unittests unix +unmapped unpause unpausing unpredicted diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 610fa84db7b15..48f6c144ed24a 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -25,6 +25,7 @@ from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance as TI +from airflow.models.baseoperator import MappedOperator from airflow.models.xcom_arg import XComArg from airflow.utils import timezone from airflow.utils.session import create_session @@ -109,7 +110,7 @@ def test_python_operator_python_callable_is_callable(self): """Tests that @task will only instantiate if the python_callable argument is callable.""" not_callable = {} - with pytest.raises(AirflowException): + with pytest.raises(TypeError): task_decorator(not_callable, dag=self.dag) def test_infer_multiple_outputs_using_typing(self): @@ -201,7 +202,7 @@ def add_number(num: int) -> int: def test_fail_method(self): """Tests that @task will fail if signature is not binding.""" - with pytest.raises(AirflowException): + with pytest.raises(TypeError): class Test: num = 2 @@ -210,8 +211,6 @@ class Test: def add_number(self, num: int) -> int: return self.num + num - Test().add_number(2) - def test_fail_multiple_outputs_key_type(self): @task_decorator(multiple_outputs=True) def add_number(num: int): @@ -498,3 +497,64 @@ def add_2(number: int): ret = add_2(test_number) assert ret.operator.doc_md.strip(), "Adds 2 to number." + + +def test_mapped_decorator() -> None: + @task_decorator + def double(number: int): + return number * 2 + + with DAG('test_dag', start_date=DEFAULT_DATE): + literal = [1, 2, 3] + doubled_0 = double.map(number=literal) + doubled_1 = double.map(number=literal) + + assert isinstance(doubled_0, XComArg) + assert isinstance(doubled_0.operator, MappedOperator) + assert doubled_0.operator.task_id == "double" + assert doubled_0.operator.mapped_kwargs == {"number": literal} + + assert doubled_1.operator.task_id == "double__1" + + +def test_mapped_decorator_invalid_args() -> None: + @task_decorator + def double(number: int): + return number * 2 + + with DAG('test_dag', start_date=DEFAULT_DATE): + literal = [1, 2, 3] + + with pytest.raises(TypeError, match="arguments 'other', 'b'"): + double.partial(other=1, b='a') + with pytest.raises(TypeError, match="argument 'other'"): + double.map(number=literal, other=1) + + +def test_partial_mapped_decorator() -> None: + @task_decorator + def product(number: int, multiple: int): + return number * multiple + + with DAG('test_dag', start_date=DEFAULT_DATE) as dag: + literal = [1, 2, 3] + quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal) + doubled = product.partial(multiple=2).map(number=literal) + trippled = product.partial(multiple=3).map(number=literal) + + product.partial(multiple=2) + + assert isinstance(doubled, XComArg) + assert isinstance(doubled.operator, MappedOperator) + assert doubled.operator.task_id == "product" + assert doubled.operator.mapped_kwargs == {"number": literal} + assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2} + + assert trippled.operator.task_id == "product__1" + assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3} + + assert quadrupled.operator.task_id == "times_4" + + assert doubled.operator is not trippled.operator + + assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index cabe1b58785a5..c0a4593646100 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -22,13 +22,20 @@ from unittest import mock import jinja2 +import pendulum import pytest from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException from airflow.lineage.entities import File from airflow.models import DAG -from airflow.models.baseoperator import BaseOperator, BaseOperatorMeta, chain, cross_downstream +from airflow.models.baseoperator import ( + BaseOperator, + BaseOperatorMeta, + MappedOperator, + chain, + cross_downstream, +) from airflow.utils.context import Context from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup @@ -640,3 +647,78 @@ def test_operator_retries(caplog, dag_maker, retries, expected): retries=retries, ) assert caplog.record_tuples == expected + + +def test_task_mapping_with_dag(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + task1 >> mapped + + assert isinstance(mapped, MappedOperator) + assert mapped in dag.tasks + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {'start_date': DEFAULT_DATE.now(), 'owner': 'test'} + with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ['a', 'b', 'c'] + mapped = MockOperator(task_id='task_2').map(arg2=literal) + + task1 >> mapped + + assert mapped.partial_kwargs['owner'] == 'test' + assert mapped.start_date == pendulum.instance(default_args['start_date']) + + +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator(task_id='a').map(file=[1, 2, {'a': 'b'}]) + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator( + task_id='a', + ).partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id='a', arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.partial_kwargs == {'arg1': 'a', 'trigger_rule': TriggerRule.ONE_FAILED} + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): + MockOperator.partial(task_id='a', foo='bar', bar=2) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index d1e6e794ef72e..1729f8bee2de4 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1094,14 +1094,12 @@ def test_no_new_fields_added_to_base_operator(self): tests should be added for it. """ base_operator = BaseOperator(task_id="10") - fields = base_operator.__dict__ + fields = {k: v for (k, v) in vars(base_operator).items() if k in BaseOperator.get_serialized_fields()} assert fields == { - '_BaseOperator__instantiated': True, '_downstream_task_ids': set(), '_inlets': [], '_log': base_operator.log, '_outlets': [], - '_upstream_task_ids': set(), '_pre_execute_hook': None, '_post_execute_hook': None, 'depends_on_past': False, @@ -1114,10 +1112,8 @@ def test_no_new_fields_added_to_base_operator(self): 'email': None, 'email_on_failure': True, 'email_on_retry': True, - 'end_date': None, 'execution_timeout': None, 'executor_config': {}, - 'inlets': [], 'label': '10', 'max_active_tis_per_dag': None, 'max_retry_delay': None, @@ -1125,7 +1121,6 @@ def test_no_new_fields_added_to_base_operator(self): 'on_failure_callback': None, 'on_retry_callback': None, 'on_success_callback': None, - 'outlets': [], 'owner': 'airflow', 'params': {}, 'pool': 'default_pool', @@ -1138,7 +1133,6 @@ def test_no_new_fields_added_to_base_operator(self): 'retry_exponential_backoff': False, 'run_as_user': None, 'sla': None, - 'start_date': None, 'task_id': '10', 'trigger_rule': 'all_success', 'wait_for_downstream': False, diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 2dcef95885983..51e0319e49963 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -21,13 +21,16 @@ from airflow.decorators import dag, task_group as task_group_decorator from airflow.models import DAG +from airflow.models.baseoperator import MappedOperator from airflow.models.xcom_arg import XComArg from airflow.operators.bash import BashOperator from airflow.operators.dummy import DummyOperator from airflow.operators.python import PythonOperator from airflow.utils.dates import days_ago -from airflow.utils.task_group import TaskGroup +from airflow.utils.task_group import MappedTaskGroup, TaskGroup from airflow.www.views import dag_edges, task_group_to_dict +from tests.models import DEFAULT_DATE +from tests.test_utils.mock_operators import MockOperator EXPECTED_JSON = { 'id': None, @@ -998,3 +1001,94 @@ def wrap(): assert isinstance(total_3, XComArg) wrap() + + +def test_map() -> None: + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + start = MockOperator(task_id="start") + end = MockOperator(task_id="end") + literal = ['a', 'b', 'c'] + with TaskGroup("process_one").map(literal) as process_one: + one = MockOperator(task_id='one') + two = MockOperator(task_id='two') + three = MockOperator(task_id='three') + + one >> two >> three + + start >> process_one >> end + + # check the mapped operators are attached to the task broup + assert isinstance(process_one, MappedTaskGroup) + assert process_one.has_task(one) + assert process_one.mapped_arg is literal + + assert isinstance(one, MappedOperator) + assert start.downstream_list == [one] + assert one in dag.tasks + # At parse time there should only be two tasks! + assert len(dag.tasks) == 5 + + assert end.upstream_list == [three] + assert three.downstream_list == [end] + + +def test_nested_map() -> None: + with DAG("test-dag", start_date=DEFAULT_DATE): + start = MockOperator(task_id="start") + end = MockOperator(task_id="end") + literal = ['a', 'b', 'c'] + with TaskGroup("process_one").map(literal) as process_one: + one = MockOperator(task_id='one') + + with TaskGroup("process_two").map(literal) as process_one_two: + two = MockOperator(task_id='two') + three = MockOperator(task_id='three') + two >> three + + four = MockOperator(task_id='four') + one >> process_one_two >> four + + start >> process_one >> end + + +def test_decorator_unknown_args(): + """Test that unknown args passed to the decorator cause an error at parse time""" + with pytest.raises(TypeError): + + @task_group_decorator(b=2) + def tg(): + ... + + +def test_decorator_partial_unmapped(): + @task_group_decorator + def tg(): + ... + + with pytest.warns(UserWarning, match='was never mapped'): + with DAG("test-dag", start_date=DEFAULT_DATE): + tg.partial() + + +def test_decorator_map(): + @task_group_decorator + def my_task_group(my_arg_1: str, unmapped: bool): + assert unmapped is True + assert isinstance(my_arg_1, object) + task_1 = DummyOperator(task_id="task_1") + task_2 = BashOperator(task_id="task_2", bash_command='echo "${my_arg_1}"', env={'my_arg_1': my_arg_1}) + task_3 = DummyOperator(task_id="task_3") + task_1 >> [task_2, task_3] + + return task_1, task_2, task_3 + + with DAG("test-dag", start_date=DEFAULT_DATE) as dag: + lines = ["foo", "bar", "baz"] + + (task_1, task_2, task_3) = my_task_group.partial(unmapped=True).map(my_arg_1=lines) + + assert task_1 in dag.tasks + + tg = dag.task_group.get_child_by_label("my_task_group") + assert isinstance(tg, MappedTaskGroup) + assert "my_arg_1" in tg.mapped_kwargs