Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 145 additions & 42 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,28 @@

import functools
import inspect
import itertools
import re
from inspect import signature
from typing import Any, Callable, Collection, Dict, Mapping, Optional, Sequence, Type, TypeVar, cast

from typing import (
Any,
Callable,
Collection,
Dict,
Generic,
Mapping,
Optional,
Sequence,
Set,
Type,
TypeVar,
cast,
)

import attr

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.models.baseoperator import BaseOperator, MappedOperator
from airflow.models.dag import DAG, DagContext
from airflow.models.xcom_arg import XComArg
from airflow.utils.context import Context
Expand All @@ -39,7 +55,7 @@ def validate_python_callable(python_callable):
"""
if not callable(python_callable):
raise TypeError('`python_callable` param must be callable')
if 'self' in signature(python_callable).parameters.keys():
if 'self' in inspect.signature(python_callable).parameters.keys():
raise AirflowException('@task does not support methods')


Expand Down Expand Up @@ -127,7 +143,7 @@ def __init__(
op_kwargs = op_kwargs or {}

# Check that arguments can be binded
signature(python_callable).bind(*op_args, **op_kwargs)
inspect.signature(python_callable).bind(*op_args, **op_kwargs)
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
Expand Down Expand Up @@ -169,7 +185,7 @@ def _hook_apply_defaults(self, *args, **kwargs):
python_callable = kwargs['python_callable']
default_args = kwargs.get('default_args') or {}
op_kwargs = kwargs.get('op_kwargs') or {}
f_sig = signature(python_callable)
f_sig = inspect.signature(python_callable)
for arg in f_sig.parameters:
if arg not in op_kwargs and arg in default_args:
op_kwargs[arg] = default_args[arg]
Expand All @@ -179,11 +195,113 @@ def _hook_apply_defaults(self, *args, **kwargs):

T = TypeVar("T", bound=Callable)

OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")


@attr.define(slots=False)
class _TaskDecorator(Generic[T, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.

``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function.

:meta private:
"""

function: T = attr.ib(validator=attr.validators.is_callable())
operator_class: Type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
kwargs: Dict[str, Any] = attr.ib(factory=dict)

decorator_name: str = attr.ib(repr=False, default="task")

@cached_property
def function_signature(self):
return inspect.signature(self.function)

@cached_property
def function_arg_names(self) -> Set[str]:
return set(self.function_signature.parameters)

@function.validator
def _validate_function(self, _, f):
if 'self' in self.function_arg_names:
raise TypeError(f'@{self.decorator_name} does not support methods')

@multiple_outputs.default
def _infer_multiple_outputs(self):
return_type = self.function_signature.return_annotation
ttype = getattr(return_type, "__origin__", None)

return return_type is not inspect.Signature.empty and ttype in (dict, Dict)

def __attrs_post_init__(self):
self.kwargs.setdefault('task_id', self.function.__name__)

def __call__(self, *args, **kwargs) -> XComArg:
op = self.operator_class(
python_callable=self.function,
op_args=args,
op_kwargs=kwargs,
multiple_outputs=self.multiple_outputs,
**self.kwargs,
)
if self.function.__doc__:
op.doc_md = self.function.__doc__
return XComArg(op)

def _validate_arg_names(self, funcname: str, kwargs: Dict[str, Any], valid_names: Set[str] = set()):
unknown_args = kwargs.copy()
for name in itertools.chain(self.function_arg_names, valid_names):
unknown_args.pop(name, None)

if not unknown_args:
# If we have no args left ot check, we are valid
return

if len(unknown_args) == 1:
raise TypeError(f'{funcname} got unexpected keyword argument {next(iter(unknown_args))!r}')
else:
names = ", ".join(repr(n) for n in unknown_args)
raise TypeError(f'{funcname} got unexpected keyword arguments {names}')

def map(
self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
) -> XComArg:

dag = dag or DagContext.get_current_dag()
task_group = task_group or TaskGroupContext.get_current_task_group(dag)
task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)

self._validate_arg_names("map", kwargs)

operator = MappedOperator(
operator_class=self.operator_class,
task_id=task_id,
dag=dag,
task_group=task_group,
partial_kwargs=self.kwargs,
# Set them to empty to bypass the validation, as for decorated stuff we validate ourselves
mapped_kwargs={},
)
operator.mapped_kwargs.update(kwargs)

return XComArg(operator=operator)

def partial(
self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
) -> "_TaskDecorator[T, OperatorSubclass]":
self._validate_arg_names("partial", kwargs, {'task_id'})
partial_kwargs = self.kwargs.copy()
partial_kwargs.update(kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes me wonder, what should MyOperator.map(x=something).map(x=another) do? If I understand this correctly, this would currently discard something and just map to another. We should likely add something in to prevent this from happening, perhaps in _validate_arg_names?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's an unspecified point of the API. I think I'm leaning towards .map().map() being an error on general grounds.

But yes, in my head updating the params was my intent.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think map().map() should be an error. We already agreed on suporting map(arg1, arg2) and:

There should be one-- and preferably only one --obvious way to do it.

Copy link
Member Author

@ashb ashb Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm unliekly to be able to make this change before stopping for Christmas, so either someone else can make this, or we can merge it with this and fix it later.

(I think that having a .map() function that returns an error would be clearer than having no map method, similar to how I have .partial() on a Task object throw an error still.)

return attr.evolve(self, kwargs=partial_kwargs)


def task_decorator_factory(
python_callable: Optional[Callable] = None,
*,
multiple_outputs: Optional[bool] = None,
decorated_operator_class: Optional[Type[BaseOperator]] = None,
decorated_operator_class: Type[BaseOperator],
**kwargs,
) -> Callable[[T], T]:
"""
Expand All @@ -202,38 +320,23 @@ def task_decorator_factory(
:type decorated_operator_class: BaseOperator

"""
# try to infer from type annotation
if python_callable and multiple_outputs is None:
sig = signature(python_callable).return_annotation
ttype = getattr(sig, "__origin__", None)

multiple_outputs = sig != inspect.Signature.empty and ttype in (dict, Dict)

def wrapper(f: T):
"""
Python wrapper to generate PythonDecoratedOperator out of simple python functions.
Used for Airflow Decorated interface
"""
validate_python_callable(f)
kwargs.setdefault('task_id', f.__name__)

@functools.wraps(f)
def factory(*args, **f_kwargs):
op = decorated_operator_class(
python_callable=f,
op_args=args,
op_kwargs=f_kwargs,
multiple_outputs=multiple_outputs,
**kwargs,
)
if f.__doc__:
op.doc_md = f.__doc__
return XComArg(op)

return cast(T, factory)

if callable(python_callable):
return wrapper(python_callable)
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
if python_callable:
return _TaskDecorator( # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. _TaskDecorator is bound to callable so ignore should not be needed I guess? Is the type: ignore here about the generics

function=python_callable,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
)
elif python_callable is not None:
raise AirflowException('No args allowed while using @task, use kwargs instead')
return wrapper
raise TypeError('No args allowed while using @task, use kwargs instead')
return cast(
"Callable[[T], T]",
functools.partial(
_TaskDecorator,
multiple_outputs=multiple_outputs,
operator_class=decorated_operator_class,
kwargs=kwargs,
),
)
115 changes: 86 additions & 29 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,101 @@
together when the DAG is displayed graphically.
"""
import functools
import warnings
from inspect import signature
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, cast, overload

from airflow.utils.task_group import TaskGroup
import attr

from airflow.utils.task_group import MappedTaskGroup, TaskGroup

if TYPE_CHECKING:
from airflow.models import DAG

F = TypeVar("F", bound=Callable[..., Any])
T = TypeVar("T", bound=Callable)
R = TypeVar("R")

task_group_sig = signature(TaskGroup.__init__)


@attr.define
class TaskGroupDecorator(Generic[R]):
""":meta private:"""

function: Callable[..., R] = attr.ib(validator=attr.validators.is_callable())
kwargs: Dict[str, Any] = attr.ib(factory=dict)
"""kwargs for the TaskGroup"""

@function.validator
def _validate_function(self, _, f):
if 'self' in signature(f).parameters:
raise TypeError('@task_group does not support methods')

@kwargs.validator
def _validate(self, _, kwargs):
task_group_sig.bind_partial(**kwargs)

def __attrs_post_init__(self):
self.kwargs.setdefault('group_id', self.function.__name__)

def _make_task_group(self, **kwargs) -> TaskGroup:
return TaskGroup(**kwargs)

def __call__(self, *args, **kwargs) -> R:
with self._make_task_group(add_suffix_on_collision=True, **self.kwargs):
# Invoke function to run Tasks inside the TaskGroup
return self.function(*args, **kwargs)

def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).partial(**kwargs)

def map(self, **kwargs) -> R:
return MappedTaskGroupDecorator(function=self.function, kwargs=self.kwargs).map(**kwargs)


@attr.define
class MappedTaskGroupDecorator(TaskGroupDecorator[R]):
""":meta private:"""

partial_kwargs: Dict[str, Any] = attr.ib(factory=dict)
"""static kwargs for the decorated function"""
mapped_kwargs: Dict[str, Any] = attr.ib(factory=dict)
"""kwargs for the decorated function"""

def __call__(self, *args, **kwargs):
raise RuntimeError("A mapped @task_group cannot be called. Use `.map` and `.partial` instead")

def _make_task_group(self, **kwargs) -> MappedTaskGroup:
tg = MappedTaskGroup(**kwargs)
tg.partial_kwargs = self.partial_kwargs
tg.mapped_kwargs = self.mapped_kwargs
return tg

def partial(self, **kwargs) -> "MappedTaskGroupDecorator[R]":
if self.partial_kwargs:
raise RuntimeError("Already a partial task group")
self.partial_kwargs.update(kwargs)
return self

def map(self, **kwargs) -> R:
if self.mapped_kwargs:
raise RuntimeError("Already a mapped task group")
self.mapped_kwargs = kwargs

call_kwargs = self.partial_kwargs.copy()
duplicated_keys = set(call_kwargs).intersection(kwargs)
if duplicated_keys:
raise RuntimeError(f"Cannot map partial arguments: {', '.join(sorted(duplicated_keys))}")
call_kwargs.update({k: object() for k in kwargs})

return super().__call__(**call_kwargs)

def __del__(self):
if not self.mapped_kwargs:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edge case (that I think is fine): This would mean that tg.partial(a=2).map() wouldn't work -- as the mapped_kwargs would be an empty dict and evaluate to false.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should even disable map() and partial() without any arguments (because how’d they make sense…?) but let’s not bother.

warnings.warn(f"Partial task group {self.function.__name__} was never mapped!")


# This covers the @task_group() case. Annotations are copied from the TaskGroup
# class, only providing a default to 'group_id' (this is optional for the
# decorator and defaults to the decorated function's name). Please keep them in
Expand Down Expand Up @@ -73,31 +155,6 @@ def task_group(python_callable=None, *tg_args, **tg_kwargs):
:param tg_args: Positional arguments for the TaskGroup object.
:param tg_kwargs: Keyword arguments for the TaskGroup object.
"""

def wrapper(f):
# Setting group_id as function name if not given in kwarg group_id
if not tg_args and 'group_id' not in tg_kwargs:
tg_kwargs['group_id'] = f.__name__
task_group_bound_args = task_group_sig.bind_partial(*tg_args, **tg_kwargs)

@functools.wraps(f)
def factory(*args, **kwargs):
# Generate signature for decorated function and bind the arguments when called
# we do this to extract parameters so we can annotate them on the DAG object.
# In addition, this fails if we are missing any args/kwargs with TypeError as expected.
# Apply defaults to capture default values if set.

# Initialize TaskGroup with bound arguments
with TaskGroup(
*task_group_bound_args.args,
add_suffix_on_collision=True,
**task_group_bound_args.kwargs,
):
# Invoke function to run Tasks inside the TaskGroup
return f(*args, **kwargs)

return factory

if callable(python_callable):
return wrapper(python_callable)
return wrapper
return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs)
return cast("Callable[[T], T]", functools.partial(TaskGroupDecorator, kwargs=tg_kwargs))
Loading