-
Notifications
You must be signed in to change notification settings - Fork 16.4k
Map and Partial DAG authoring interface for Dynamic Task Mapping #19965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
423bf9b
da478ef
69fd94a
7256a37
80a0d92
273f32c
25a7fdd
ad8f047
27b12cd
e2fe0c3
b246c7b
2362c78
6c21820
efbc0b2
deb54df
16df4ab
364843f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,28 @@ | |
|
|
||
| import functools | ||
| import inspect | ||
| import itertools | ||
| import re | ||
| from inspect import signature | ||
| from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast | ||
|
|
||
| from typing import ( | ||
| Any, | ||
| Callable, | ||
| Collection, | ||
| Dict, | ||
| Generic, | ||
| Mapping, | ||
| Optional, | ||
| Sequence, | ||
| Set, | ||
| Type, | ||
| TypeVar, | ||
| cast, | ||
| ) | ||
|
|
||
| import attr | ||
|
|
||
| from airflow.compat.functools import cached_property | ||
| from airflow.exceptions import AirflowException | ||
| from airflow.models import BaseOperator | ||
| from airflow.models.baseoperator import BaseOperator, MappedOperator | ||
| from airflow.models.dag import DAG, DagContext | ||
| from airflow.models.xcom_arg import XComArg | ||
| from airflow.utils.context import Context | ||
|
|
@@ -39,7 +55,7 @@ def validate_python_callable(python_callable): | |
| """ | ||
| if not callable(python_callable): | ||
| raise TypeError('`python_callable` param must be callable') | ||
| if 'self' in signature(python_callable).parameters.keys(): | ||
| if 'self' in inspect.signature(python_callable).parameters.keys(): | ||
| raise AirflowException('@task does not support methods') | ||
|
|
||
|
|
||
|
|
@@ -127,7 +143,7 @@ def __init__( | |
| op_kwargs = op_kwargs or {} | ||
|
|
||
| # Check that arguments can be binded | ||
| signature(python_callable).bind(*op_args, **op_kwargs) | ||
| inspect.signature(python_callable).bind(*op_args, **op_kwargs) | ||
| self.multiple_outputs = multiple_outputs | ||
| self.op_args = op_args | ||
| self.op_kwargs = op_kwargs | ||
|
|
@@ -169,7 +185,7 @@ def _hook_apply_defaults(self, *args, **kwargs): | |
| python_callable = kwargs['python_callable'] | ||
| default_args = kwargs.get('default_args') or {} | ||
| op_kwargs = kwargs.get('op_kwargs') or {} | ||
| f_sig = signature(python_callable) | ||
| f_sig = inspect.signature(python_callable) | ||
| for arg in f_sig.parameters: | ||
| if arg not in op_kwargs and arg in default_args: | ||
| op_kwargs[arg] = default_args[arg] | ||
|
|
@@ -179,11 +195,113 @@ def _hook_apply_defaults(self, *args, **kwargs): | |
|
|
||
| T = TypeVar("T", bound=Callable) | ||
|
|
||
| OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") | ||
|
|
||
|
|
||
| @attr.define(slots=False) | ||
| class _TaskDecorator(Generic[T, OperatorSubclass]): | ||
| """ | ||
| Helper class for providing dynamic task mapping to decorated functions. | ||
|
|
||
| ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. | ||
|
|
||
| :meta private: | ||
| """ | ||
|
|
||
| function: T = attr.ib(validator=attr.validators.is_callable()) | ||
| operator_class: Type[OperatorSubclass] | ||
| multiple_outputs: bool = attr.ib() | ||
| kwargs: Dict[str, Any] = attr.ib(factory=dict) | ||
|
|
||
| decorator_name: str = attr.ib(repr=False, default="task") | ||
|
|
||
| @cached_property | ||
| def function_signature(self): | ||
| return inspect.signature(self.function) | ||
|
|
||
| @cached_property | ||
| def function_arg_names(self) -> Set[str]: | ||
| return set(self.function_signature.parameters) | ||
|
|
||
| @function.validator | ||
| def _validate_function(self, _, f): | ||
| if 'self' in self.function_arg_names: | ||
| raise TypeError(f'@{self.decorator_name} does not support methods') | ||
|
|
||
| @multiple_outputs.default | ||
| def _infer_multiple_outputs(self): | ||
| return_type = self.function_signature.return_annotation | ||
| ttype = getattr(return_type, "__origin__", None) | ||
|
|
||
| return return_type is not inspect.Signature.empty and ttype in (dict, Dict) | ||
|
|
||
| def __attrs_post_init__(self): | ||
| self.kwargs.setdefault('task_id', self.function.__name__) | ||
|
|
||
| def __call__(self, *args, **kwargs) -> XComArg: | ||
| op = self.operator_class( | ||
| python_callable=self.function, | ||
| op_args=args, | ||
| op_kwargs=kwargs, | ||
| multiple_outputs=self.multiple_outputs, | ||
| **self.kwargs, | ||
| ) | ||
| if self.function.__doc__: | ||
| op.doc_md = self.function.__doc__ | ||
| return XComArg(op) | ||
|
|
||
| def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()): | ||
| unknown_args = kwargs.copy() | ||
| for name in itertools.chain(self.function_arg_names, valid_names): | ||
| unknown_args.pop(name, None) | ||
|
|
||
| if not unknown_args: | ||
| # If we have no args left ot check, we are valid | ||
| return | ||
|
|
||
| if len(unknown_args) == 1: | ||
| raise TypeError(f'{funcname} got unexpected keyword argument {next(iter(unknown_args))!r}') | ||
| else: | ||
| names = ", ".join(repr(n) for n in unknown_args) | ||
| raise TypeError(f'{funcname} got unexpected keyword arguments {names}') | ||
|
|
||
| def map( | ||
| self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs | ||
| ) -> XComArg: | ||
|
|
||
| dag = dag or DagContext.get_current_dag() | ||
| task_group = task_group or TaskGroupContext.get_current_task_group(dag) | ||
| task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group) | ||
|
|
||
| self._validate_arg_names("map", kwargs) | ||
|
|
||
| operator = MappedOperator( | ||
| operator_class=self.operator_class, | ||
| task_id=task_id, | ||
| dag=dag, | ||
| task_group=task_group, | ||
| partial_kwargs=self.kwargs, | ||
| # Set them to empty to bypass the validation, as for decorated stuff we validate ourselves | ||
| mapped_kwargs={}, | ||
| ) | ||
| operator.mapped_kwargs.update(kwargs) | ||
|
|
||
| return XComArg(operator=operator) | ||
|
|
||
| def partial( | ||
| self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs | ||
| ) -> "_TaskDecorator[T, OperatorSubclass]": | ||
| self._validate_arg_names("partial", kwargs, {'task_id'}) | ||
| partial_kwargs = self.kwargs.copy() | ||
| partial_kwargs.update(kwargs) | ||
| return attr.evolve(self, kwargs=partial_kwargs) | ||
|
|
||
|
|
||
| def task_decorator_factory( | ||
| python_callable: Optional[Callable] = None, | ||
| *, | ||
| multiple_outputs: Optional[bool] = None, | ||
| decorated_operator_class: Optional[Type[BaseOperator]] = None, | ||
| decorated_operator_class: Type[BaseOperator], | ||
| **kwargs, | ||
| ) -> Callable[[T], T]: | ||
| """ | ||
|
|
@@ -202,38 +320,23 @@ def task_decorator_factory( | |
| :type decorated_operator_class: BaseOperator | ||
|
|
||
| """ | ||
| # try to infer from type annotation | ||
| if python_callable and multiple_outputs is None: | ||
| sig = signature(python_callable).return_annotation | ||
| ttype = getattr(sig, "__origin__", None) | ||
|
|
||
| multiple_outputs = sig != inspect.Signature.empty and ttype in (dict, Dict) | ||
|
|
||
| def wrapper(f: T): | ||
| """ | ||
| Python wrapper to generate PythonDecoratedOperator out of simple python functions. | ||
| Used for Airflow Decorated interface | ||
| """ | ||
| validate_python_callable(f) | ||
| kwargs.setdefault('task_id', f.__name__) | ||
|
|
||
| @functools.wraps(f) | ||
| def factory(*args, **f_kwargs): | ||
| op = decorated_operator_class( | ||
| python_callable=f, | ||
| op_args=args, | ||
| op_kwargs=f_kwargs, | ||
| multiple_outputs=multiple_outputs, | ||
| **kwargs, | ||
| ) | ||
| if f.__doc__: | ||
| op.doc_md = f.__doc__ | ||
| return XComArg(op) | ||
|
|
||
| return cast(T, factory) | ||
|
|
||
| if callable(python_callable): | ||
| return wrapper(python_callable) | ||
| if multiple_outputs is None: | ||
| multiple_outputs = cast(bool, attr.NOTHING) | ||
| if python_callable: | ||
| return _TaskDecorator( # type: ignore | ||
|
||
| function=python_callable, | ||
| multiple_outputs=multiple_outputs, | ||
| operator_class=decorated_operator_class, | ||
| kwargs=kwargs, | ||
| ) | ||
| elif python_callable is not None: | ||
| raise AirflowException('No args allowed while using @task, use kwargs instead') | ||
| return wrapper | ||
| raise TypeError('No args allowed while using @task, use kwargs instead') | ||
| return cast( | ||
| "Callable[[T], T]", | ||
| functools.partial( | ||
| _TaskDecorator, | ||
| multiple_outputs=multiple_outputs, | ||
| operator_class=decorated_operator_class, | ||
| kwargs=kwargs, | ||
| ), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,19 +20,101 @@ | |
| together when the DAG is displayed graphically. | ||
| """ | ||
| import functools | ||
| import warnings | ||
| from inspect import signature | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, overload | ||
| from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, cast, overload | ||
|
|
||
| from airflow.utils.task_group import TaskGroup | ||
| import attr | ||
|
|
||
| from airflow.utils.task_group import MappedTaskGroup, TaskGroup | ||
|
|
||
| if TYPE_CHECKING: | ||
| from airflow.models import DAG | ||
|
|
||
| F = TypeVar("F", bound=Callable[..., Any]) | ||
| T = TypeVar("T", bound=Callable) | ||
| R = TypeVar("R") | ||
|
|
||
| task_group_sig = signature(TaskGroup.__init__) | ||
|
|
||
|
|
||
| @attr.define | ||
| class TaskGroupDecorator(Generic[R]): | ||
| """:meta private:""" | ||
|
|
||
| function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable()) | ||
| kwargs: Dict[str, Any] = attr.ib(factory=dict) | ||
| """kwargs for the TaskGroup""" | ||
|
|
||
| @function.validator | ||
| def _validate_function(self, _, f): | ||
| if 'self' in signature(f).parameters: | ||
| raise TypeError('@task_group does not support methods') | ||
|
|
||
| @kwargs.validator | ||
| def _validate(self, _, kwargs): | ||
| task_group_sig.bind_partial(**kwargs) | ||
|
|
||
| def __attrs_post_init__(self): | ||
| self.kwargs.setdefault('group_id', self.function.__name__) | ||
|
|
||
| def _make_task_group(self, **kwargs) -> TaskGroup: | ||
| return TaskGroup(**kwargs) | ||
|
|
||
| def __call__(self, *args, **kwargs) -> R: | ||
| with self._make_task_group(add_suffix_on_collision=True, **self.kwargs): | ||
| # Invoke function to run Tasks inside the TaskGroup | ||
| return self.function(*args, **kwargs) | ||
|
|
||
| def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": | ||
| return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs) | ||
|
|
||
| def map(self, **kwargs) -> R: | ||
| return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).map(**kwargs) | ||
|
|
||
|
|
||
| @attr.define | ||
| class MappedTaskGroupDecorator(TaskGroupDecorator[R]): | ||
| """:meta private:""" | ||
|
|
||
| partial_kwargs: Dict[str, Any] = attr.ib(factory=dict) | ||
| """static kwargs for the decorated function""" | ||
| mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict) | ||
| """kwargs for the decorated function""" | ||
|
|
||
| def __call__(self, *args, **kwargs): | ||
| raise RuntimeError("A mapped @task_group cannot be called. Use `.map` and `.partial` instead") | ||
|
|
||
| def _make_task_group(self, **kwargs) -> MappedTaskGroup: | ||
| tg = MappedTaskGroup(**kwargs) | ||
| tg.partial_kwargs = self.partial_kwargs | ||
| tg.mapped_kwargs = self.mapped_kwargs | ||
| return tg | ||
|
|
||
| def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]": | ||
| if self.partial_kwargs: | ||
| raise RuntimeError("Already a partial task group") | ||
| self.partial_kwargs.update(kwargs) | ||
| return self | ||
|
|
||
| def map(self, **kwargs) -> R: | ||
| if self.mapped_kwargs: | ||
| raise RuntimeError("Already a mapped task group") | ||
| self.mapped_kwargs = kwargs | ||
|
|
||
| call_kwargs = self.partial_kwargs.copy() | ||
| duplicated_keys = set(call_kwargs).intersection(kwargs) | ||
| if duplicated_keys: | ||
| raise RuntimeError(f"Cannot map partial arguments: {', '.join(sorted(duplicated_keys))}") | ||
| call_kwargs.update({k: object() for k in kwargs}) | ||
|
|
||
| return super().__call__(**call_kwargs) | ||
|
|
||
| def __del__(self): | ||
| if not self.mapped_kwargs: | ||
|
||
| warnings.warn(f"Partial task group {self.function.__name__} was never mapped!") | ||
|
|
||
|
|
||
| # This covers the @task_group() case. Annotations are copied from the TaskGroup | ||
| # class, only providing a default to 'group_id' (this is optional for the | ||
| # decorator and defaults to the decorated function's name). Please keep them in | ||
|
|
@@ -73,31 +155,6 @@ def task_group(python_callable=None, *tg_args, **tg_kwargs): | |
| :param tg_args: Positional arguments for the TaskGroup object. | ||
| :param tg_kwargs: Keyword arguments for the TaskGroup object. | ||
| """ | ||
|
|
||
| def wrapper(f): | ||
| # Setting group_id as function name if not given in kwarg group_id | ||
| if not tg_args and 'group_id' not in tg_kwargs: | ||
| tg_kwargs['group_id'] = f.__name__ | ||
| task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs) | ||
|
|
||
| @functools.wraps(f) | ||
| def factory(*args, **kwargs): | ||
| # Generate signature for decorated function and bind the arguments when called | ||
| # we do this to extract parameters so we can annotate them on the DAG object. | ||
| # In addition, this fails if we are missing any args/kwargs with TypeError as expected. | ||
| # Apply defaults to capture default values if set. | ||
|
|
||
| # Initialize TaskGroup with bound arguments | ||
| with TaskGroup( | ||
| *task_group_bound_args.args, | ||
| add_suffix_on_collision=True, | ||
| **task_group_bound_args.kwargs, | ||
| ): | ||
| # Invoke function to run Tasks inside the TaskGroup | ||
| return f(*args, **kwargs) | ||
|
|
||
| return factory | ||
|
|
||
| if callable(python_callable): | ||
| return wrapper(python_callable) | ||
| return wrapper | ||
| return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs) | ||
| return cast("Callable[[T], T]", functools.partial(TaskGroupDecorator, kwargs=tg_kwargs)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This makes me wonder, what should
MyOperator.map(x=something).map(x=another)do? If I understand this correctly, this would currently discardsomethingand just map toanother. We should likely add something in to prevent this from happening, perhaps in_validate_arg_names?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's an unspecified point of the API. I think I'm leaning towards
.map().map()being an error on general grounds.But yes, in my head updating the params was my intent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think
map().map()should be an error. We already agreed on suportingmap(arg1, arg2)and:Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm unliekly to be able to make this change before stopping for Christmas, so either someone else can make this, or we can merge it with this and fix it later.
(I think that having a
.map()function that returns an error would be clearer than having no map method, similar to how I have.partial()on a Task object throw an error still.)