Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 108 additions & 79 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET, ArgNotSet
from airflow.utils.weight_rule import WeightRule
from airflow.utils.xcom import XCOM_RETURN_KEY

Expand Down Expand Up @@ -184,50 +185,70 @@ def partial(**kwargs):
return self.class_method.__get__(cls, cls)


_PARTIAL_DEFAULTS = {
"owner": DEFAULT_OWNER,
"trigger_rule": DEFAULT_TRIGGER_RULE,
"depends_on_past": False,
"ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
"wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
"wait_for_downstream": False,
"retries": DEFAULT_RETRIES,
"queue": DEFAULT_QUEUE,
"pool_slots": DEFAULT_POOL_SLOTS,
"execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT,
"retry_delay": DEFAULT_RETRY_DELAY,
"retry_exponential_backoff": False,
"priority_weight": DEFAULT_PRIORITY_WEIGHT,
"weight_rule": DEFAULT_WEIGHT_RULE,
"inlets": [],
"outlets": [],
}


# This is what handles the actual mapping.
def partial(
operator_class: type[BaseOperator],
*,
task_id: str,
dag: DAG | None = None,
task_group: TaskGroup | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
owner: str = DEFAULT_OWNER,
email: None | str | Iterable[str] = None,
start_date: datetime | ArgNotSet = NOTSET,
end_date: datetime | ArgNotSet = NOTSET,
owner: str | ArgNotSet = NOTSET,
email: None | str | Iterable[str] | ArgNotSet = NOTSET,
params: collections.abc.MutableMapping | None = None,
resources: dict[str, Any] | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
depends_on_past: bool = False,
ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING,
wait_for_downstream: bool = False,
retries: int | None = DEFAULT_RETRIES,
queue: str = DEFAULT_QUEUE,
pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
max_retry_delay: None | timedelta | float = None,
retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
sla: timedelta | None = None,
max_active_tis_per_dag: int | None = None,
max_active_tis_per_dagrun: int | None = None,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None,
run_as_user: str | None = None,
executor_config: dict | None = None,
inlets: Any | None = None,
outlets: Any | None = None,
doc: str | None = None,
doc_md: str | None = None,
doc_json: str | None = None,
doc_yaml: str | None = None,
doc_rst: str | None = None,
resources: dict[str, Any] | None | ArgNotSet = NOTSET,
trigger_rule: str | ArgNotSet = NOTSET,
depends_on_past: bool | ArgNotSet = NOTSET,
ignore_first_depends_on_past: bool | ArgNotSet = NOTSET,
wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET,
wait_for_downstream: bool | ArgNotSet = NOTSET,
retries: int | None | ArgNotSet = NOTSET,
queue: str | ArgNotSet = NOTSET,
pool: str | ArgNotSet = NOTSET,
pool_slots: int | ArgNotSet = NOTSET,
execution_timeout: timedelta | None | ArgNotSet = NOTSET,
max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET,
retry_delay: timedelta | float | ArgNotSet = NOTSET,
retry_exponential_backoff: bool | ArgNotSet = NOTSET,
priority_weight: int | ArgNotSet = NOTSET,
weight_rule: str | ArgNotSet = NOTSET,
sla: timedelta | None | ArgNotSet = NOTSET,
max_active_tis_per_dag: int | None | ArgNotSet = NOTSET,
max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET,
on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET,
run_as_user: str | None | ArgNotSet = NOTSET,
executor_config: dict | None | ArgNotSet = NOTSET,
inlets: Any | None | ArgNotSet = NOTSET,
outlets: Any | None | ArgNotSet = NOTSET,
doc: str | None | ArgNotSet = NOTSET,
doc_md: str | None | ArgNotSet = NOTSET,
doc_json: str | None | ArgNotSet = NOTSET,
doc_yaml: str | None | ArgNotSet = NOTSET,
doc_rst: str | None | ArgNotSet = NOTSET,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
Expand All @@ -242,54 +263,62 @@ def partial(
task_id = task_group.child_id(task_id)

# Merge DAG and task group level defaults into user-supplied values.
partial_kwargs, partial_params = get_merged_defaults(
dag_default_args, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=params,
task_default_args=kwargs.pop("default_args", None),
)
partial_kwargs.update(kwargs)

# Always fully populate partial kwargs to exclude them from map().
partial_kwargs.setdefault("dag", dag)
partial_kwargs.setdefault("task_group", task_group)
partial_kwargs.setdefault("task_id", task_id)
partial_kwargs.setdefault("start_date", start_date)
partial_kwargs.setdefault("end_date", end_date)
partial_kwargs.setdefault("owner", owner)
partial_kwargs.setdefault("email", email)
partial_kwargs.setdefault("trigger_rule", trigger_rule)
partial_kwargs.setdefault("depends_on_past", depends_on_past)
partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past)
partial_kwargs.setdefault("wait_for_past_depends_before_skipping", wait_for_past_depends_before_skipping)
partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream)
partial_kwargs.setdefault("retries", retries)
partial_kwargs.setdefault("queue", queue)
partial_kwargs.setdefault("pool", pool)
partial_kwargs.setdefault("pool_slots", pool_slots)
partial_kwargs.setdefault("execution_timeout", execution_timeout)
partial_kwargs.setdefault("max_retry_delay", max_retry_delay)
partial_kwargs.setdefault("retry_delay", retry_delay)
partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff)
partial_kwargs.setdefault("priority_weight", priority_weight)
partial_kwargs.setdefault("weight_rule", weight_rule)
partial_kwargs.setdefault("sla", sla)
partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag)
partial_kwargs.setdefault("max_active_tis_per_dagrun", max_active_tis_per_dagrun)
partial_kwargs.setdefault("on_execute_callback", on_execute_callback)
partial_kwargs.setdefault("on_failure_callback", on_failure_callback)
partial_kwargs.setdefault("on_retry_callback", on_retry_callback)
partial_kwargs.setdefault("on_success_callback", on_success_callback)
partial_kwargs.setdefault("run_as_user", run_as_user)
partial_kwargs.setdefault("executor_config", executor_config)
partial_kwargs.setdefault("inlets", inlets or [])
partial_kwargs.setdefault("outlets", outlets or [])
partial_kwargs.setdefault("resources", resources)
partial_kwargs.setdefault("doc", doc)
partial_kwargs.setdefault("doc_json", doc_json)
partial_kwargs.setdefault("doc_md", doc_md)
partial_kwargs.setdefault("doc_rst", doc_rst)
partial_kwargs.setdefault("doc_yaml", doc_yaml)

# Create partial_kwargs from args and kwargs
partial_kwargs: dict[str, Any] = {
**kwargs,
"dag": dag,
"task_group": task_group,
"task_id": task_id,
"start_date": start_date,
"end_date": end_date,
"owner": owner,
"email": email,
"trigger_rule": trigger_rule,
"depends_on_past": depends_on_past,
"ignore_first_depends_on_past": ignore_first_depends_on_past,
"wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping,
"wait_for_downstream": wait_for_downstream,
"retries": retries,
"queue": queue,
"pool": pool,
"pool_slots": pool_slots,
"execution_timeout": execution_timeout,
"max_retry_delay": max_retry_delay,
"retry_delay": retry_delay,
"retry_exponential_backoff": retry_exponential_backoff,
"priority_weight": priority_weight,
"weight_rule": weight_rule,
"sla": sla,
"max_active_tis_per_dag": max_active_tis_per_dag,
"max_active_tis_per_dagrun": max_active_tis_per_dagrun,
"on_execute_callback": on_execute_callback,
"on_failure_callback": on_failure_callback,
"on_retry_callback": on_retry_callback,
"on_success_callback": on_success_callback,
"run_as_user": run_as_user,
"executor_config": executor_config,
"inlets": inlets,
"outlets": outlets,
"resources": resources,
"doc": doc,
"doc_json": doc_json,
"doc_md": doc_md,
"doc_rst": doc_rst,
"doc_yaml": doc_yaml,
}

# Inject DAG-level default args into args provided to this function.
partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET)

# Fill fields not provided by the user with default values.
partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()}

# Post-process arguments. Should be kept in sync with _TaskDecorator.expand().
if "task_concurrency" in kwargs: # Reject deprecated option.
Expand Down
14 changes: 14 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ def test_task_mapping_default_args():
assert mapped.start_date == pendulum.instance(default_args["start_date"])


def test_task_mapping_override_default_args():
default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()}
with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args):
literal = ["a", "b", "c"]
mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal)

# retries should be 1 because it is provided as a partial arg
assert mapped.partial_kwargs["retries"] == 1
# start_date should be equal to default_args["start_date"] because it is not provided as partial arg
assert mapped.start_date == pendulum.instance(default_args["start_date"])
# owner should be equal to Airflow default owner (airflow) because it is not provided at all
assert mapped.owner == "airflow"


def test_map_unknown_arg_raises():
with pytest.raises(TypeError, match=r"argument 'file'"):
BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}])
Expand Down