diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d81d45dc7e92a..37106c580f16b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -90,6 +90,7 @@ from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.types import NOTSET, ArgNotSet from airflow.utils.weight_rule import WeightRule from airflow.utils.xcom import XCOM_RETURN_KEY @@ -184,6 +185,26 @@ def partial(**kwargs): return self.class_method.__get__(cls, cls) +_PARTIAL_DEFAULTS = { + "owner": DEFAULT_OWNER, + "trigger_rule": DEFAULT_TRIGGER_RULE, + "depends_on_past": False, + "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + "wait_for_downstream": False, + "retries": DEFAULT_RETRIES, + "queue": DEFAULT_QUEUE, + "pool_slots": DEFAULT_POOL_SLOTS, + "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, + "retry_delay": DEFAULT_RETRY_DELAY, + "retry_exponential_backoff": False, + "priority_weight": DEFAULT_PRIORITY_WEIGHT, + "weight_rule": DEFAULT_WEIGHT_RULE, + "inlets": [], + "outlets": [], +} + + # This is what handles the actual mapping. def partial( operator_class: type[BaseOperator], @@ -191,43 +212,43 @@ def partial( task_id: str, dag: DAG | None = None, task_group: TaskGroup | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - owner: str = DEFAULT_OWNER, - email: None | str | Iterable[str] = None, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, params: collections.abc.MutableMapping | None = None, - resources: dict[str, Any] | None = None, - trigger_rule: str = DEFAULT_TRIGGER_RULE, - depends_on_past: bool = False, - ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - wait_for_downstream: bool = False, - retries: int | None = DEFAULT_RETRIES, - queue: str = DEFAULT_QUEUE, - pool: str | None = None, - pool_slots: int = DEFAULT_POOL_SLOTS, - execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, - max_retry_delay: None | timedelta | float = None, - retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, - retry_exponential_backoff: bool = False, - priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - weight_rule: str = DEFAULT_WEIGHT_RULE, - sla: timedelta | None = None, - max_active_tis_per_dag: int | None = None, - max_active_tis_per_dagrun: int | None = None, - on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - run_as_user: str | None = None, - executor_config: dict | None = None, - inlets: Any | None = None, - outlets: Any | None = None, - doc: str | None = None, - doc_md: str | None = None, - doc_json: str | None = None, - doc_yaml: str | None = None, - doc_rst: str | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + max_active_tis_per_dagrun: int | None | ArgNotSet = NOTSET, + on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, **kwargs, ) -> OperatorPartial: from airflow.models.dag import DagContext @@ -242,54 +263,62 @@ def partial( task_id = task_group.child_id(task_id) # Merge DAG and task group level defaults into user-supplied values. - partial_kwargs, partial_params = get_merged_defaults( + dag_default_args, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=params, task_default_args=kwargs.pop("default_args", None), ) - partial_kwargs.update(kwargs) - - # Always fully populate partial kwargs to exclude them from map(). - partial_kwargs.setdefault("dag", dag) - partial_kwargs.setdefault("task_group", task_group) - partial_kwargs.setdefault("task_id", task_id) - partial_kwargs.setdefault("start_date", start_date) - partial_kwargs.setdefault("end_date", end_date) - partial_kwargs.setdefault("owner", owner) - partial_kwargs.setdefault("email", email) - partial_kwargs.setdefault("trigger_rule", trigger_rule) - partial_kwargs.setdefault("depends_on_past", depends_on_past) - partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past) - partial_kwargs.setdefault("wait_for_past_depends_before_skipping", wait_for_past_depends_before_skipping) - partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream) - partial_kwargs.setdefault("retries", retries) - partial_kwargs.setdefault("queue", queue) - partial_kwargs.setdefault("pool", pool) - partial_kwargs.setdefault("pool_slots", pool_slots) - partial_kwargs.setdefault("execution_timeout", execution_timeout) - partial_kwargs.setdefault("max_retry_delay", max_retry_delay) - partial_kwargs.setdefault("retry_delay", retry_delay) - partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff) - partial_kwargs.setdefault("priority_weight", priority_weight) - partial_kwargs.setdefault("weight_rule", weight_rule) - partial_kwargs.setdefault("sla", sla) - partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag) - partial_kwargs.setdefault("max_active_tis_per_dagrun", max_active_tis_per_dagrun) - partial_kwargs.setdefault("on_execute_callback", on_execute_callback) - partial_kwargs.setdefault("on_failure_callback", on_failure_callback) - partial_kwargs.setdefault("on_retry_callback", on_retry_callback) - partial_kwargs.setdefault("on_success_callback", on_success_callback) - partial_kwargs.setdefault("run_as_user", run_as_user) - partial_kwargs.setdefault("executor_config", executor_config) - partial_kwargs.setdefault("inlets", inlets or []) - partial_kwargs.setdefault("outlets", outlets or []) - partial_kwargs.setdefault("resources", resources) - partial_kwargs.setdefault("doc", doc) - partial_kwargs.setdefault("doc_json", doc_json) - partial_kwargs.setdefault("doc_md", doc_md) - partial_kwargs.setdefault("doc_rst", doc_rst) - partial_kwargs.setdefault("doc_yaml", doc_yaml) + + # Create partial_kwargs from args and kwargs + partial_kwargs: dict[str, Any] = { + **kwargs, + "dag": dag, + "task_group": task_group, + "task_id": task_id, + "start_date": start_date, + "end_date": end_date, + "owner": owner, + "email": email, + "trigger_rule": trigger_rule, + "depends_on_past": depends_on_past, + "ignore_first_depends_on_past": ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping, + "wait_for_downstream": wait_for_downstream, + "retries": retries, + "queue": queue, + "pool": pool, + "pool_slots": pool_slots, + "execution_timeout": execution_timeout, + "max_retry_delay": max_retry_delay, + "retry_delay": retry_delay, + "retry_exponential_backoff": retry_exponential_backoff, + "priority_weight": priority_weight, + "weight_rule": weight_rule, + "sla": sla, + "max_active_tis_per_dag": max_active_tis_per_dag, + "max_active_tis_per_dagrun": max_active_tis_per_dagrun, + "on_execute_callback": on_execute_callback, + "on_failure_callback": on_failure_callback, + "on_retry_callback": on_retry_callback, + "on_success_callback": on_success_callback, + "run_as_user": run_as_user, + "executor_config": executor_config, + "inlets": inlets, + "outlets": outlets, + "resources": resources, + "doc": doc, + "doc_json": doc_json, + "doc_md": doc_md, + "doc_rst": doc_rst, + "doc_yaml": doc_yaml, + } + + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET) + + # Fill fields not provided by the user with default values. + partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()} # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). if "task_concurrency" in kwargs: # Reject deprecated option. diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 931d262d24c7a..84ddd9fb66e86 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -85,6 +85,20 @@ def test_task_mapping_default_args(): assert mapped.start_date == pendulum.instance(default_args["start_date"]) +def test_task_mapping_override_default_args(): + default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} + with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) + + # retries should be 1 because it is provided as a partial arg + assert mapped.partial_kwargs["retries"] == 1 + # start_date should be equal to default_args["start_date"] because it is not provided as partial arg + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + # owner should be equal to Airflow default owner (airflow) because it is not provided at all + assert mapped.owner == "airflow" + + def test_map_unknown_arg_raises(): with pytest.raises(TypeError, match=r"argument 'file'"): BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}])