From acf78e032219fab4f010bdfdf381749dac623297 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 4 Mar 2023 00:50:45 +0100 Subject: [PATCH 01/10] Add a failing test to make it pass --- tests/models/test_mappedoperator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index bdfcf8bc7f809..9b77cc692a098 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -85,6 +85,15 @@ def test_task_mapping_default_args(): assert mapped.start_date == pendulum.instance(default_args["start_date"]) +def test_task_mapping_override_default_args(): + default_args = {"retries": 2} + with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) + + assert mapped.partial_kwargs["retries"] == 1 + + def test_map_unknown_arg_raises(): with pytest.raises(TypeError, match=r"argument 'file'"): BaseOperator.partial(task_id="a").expand(file=[1, 2, {"a": "b"}]) From 934d128888a92918d26f99366909695bc27c375a Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 4 Mar 2023 01:56:53 +0100 Subject: [PATCH 02/10] use partial_kwargs when they are provide and override only None values by dag default values --- airflow/models/baseoperator.py | 169 +++++++++++++++++++++------------ 1 file changed, 108 insertions(+), 61 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 6b33ef9ee392d..1f5dfb1ef2cb9 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -192,25 +192,25 @@ def partial( task_group: TaskGroup | None = None, start_date: datetime | None = None, end_date: datetime | None = None, - owner: str = DEFAULT_OWNER, + owner: str | None = None, email: None | str | Iterable[str] = None, params: dict | None = None, resources: dict[str, Any] | None = None, - trigger_rule: str = DEFAULT_TRIGGER_RULE, - depends_on_past: bool = False, - ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - wait_for_downstream: bool = False, - retries: int | None = DEFAULT_RETRIES, - queue: str = DEFAULT_QUEUE, + trigger_rule: str | None = None, + depends_on_past: bool | None = None, + ignore_first_depends_on_past: bool | None = None, + wait_for_past_depends_before_skipping: bool | None = None, + wait_for_downstream: bool | None = None, + retries: int | None = None, + queue: str | None = None, pool: str | None = None, - pool_slots: int = DEFAULT_POOL_SLOTS, - execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + pool_slots: int | None = None, + execution_timeout: timedelta | None = None, max_retry_delay: None | timedelta | float = None, - retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, - retry_exponential_backoff: bool = False, - priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - weight_rule: str = DEFAULT_WEIGHT_RULE, + retry_delay: timedelta | float | None = None, + retry_exponential_backoff: bool | None = None, + priority_weight: int | None = None, + weight_rule: str | None = None, sla: timedelta | None = None, max_active_tis_per_dag: int | None = None, on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, @@ -240,72 +240,119 @@ def partial( task_id = task_group.child_id(task_id) # Merge DAG and task group level defaults into user-supplied values. - partial_kwargs, partial_params = get_merged_defaults( + default_partial_kwargs, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=params, task_default_args=kwargs.pop("default_args", None), ) - partial_kwargs.update(kwargs) - - # Always fully populate partial kwargs to exclude them from map(). - partial_kwargs.setdefault("dag", dag) - partial_kwargs.setdefault("task_group", task_group) - partial_kwargs.setdefault("task_id", task_id) - partial_kwargs.setdefault("start_date", start_date) - partial_kwargs.setdefault("end_date", end_date) - partial_kwargs.setdefault("owner", owner) - partial_kwargs.setdefault("email", email) - partial_kwargs.setdefault("trigger_rule", trigger_rule) - partial_kwargs.setdefault("depends_on_past", depends_on_past) - partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past) - partial_kwargs.setdefault("wait_for_past_depends_before_skipping", wait_for_past_depends_before_skipping) - partial_kwargs.setdefault("wait_for_downstream", wait_for_downstream) - partial_kwargs.setdefault("retries", retries) - partial_kwargs.setdefault("queue", queue) - partial_kwargs.setdefault("pool", pool) - partial_kwargs.setdefault("pool_slots", pool_slots) - partial_kwargs.setdefault("execution_timeout", execution_timeout) - partial_kwargs.setdefault("max_retry_delay", max_retry_delay) - partial_kwargs.setdefault("retry_delay", retry_delay) - partial_kwargs.setdefault("retry_exponential_backoff", retry_exponential_backoff) - partial_kwargs.setdefault("priority_weight", priority_weight) - partial_kwargs.setdefault("weight_rule", weight_rule) - partial_kwargs.setdefault("sla", sla) - partial_kwargs.setdefault("max_active_tis_per_dag", max_active_tis_per_dag) - partial_kwargs.setdefault("on_execute_callback", on_execute_callback) - partial_kwargs.setdefault("on_failure_callback", on_failure_callback) - partial_kwargs.setdefault("on_retry_callback", on_retry_callback) - partial_kwargs.setdefault("on_success_callback", on_success_callback) - partial_kwargs.setdefault("run_as_user", run_as_user) - partial_kwargs.setdefault("executor_config", executor_config) - partial_kwargs.setdefault("inlets", inlets or []) - partial_kwargs.setdefault("outlets", outlets or []) - partial_kwargs.setdefault("resources", resources) - partial_kwargs.setdefault("doc", doc) - partial_kwargs.setdefault("doc_json", doc_json) - partial_kwargs.setdefault("doc_md", doc_md) - partial_kwargs.setdefault("doc_rst", doc_rst) - partial_kwargs.setdefault("doc_yaml", doc_yaml) + + # Create partial_kwargs from args and kwargs + partial_kwargs = { + **kwargs, + "dag": dag, + "task_group": task_group, + "task_id": task_id, + "start_date": start_date, + "end_date": end_date, + "owner": owner, + "email": email, + "trigger_rule": trigger_rule, + "depends_on_past": depends_on_past, + "ignore_first_depends_on_past": ignore_first_depends_on_past, + "wait_for_past_depends_before_skipping": wait_for_past_depends_before_skipping, + "wait_for_downstream": wait_for_downstream, + "retries": retries, + "queue": queue, + "pool": pool, + "pool_slots": pool_slots, + "execution_timeout": execution_timeout, + "max_retry_delay": max_retry_delay, + "retry_delay": retry_delay, + "retry_exponential_backoff": retry_exponential_backoff, + "priority_weight": priority_weight, + "weight_rule": weight_rule, + "sla": sla, + "max_active_tis_per_dag": max_active_tis_per_dag, + "on_execute_callback": on_execute_callback, + "on_failure_callback": on_failure_callback, + "on_retry_callback": on_retry_callback, + "on_success_callback": on_success_callback, + "run_as_user": run_as_user, + "executor_config": executor_config, + "inlets": inlets, + "outlets": outlets, + "resources": resources, + "doc": doc, + "doc_json": doc_json, + "doc_md": doc_md, + "doc_rst": doc_rst, + "doc_yaml": doc_yaml, + } + + # Override None kwargs by dag default values + for k, v in default_partial_kwargs.items(): + if partial_kwargs.get(k) is None: + partial_kwargs[k] = v + + # Override None kwargs which don't have a dag default value by Airflow default value + partial_kwargs["owner"] = partial_kwargs["owner"] or DEFAULT_OWNER + partial_kwargs["trigger_rule"] = partial_kwargs["trigger_rule"] or DEFAULT_TRIGGER_RULE + partial_kwargs["depends_on_past"] = ( + partial_kwargs["depends_on_past"] if partial_kwargs["depends_on_past"] is not None else False + ) + partial_kwargs["ignore_first_depends_on_past"] = ( + partial_kwargs["ignore_first_depends_on_past"] + if partial_kwargs["ignore_first_depends_on_past"] is not None + else DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST + ) + partial_kwargs["wait_for_past_depends_before_skipping"] = ( + partial_kwargs["wait_for_past_depends_before_skipping"] + if partial_kwargs["wait_for_past_depends_before_skipping"] is not None + else DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING + ) + partial_kwargs["wait_for_downstream"] = ( + partial_kwargs["wait_for_downstream"] if partial_kwargs["wait_for_downstream"] is not None else False + ) + partial_kwargs["retries"] = ( + partial_kwargs["retries"] if partial_kwargs["retries"] is not None else DEFAULT_RETRIES + ) + partial_kwargs["queue"] = partial_kwargs["queue"] or DEFAULT_QUEUE + partial_kwargs["pool_slots"] = ( + partial_kwargs["pool_slots"] if partial_kwargs["pool_slots"] is not None else DEFAULT_POOL_SLOTS + ) + partial_kwargs["execution_timeout"] = ( + partial_kwargs["execution_timeout"] or DEFAULT_TASK_EXECUTION_TIMEOUT + ) + partial_kwargs["retry_delay"] = partial_kwargs["retry_delay"] or DEFAULT_RETRY_DELAY + partial_kwargs["retry_exponential_backoff"] = ( + partial_kwargs["retry_exponential_backoff"] if partial_kwargs["retry_exponential_backoff"] else False + ) + partial_kwargs["priority_weight"] = ( + partial_kwargs["priority_weight"] if partial_kwargs["priority_weight"] else DEFAULT_PRIORITY_WEIGHT + ) + partial_kwargs["weight_rule"] = partial_kwargs["weight_rule"] or DEFAULT_WEIGHT_RULE # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). if "task_concurrency" in kwargs: # Reject deprecated option. raise TypeError("unexpected argument: task_concurrency") if partial_kwargs["wait_for_downstream"]: partial_kwargs["depends_on_past"] = True - partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) - partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) + partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) # type: ignore + partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) # type: ignore if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) - partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") + partial_kwargs["retry_delay"] = coerce_timedelta( + partial_kwargs["retry_delay"], key="retry_delay" # type: ignore + ) if partial_kwargs["max_retry_delay"] is not None: partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], + partial_kwargs["max_retry_delay"], # type: ignore key="max_retry_delay", ) partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {} - partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) + partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) # type: ignore return OperatorPartial( operator_class=operator_class, From 389c6a9bfd435d8a5f8fd76becf856257faa7627 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 4 Mar 2023 02:02:34 +0100 Subject: [PATCH 03/10] update the test and check if the values are filled in the right order --- tests/models/test_mappedoperator.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 9b77cc692a098..7ebc09b5c95ab 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -86,12 +86,17 @@ def test_task_mapping_default_args(): def test_task_mapping_override_default_args(): - default_args = {"retries": 2} + default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} with DAG("test-dag", start_date=DEFAULT_DATE, default_args=default_args): literal = ["a", "b", "c"] mapped = MockOperator.partial(task_id="task", retries=1).expand(arg2=literal) + # retries should be 1 because it is provided as a partial arg assert mapped.partial_kwargs["retries"] == 1 + # start_date should be equal to default_args["start_date"] because it is not provided as partial arg + assert mapped.start_date == pendulum.instance(default_args["start_date"]) + # owner should be equal to Airflow default owner (airflow) because it is not provided at all + assert mapped.owner == "airflow" def test_map_unknown_arg_raises(): From 7725d635f6116ac3dc4e0040ae963e67881ef944 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 4 Mar 2023 03:06:24 +0100 Subject: [PATCH 04/10] fix overriding retry_delay with default value when it is equal to 0 --- airflow/models/baseoperator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 1f5dfb1ef2cb9..f6c26d184390b 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -324,7 +324,9 @@ def partial( partial_kwargs["execution_timeout"] = ( partial_kwargs["execution_timeout"] or DEFAULT_TASK_EXECUTION_TIMEOUT ) - partial_kwargs["retry_delay"] = partial_kwargs["retry_delay"] or DEFAULT_RETRY_DELAY + partial_kwargs["retry_delay"] = ( + partial_kwargs["retry_delay"] if partial_kwargs["retry_delay"] is not None else DEFAULT_RETRY_DELAY + ) partial_kwargs["retry_exponential_backoff"] = ( partial_kwargs["retry_exponential_backoff"] if partial_kwargs["retry_exponential_backoff"] else False ) From 5e3b5ff0497f5633f7b8a8e1bc28efe25d9bac9f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 4 Mar 2023 03:10:07 +0100 Subject: [PATCH 05/10] add missing default value for inlets and outlets --- airflow/models/baseoperator.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f6c26d184390b..f1b07a9f74033 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -334,6 +334,8 @@ def partial( partial_kwargs["priority_weight"] if partial_kwargs["priority_weight"] else DEFAULT_PRIORITY_WEIGHT ) partial_kwargs["weight_rule"] = partial_kwargs["weight_rule"] or DEFAULT_WEIGHT_RULE + partial_kwargs["inlets"] = partial_kwargs["inlets"] or [] + partial_kwargs["outlets"] = partial_kwargs["outlets"] or [] # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). if "task_concurrency" in kwargs: # Reject deprecated option. From ad1f7b862a439063da11f54420b66427a472bd00 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 11 Mar 2023 02:43:00 +0100 Subject: [PATCH 06/10] set partial_kwargs dict type to dict[str, Any] and remove type ignore comments --- airflow/models/baseoperator.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index f1b07a9f74033..70e1377bbdb68 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -248,7 +248,7 @@ def partial( ) # Create partial_kwargs from args and kwargs - partial_kwargs = { + partial_kwargs: dict[str, Any] = { **kwargs, "dag": dag, "task_group": task_group, @@ -342,21 +342,19 @@ def partial( raise TypeError("unexpected argument: task_concurrency") if partial_kwargs["wait_for_downstream"]: partial_kwargs["depends_on_past"] = True - partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) # type: ignore - partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) # type: ignore + partial_kwargs["start_date"] = timezone.convert_to_utc(partial_kwargs["start_date"]) + partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) - partial_kwargs["retry_delay"] = coerce_timedelta( - partial_kwargs["retry_delay"], key="retry_delay" # type: ignore - ) + partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: partial_kwargs["max_retry_delay"] = coerce_timedelta( - partial_kwargs["max_retry_delay"], # type: ignore + partial_kwargs["max_retry_delay"], key="max_retry_delay", ) partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {} - partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) # type: ignore + partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"]) return OperatorPartial( operator_class=operator_class, From cb0980afdfc7e1fb2f8c12e02d0f28b56880143e Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sat, 11 Mar 2023 03:19:07 +0100 Subject: [PATCH 07/10] create a dict for default values and use NotSet instead of None to support None as accepted value --- airflow/models/baseoperator.py | 152 +++++++++++++++++---------------- 1 file changed, 77 insertions(+), 75 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 70e1377bbdb68..5c7fb90750611 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -90,6 +90,7 @@ from airflow.utils.operator_resources import Resources from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.types import NOTSET, ArgNotSet from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: @@ -190,42 +191,42 @@ def partial( task_id: str, dag: DAG | None = None, task_group: TaskGroup | None = None, - start_date: datetime | None = None, - end_date: datetime | None = None, - owner: str | None = None, - email: None | str | Iterable[str] = None, + start_date: datetime | None | ArgNotSet = NOTSET, + end_date: datetime | None | ArgNotSet = NOTSET, + owner: str | None | ArgNotSet = NOTSET, + email: None | str | Iterable[str] | ArgNotSet = NOTSET, params: dict | None = None, - resources: dict[str, Any] | None = None, - trigger_rule: str | None = None, - depends_on_past: bool | None = None, - ignore_first_depends_on_past: bool | None = None, - wait_for_past_depends_before_skipping: bool | None = None, - wait_for_downstream: bool | None = None, - retries: int | None = None, - queue: str | None = None, + resources: dict[str, Any] | None | ArgNotSet = NOTSET, + trigger_rule: str | None | ArgNotSet = NOTSET, + depends_on_past: bool | None | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | None | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | None | ArgNotSet = NOTSET, + wait_for_downstream: bool | None | ArgNotSet = NOTSET, + retries: int | None | ArgNotSet = NOTSET, + queue: str | None | ArgNotSet = NOTSET, pool: str | None = None, - pool_slots: int | None = None, - execution_timeout: timedelta | None = None, - max_retry_delay: None | timedelta | float = None, - retry_delay: timedelta | float | None = None, - retry_exponential_backoff: bool | None = None, - priority_weight: int | None = None, - weight_rule: str | None = None, - sla: timedelta | None = None, - max_active_tis_per_dag: int | None = None, - on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None, - run_as_user: str | None = None, - executor_config: dict | None = None, - inlets: Any | None = None, - outlets: Any | None = None, - doc: str | None = None, - doc_md: str | None = None, - doc_json: str | None = None, - doc_yaml: str | None = None, - doc_rst: str | None = None, + pool_slots: int | None | ArgNotSet = NOTSET, + execution_timeout: timedelta | None | ArgNotSet = NOTSET, + max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, + retry_delay: timedelta | float | None | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | None | ArgNotSet = NOTSET, + priority_weight: int | None | ArgNotSet = NOTSET, + weight_rule: str | None | ArgNotSet = NOTSET, + sla: timedelta | None | ArgNotSet = NOTSET, + max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, + on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, + run_as_user: str | None | ArgNotSet = NOTSET, + executor_config: dict | None | ArgNotSet = NOTSET, + inlets: Any | None | ArgNotSet = NOTSET, + outlets: Any | None | ArgNotSet = NOTSET, + doc: str | None | ArgNotSet = NOTSET, + doc_md: str | None | ArgNotSet = NOTSET, + doc_json: str | None | ArgNotSet = NOTSET, + doc_yaml: str | None | ArgNotSet = NOTSET, + doc_rst: str | None | ArgNotSet = NOTSET, **kwargs, ) -> OperatorPartial: from airflow.models.dag import DagContext @@ -290,52 +291,53 @@ def partial( "doc_yaml": doc_yaml, } + DEFAULT_VALUES: dict[str, Any] = { + "task_id": None, + "start_date": None, + "end_date": None, + "owner": DEFAULT_OWNER, + "email": None, + "trigger_rule": DEFAULT_TRIGGER_RULE, + "depends_on_past": False, + "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + "wait_for_downstream": False, + "retries": DEFAULT_RETRIES, + "queue": DEFAULT_QUEUE, + "pool_slots": DEFAULT_POOL_SLOTS, + "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, + "max_retry_delay": None, + "retry_delay": DEFAULT_RETRY_DELAY, + "retry_exponential_backoff": False, + "priority_weight": DEFAULT_PRIORITY_WEIGHT, + "weight_rule": DEFAULT_WEIGHT_RULE, + "sla": None, + "max_active_tis_per_dag": None, + "on_execute_callback": None, + "on_failure_callback": None, + "on_retry_callback": None, + "on_success_callback": None, + "run_as_user": None, + "executor_config": None, + "inlets": [], + "outlets": [], + "resources": None, + "doc": None, + "doc_json": None, + "doc_md": None, + "doc_rst": None, + "doc_yaml": None, + } + # Override None kwargs by dag default values for k, v in default_partial_kwargs.items(): - if partial_kwargs.get(k) is None: + if partial_kwargs.get(k) is NOTSET: partial_kwargs[k] = v # Override None kwargs which don't have a dag default value by Airflow default value - partial_kwargs["owner"] = partial_kwargs["owner"] or DEFAULT_OWNER - partial_kwargs["trigger_rule"] = partial_kwargs["trigger_rule"] or DEFAULT_TRIGGER_RULE - partial_kwargs["depends_on_past"] = ( - partial_kwargs["depends_on_past"] if partial_kwargs["depends_on_past"] is not None else False - ) - partial_kwargs["ignore_first_depends_on_past"] = ( - partial_kwargs["ignore_first_depends_on_past"] - if partial_kwargs["ignore_first_depends_on_past"] is not None - else DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST - ) - partial_kwargs["wait_for_past_depends_before_skipping"] = ( - partial_kwargs["wait_for_past_depends_before_skipping"] - if partial_kwargs["wait_for_past_depends_before_skipping"] is not None - else DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING - ) - partial_kwargs["wait_for_downstream"] = ( - partial_kwargs["wait_for_downstream"] if partial_kwargs["wait_for_downstream"] is not None else False - ) - partial_kwargs["retries"] = ( - partial_kwargs["retries"] if partial_kwargs["retries"] is not None else DEFAULT_RETRIES - ) - partial_kwargs["queue"] = partial_kwargs["queue"] or DEFAULT_QUEUE - partial_kwargs["pool_slots"] = ( - partial_kwargs["pool_slots"] if partial_kwargs["pool_slots"] is not None else DEFAULT_POOL_SLOTS - ) - partial_kwargs["execution_timeout"] = ( - partial_kwargs["execution_timeout"] or DEFAULT_TASK_EXECUTION_TIMEOUT - ) - partial_kwargs["retry_delay"] = ( - partial_kwargs["retry_delay"] if partial_kwargs["retry_delay"] is not None else DEFAULT_RETRY_DELAY - ) - partial_kwargs["retry_exponential_backoff"] = ( - partial_kwargs["retry_exponential_backoff"] if partial_kwargs["retry_exponential_backoff"] else False - ) - partial_kwargs["priority_weight"] = ( - partial_kwargs["priority_weight"] if partial_kwargs["priority_weight"] else DEFAULT_PRIORITY_WEIGHT - ) - partial_kwargs["weight_rule"] = partial_kwargs["weight_rule"] or DEFAULT_WEIGHT_RULE - partial_kwargs["inlets"] = partial_kwargs["inlets"] or [] - partial_kwargs["outlets"] = partial_kwargs["outlets"] or [] + partial_kwargs = { + k: v if v is not NOTSET else DEFAULT_VALUES.get(k, v) for k, v in partial_kwargs.items() + } # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). if "task_concurrency" in kwargs: # Reject deprecated option. From 01997bd5c092d59c610dcc6940bd451cb15f6f1d Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 31 Mar 2023 23:10:16 +0200 Subject: [PATCH 08/10] update partial typing by removing None type from some args and set NotSet for all args --- airflow/models/baseoperator.py | 55 +++++++++++----------------------- 1 file changed, 18 insertions(+), 37 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5c7fb90750611..ece26f6ffc0fb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -191,27 +191,27 @@ def partial( task_id: str, dag: DAG | None = None, task_group: TaskGroup | None = None, - start_date: datetime | None | ArgNotSet = NOTSET, - end_date: datetime | None | ArgNotSet = NOTSET, - owner: str | None | ArgNotSet = NOTSET, + start_date: datetime | ArgNotSet = NOTSET, + end_date: datetime | ArgNotSet = NOTSET, + owner: str | ArgNotSet = NOTSET, email: None | str | Iterable[str] | ArgNotSet = NOTSET, params: dict | None = None, resources: dict[str, Any] | None | ArgNotSet = NOTSET, - trigger_rule: str | None | ArgNotSet = NOTSET, - depends_on_past: bool | None | ArgNotSet = NOTSET, - ignore_first_depends_on_past: bool | None | ArgNotSet = NOTSET, - wait_for_past_depends_before_skipping: bool | None | ArgNotSet = NOTSET, - wait_for_downstream: bool | None | ArgNotSet = NOTSET, + trigger_rule: str | ArgNotSet = NOTSET, + depends_on_past: bool | ArgNotSet = NOTSET, + ignore_first_depends_on_past: bool | ArgNotSet = NOTSET, + wait_for_past_depends_before_skipping: bool | ArgNotSet = NOTSET, + wait_for_downstream: bool | ArgNotSet = NOTSET, retries: int | None | ArgNotSet = NOTSET, - queue: str | None | ArgNotSet = NOTSET, - pool: str | None = None, - pool_slots: int | None | ArgNotSet = NOTSET, + queue: str | ArgNotSet = NOTSET, + pool: str | ArgNotSet = NOTSET, + pool_slots: int | ArgNotSet = NOTSET, execution_timeout: timedelta | None | ArgNotSet = NOTSET, max_retry_delay: None | timedelta | float | ArgNotSet = NOTSET, - retry_delay: timedelta | float | None | ArgNotSet = NOTSET, - retry_exponential_backoff: bool | None | ArgNotSet = NOTSET, - priority_weight: int | None | ArgNotSet = NOTSET, - weight_rule: str | None | ArgNotSet = NOTSET, + retry_delay: timedelta | float | ArgNotSet = NOTSET, + retry_exponential_backoff: bool | ArgNotSet = NOTSET, + priority_weight: int | ArgNotSet = NOTSET, + weight_rule: str | ArgNotSet = NOTSET, sla: timedelta | None | ArgNotSet = NOTSET, max_active_tis_per_dag: int | None | ArgNotSet = NOTSET, on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] | ArgNotSet = NOTSET, @@ -292,11 +292,7 @@ def partial( } DEFAULT_VALUES: dict[str, Any] = { - "task_id": None, - "start_date": None, - "end_date": None, "owner": DEFAULT_OWNER, - "email": None, "trigger_rule": DEFAULT_TRIGGER_RULE, "depends_on_past": False, "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -306,37 +302,22 @@ def partial( "queue": DEFAULT_QUEUE, "pool_slots": DEFAULT_POOL_SLOTS, "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, - "max_retry_delay": None, "retry_delay": DEFAULT_RETRY_DELAY, "retry_exponential_backoff": False, "priority_weight": DEFAULT_PRIORITY_WEIGHT, "weight_rule": DEFAULT_WEIGHT_RULE, - "sla": None, - "max_active_tis_per_dag": None, - "on_execute_callback": None, - "on_failure_callback": None, - "on_retry_callback": None, - "on_success_callback": None, - "run_as_user": None, - "executor_config": None, "inlets": [], "outlets": [], - "resources": None, - "doc": None, - "doc_json": None, - "doc_md": None, - "doc_rst": None, - "doc_yaml": None, } - # Override None kwargs by dag default values + # Override NOTSET kwargs by dag default values for k, v in default_partial_kwargs.items(): if partial_kwargs.get(k) is NOTSET: partial_kwargs[k] = v - # Override None kwargs which don't have a dag default value by Airflow default value + # Override NOTSET kwargs which don't have a dag default value by Airflow default value or None partial_kwargs = { - k: v if v is not NOTSET else DEFAULT_VALUES.get(k, v) for k, v in partial_kwargs.items() + k: v if v is not NOTSET else DEFAULT_VALUES.get(k, None) for k, v in partial_kwargs.items() } # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). From cd142b6cc5d87c5b84ef357e9d40b5fbb917e613 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 10 Apr 2023 21:22:27 +0800 Subject: [PATCH 09/10] Tweak kwarg merging slightly This should improve iteration a bit, I think. --- airflow/models/baseoperator.py | 53 ++++++++++++++++------------------ 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 4f35ab8a944e9..9e6f5a49f31e0 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -185,6 +185,26 @@ def partial(**kwargs): return self.class_method.__get__(cls, cls) +_PARTIAL_DEFAULTS = { + "owner": DEFAULT_OWNER, + "trigger_rule": DEFAULT_TRIGGER_RULE, + "depends_on_past": False, + "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, + "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, + "wait_for_downstream": False, + "retries": DEFAULT_RETRIES, + "queue": DEFAULT_QUEUE, + "pool_slots": DEFAULT_POOL_SLOTS, + "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, + "retry_delay": DEFAULT_RETRY_DELAY, + "retry_exponential_backoff": False, + "priority_weight": DEFAULT_PRIORITY_WEIGHT, + "weight_rule": DEFAULT_WEIGHT_RULE, + "inlets": (), + "outlets": (), +} + + # This is what handles the actual mapping. def partial( operator_class: type[BaseOperator], @@ -242,7 +262,7 @@ def partial( task_id = task_group.child_id(task_id) # Merge DAG and task group level defaults into user-supplied values. - default_partial_kwargs, partial_params = get_merged_defaults( + dag_default_args, partial_params = get_merged_defaults( dag=dag, task_group=task_group, task_params=params, @@ -292,34 +312,11 @@ def partial( "doc_yaml": doc_yaml, } - DEFAULT_VALUES: dict[str, Any] = { - "owner": DEFAULT_OWNER, - "trigger_rule": DEFAULT_TRIGGER_RULE, - "depends_on_past": False, - "ignore_first_depends_on_past": DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, - "wait_for_past_depends_before_skipping": DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, - "wait_for_downstream": False, - "retries": DEFAULT_RETRIES, - "queue": DEFAULT_QUEUE, - "pool_slots": DEFAULT_POOL_SLOTS, - "execution_timeout": DEFAULT_TASK_EXECUTION_TIMEOUT, - "retry_delay": DEFAULT_RETRY_DELAY, - "retry_exponential_backoff": False, - "priority_weight": DEFAULT_PRIORITY_WEIGHT, - "weight_rule": DEFAULT_WEIGHT_RULE, - "inlets": [], - "outlets": [], - } + # Inject DAG-level default args into args provided to this function. + partial_kwargs.update((k, v) for k, v in dag_default_args.items() if partial_kwargs.get(k) is NOTSET) - # Override NOTSET kwargs by dag default values - for k, v in default_partial_kwargs.items(): - if partial_kwargs.get(k) is NOTSET: - partial_kwargs[k] = v - - # Override NOTSET kwargs which don't have a dag default value by Airflow default value or None - partial_kwargs = { - k: v if v is not NOTSET else DEFAULT_VALUES.get(k, None) for k, v in partial_kwargs.items() - } + # Fill fields not provided by the user with default values. + partial_kwargs = {k: _PARTIAL_DEFAULTS.get(k) if v is NOTSET else v for k, v in partial_kwargs.items()} # Post-process arguments. Should be kept in sync with _TaskDecorator.expand(). if "task_concurrency" in kwargs: # Reject deprecated option. From 0892f25f8824288d8a78f80e08da41932de31107 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 14 Apr 2023 02:46:08 +0200 Subject: [PATCH 10/10] Fix unit tests --- airflow/models/baseoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 9e6f5a49f31e0..19173d49749fa 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -200,8 +200,8 @@ def partial(**kwargs): "retry_exponential_backoff": False, "priority_weight": DEFAULT_PRIORITY_WEIGHT, "weight_rule": DEFAULT_WEIGHT_RULE, - "inlets": (), - "outlets": (), + "inlets": [], + "outlets": [], }