diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 3cf114af0bfc7..a0ff1a2466bdb 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -229,8 +229,9 @@ exclude = [ [tool.hatch.build.targets.sdist.force-include] "../shared/configuration/src/airflow_shared/configuration" = "src/airflow/_shared/configuration" -"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/_shared/module_loading" +"../shared/dagnode/src/airflow_shared/dagnode" = "src/airflow/_shared/dagnode" "../shared/logging/src/airflow_shared/logging" = "src/airflow/_shared/logging" +"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/_shared/module_loading" "../shared/observability/src/airflow_shared/observability" = "src/airflow/_shared/observability" "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/_shared/secrets_masker" @@ -303,10 +304,11 @@ apache-airflow-devel-common = { workspace = true } [tool.airflow] shared_distributions = [ "apache-airflow-shared-configuration", + "apache-airflow-shared-dagnode", "apache-airflow-shared-logging", "apache-airflow-shared-module-loading", + "apache-airflow-shared-observability", "apache-airflow-shared-secrets-backend", "apache-airflow-shared-secrets-masker", "apache-airflow-shared-timezones", - "apache-airflow-shared-observability", ] diff --git a/airflow-core/src/airflow/_shared/dagnode b/airflow-core/src/airflow/_shared/dagnode new file mode 120000 index 0000000000000..ad88febb9c031 --- /dev/null +++ b/airflow-core/src/airflow/_shared/dagnode @@ -0,0 +1 @@ +../../../../shared/dagnode/src/airflow_shared/dagnode \ No newline at end of file diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index c45f585f4eda3..aa0e6ce5c5c9b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -77,8 +77,6 @@ if TYPE_CHECKING: from sqlalchemy.sql.dml import Update - from airflow.models.expandinput import SchedulerExpandInput - router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( @@ -314,7 +312,7 @@ def _get_upstream_map_indexes( except NotFullyPopulated: # Second try: resolve XCom for correct count try: - expand_input = cast("SchedulerExpandInput", upstream_mapped_group._expand_input) + expand_input = upstream_mapped_group._expand_input mapped_ti_count = expand_input.get_total_map_length(ti.run_id, session=session) except NotFullyPopulated: # For these trigger rules, unresolved map indexes are acceptable. diff --git a/airflow-core/src/airflow/models/mappedoperator.py b/airflow-core/src/airflow/models/mappedoperator.py index c227bbfc54b5d..758149875f0af 100644 --- a/airflow-core/src/airflow/models/mappedoperator.py +++ b/airflow-core/src/airflow/models/mappedoperator.py @@ -31,8 +31,8 @@ from airflow.exceptions import AirflowException, NotMapped from airflow.sdk import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_RETRY_DELAY_MULTIPLIER -from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator +from airflow.serialization.definitions.node import DAGNode from airflow.serialization.definitions.param import SerializedParamsDict from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup from airflow.serialization.enums import DagAttributeTypes @@ -48,6 +48,7 @@ from airflow.models import TaskInstance from airflow.models.expandinput import SchedulerExpandInput from airflow.sdk import BaseOperatorLink, Context + from airflow.sdk.definitions._internal.node import DAGNode as TaskSDKDAGNode from airflow.sdk.definitions.operator_resources import Resources from airflow.serialization.serialized_objects import SerializedDAG from airflow.task.trigger_rule import TriggerRule @@ -83,7 +84,6 @@ def is_mapped(obj: Operator | SerializedTaskGroup) -> TypeGuard[MappedOperator | getstate_setstate=False, repr=False, ) -# TODO (GH-52141): Duplicate DAGNode in the scheduler. class MappedOperator(DAGNode): """Object representing a mapped operator in a DAG.""" @@ -110,11 +110,6 @@ class MappedOperator(DAGNode): start_from_trigger: bool = False _needs_expansion: bool = True - # TODO (GH-52141): These should contain serialized containers, but currently - # this class inherits from an SDK one. - dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment] - task_group: SerializedTaskGroup = attrs.field(init=False) # type: ignore[assignment] - doc: str | None = attrs.field(init=False) doc_json: str | None = attrs.field(init=False) doc_rst: str | None = attrs.field(init=False) @@ -503,7 +498,7 @@ def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | N @functools.singledispatch -def get_mapped_ti_count(task: DAGNode, run_id: str, *, session: Session) -> int: +def get_mapped_ti_count(task: DAGNode | TaskSDKDAGNode, run_id: str, *, session: Session) -> int: raise NotImplementedError(f"Not implemented for {type(task)}") diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 0814f438688ac..ac854ac162c4d 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -28,7 +28,7 @@ from collections.abc import Collection, Iterable from datetime import datetime, timedelta from functools import cache -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from urllib.parse import quote import attrs @@ -2332,10 +2332,7 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]] # Treat it as a normal task instead. _visit_relevant_relatives_for_normal([task_id]) continue - # TODO (GH-52141): This should return scheduler operator types, but - # currently get_flat_relatives is inherited from SDK DAGNode. - relatives = cast("Iterable[Operator]", task.get_flat_relatives(upstream=direction == "upstream")) - for relative in relatives: + for relative in task.get_flat_relatives(upstream=direction == "upstream"): if relative.task_id in visited: continue relative_map_indexes = _get_relevant_map_indexes( diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py b/airflow-core/src/airflow/serialization/definitions/dag.py index f6556dcfd1c0e..238d2d748dcac 100644 --- a/airflow-core/src/airflow/serialization/definitions/dag.py +++ b/airflow-core/src/airflow/serialization/definitions/dag.py @@ -284,9 +284,7 @@ def is_task(obj) -> TypeIs[SerializedOperator]: direct_upstreams: list[SerializedOperator] = [] if include_direct_upstream: for t in itertools.chain(matched_tasks, also_include): - # TODO (GH-52141): This should return scheduler types, but currently we reuse SDK DAGNode. - upstream = (u for u in cast("Iterable[SerializedOperator]", t.upstream_list) if is_task(u)) - direct_upstreams.extend(upstream) + direct_upstreams.extend(u for u in t.upstream_list if is_task(u)) # Make sure to not recursively deepcopy the dag or task_group while copying the task. # task_group is reset later diff --git a/airflow-core/src/airflow/serialization/definitions/node.py b/airflow-core/src/airflow/serialization/definitions/node.py new file mode 100644 index 0000000000000..b17e46234ab59 --- /dev/null +++ b/airflow-core/src/airflow/serialization/definitions/node.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING + +from airflow._shared.dagnode.node import GenericDAGNode + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import TypeAlias + + from airflow.models.mappedoperator import MappedOperator + from airflow.serialization.definitions.taskgroup import SerializedTaskGroup # noqa: F401 + from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG # noqa: F401 + + Operator: TypeAlias = SerializedBaseOperator | MappedOperator + + +class DAGNode(GenericDAGNode["SerializedDAG", "Operator", "SerializedTaskGroup"], metaclass=abc.ABCMeta): + """ + Base class for a node in the graph of a workflow. + + A node may be an operator or task group, either mapped or unmapped. + """ + + @property + @abc.abstractmethod + def roots(self) -> Sequence[DAGNode]: + raise NotImplementedError() + + @property + @abc.abstractmethod + def leaves(self) -> Sequence[DAGNode]: + raise NotImplementedError() diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index 3dcb62aa30f4d..c127353bcfe13 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -27,7 +27,7 @@ import attrs import methodtools -from airflow.sdk.definitions._internal.node import DAGNode +from airflow.serialization.definitions.node import DAGNode if TYPE_CHECKING: from collections.abc import Generator, Iterator @@ -45,8 +45,7 @@ class SerializedTaskGroup(DAGNode): group_display_name: str | None = attrs.field() prefix_group_id: bool = attrs.field() parent_group: SerializedTaskGroup | None = attrs.field() - # TODO (GH-52141): Replace DAGNode dependency. - dag: SerializedDAG = attrs.field() # type: ignore[assignment] + dag: SerializedDAG = attrs.field() tooltip: str = attrs.field() default_args: dict[str, Any] = attrs.field(factory=dict) diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index ce136f998cee8..2f6d4b362a89d 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -53,7 +53,6 @@ from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler? -from airflow.sdk.definitions._internal.node import DAGNode from airflow.sdk.definitions.asset import ( AssetAliasEvent, AssetAliasUniqueKey, @@ -76,6 +75,7 @@ SerializedAssetUniqueKey, ) from airflow.serialization.definitions.dag import SerializedDAG +from airflow.serialization.definitions.node import DAGNode from airflow.serialization.definitions.param import SerializedParam, SerializedParamsDict from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup from airflow.serialization.encoders import ( @@ -118,6 +118,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004 from airflow.sdk import BaseOperatorLink + from airflow.sdk.definitions._internal.node import DAGNode as SDKDAGNode from airflow.serialization.json_schema import Validator from airflow.task.trigger_rule import TriggerRule from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -1022,7 +1023,6 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]: yield from tt.asset_condition.iter_dag_dependencies(source="", target=dag.dag_id) -# TODO (GH-52141): Duplicate DAGNode in the scheduler. class SerializedBaseOperator(DAGNode, BaseSerialization): """ A JSON serializable representation of operator. @@ -1052,10 +1052,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): _task_display_name: str | None _weight_rule: str | PriorityWeightStrategy = "downstream" - # TODO (GH-52141): These should contain serialized containers, but currently - # this class inherits from an SDK one. - dag: SerializedDAG | None = None # type: ignore[assignment] - task_group: SerializedTaskGroup | None = None # type: ignore[assignment] + dag: SerializedDAG | None = None + task_group: SerializedTaskGroup | None = None allow_nested_operators: bool = True depends_on_past: bool = False @@ -1159,8 +1157,7 @@ def __repr__(self) -> str: def node_id(self) -> str: return self.task_id - # TODO (GH-52141): Replace DAGNode with a scheduler type. - def get_dag(self) -> SerializedDAG | None: # type: ignore[override] + def get_dag(self) -> SerializedDAG | None: return self.dag @property @@ -1680,7 +1677,7 @@ def _matches_client_defaults(cls, var: Any, attrname: str) -> bool: return False @classmethod - def _is_excluded(cls, var: Any, attrname: str, op: DAGNode): + def _is_excluded(cls, var: Any, attrname: str, op: SDKDAGNode) -> bool: """ Determine if a variable is excluded from the serialized object. diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py index 3c0fc3be5c0c2..2effd1fef6eb7 100644 --- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py @@ -620,12 +620,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]: if not task.is_teardown: # a teardown cannot have any indirect setups - relevant_setups: dict[str, MappedOperator | SerializedBaseOperator] = { - # TODO (GH-52141): This should return scheduler types, but - # currently we reuse logic in SDK DAGNode. - t.task_id: t # type: ignore[misc] - for t in task.get_upstreams_only_setups() - } + relevant_setups = {t.task_id: t for t in task.get_upstreams_only_setups()} if relevant_setups: for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups): yield status diff --git a/airflow-core/src/airflow/utils/dag_edges.py b/airflow-core/src/airflow/utils/dag_edges.py index 94c6069f91b02..1f3c0fbd2547b 100644 --- a/airflow-core/src/airflow/utils/dag_edges.py +++ b/airflow-core/src/airflow/utils/dag_edges.py @@ -23,8 +23,6 @@ from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG if TYPE_CHECKING: - from collections.abc import Iterable - from airflow.sdk import DAG Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -118,11 +116,7 @@ def collect_edges(task_group): while tasks_to_trace: tasks_to_trace_next: list[Operator] = [] for task in tasks_to_trace: - # TODO (GH-52141): downstream_list on DAGNode needs to be able to - # return scheduler types when used in scheduler, but SDK types when - # used at runtime. This means DAGNode needs to be rewritten as a - # generic class. - for child in cast("Iterable[Operator]", task.downstream_list): + for child in task.downstream_list: edge = (task.task_id, child.task_id) if task.is_setup and child.is_teardown: setup_teardown_edges.add(edge) diff --git a/airflow-core/src/airflow/utils/dot_renderer.py b/airflow-core/src/airflow/utils/dot_renderer.py index 586789f1722e2..d0802972980c2 100644 --- a/airflow-core/src/airflow/utils/dot_renderer.py +++ b/airflow-core/src/airflow/utils/dot_renderer.py @@ -36,7 +36,6 @@ import graphviz from airflow.models import TaskInstance - from airflow.models.taskmixin import DependencyMixin from airflow.serialization.dag_dependency import DagDependency else: try: @@ -136,7 +135,7 @@ def _draw_task_group( def _draw_nodes( - node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str | None] | None + node: object, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str | None] | None ) -> None: """Draw the node and its children on the given parent_graph recursively.""" if isinstance(node, (BaseOperator, MappedOperator, SerializedBaseOperator, SerializedMappedOperator)): diff --git a/pyproject.toml b/pyproject.toml index ba650152d6015..423326157a4a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1346,6 +1346,7 @@ apache-airflow-kubernetes-tests = { workspace = true } apache-airflow-providers = { workspace = true } apache-aurflow-docker-stack = { workspace = true } apache-airflow-shared-configuration = { workspace = true } +apache-airflow-shared-dagnode = { workspace = true } apache-airflow-shared-logging = { workspace = true } apache-airflow-shared-module-loading = { workspace = true } apache-airflow-shared-secrets-backend = { workspace = true } diff --git a/shared/dagnode/pyproject.toml b/shared/dagnode/pyproject.toml new file mode 100644 index 0000000000000..d75d1cf3c54af --- /dev/null +++ b/shared/dagnode/pyproject.toml @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[project] +name = "apache-airflow-shared-dagnode" +description = "Shared DAGNode logic for Airflow distributions" +version = "0.0" +classifiers = [ + "Private :: Do Not Upload", +] + +dependencies = [ + "structlog>=25.4.0", +] + +[dependency-groups] +dev = [ + "apache-airflow-devel-common", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/airflow_shared"] + +[tool.ruff] +extend = "../../pyproject.toml" +src = ["src"] + +[tool.ruff.lint.per-file-ignores] +# Ignore Doc rules et al for anything outside of tests +"!src/*" = ["D", "S101", "TRY002"] diff --git a/shared/dagnode/src/airflow_shared/dagnode/__init__.py b/shared/dagnode/src/airflow_shared/dagnode/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/shared/dagnode/src/airflow_shared/dagnode/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/shared/dagnode/src/airflow_shared/dagnode/node.py b/shared/dagnode/src/airflow_shared/dagnode/node.py new file mode 100644 index 0000000000000..2f4504818e273 --- /dev/null +++ b/shared/dagnode/src/airflow_shared/dagnode/node.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, TypeVar + +import structlog + +if TYPE_CHECKING: + from collections.abc import Collection, Iterable + + from ..logging.types import Logger + +Dag = TypeVar("Dag") +Task = TypeVar("Task") +TaskGroup = TypeVar("TaskGroup") + + +class GenericDAGNode(Generic[Dag, Task, TaskGroup]): + """ + Generic class for a node in the graph of a workflow. + + A node may be an operator or task group, either mapped or unmapped. + """ + + dag: Dag | None + task_group: TaskGroup | None + upstream_task_ids: set[str] + downstream_task_ids: set[str] + + _log_config_logger_name: str | None = None + _logger_name: str | None = None + _cached_logger: Logger | None = None + + def __init__(self): + super().__init__() + self.upstream_task_ids = set() + self.downstream_task_ids = set() + + @property + def log(self) -> Logger: + if self._cached_logger is not None: + return self._cached_logger + + typ = type(self) + + logger_name: str = ( + self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" + ) + + if self._log_config_logger_name: + logger_name = ( + f"{self._log_config_logger_name}.{logger_name}" + if logger_name + else self._log_config_logger_name + ) + + self._cached_logger = structlog.get_logger(logger_name) + return self._cached_logger + + @property + def dag_id(self) -> str: + if self.dag: + return self.dag.dag_id + return "_in_memory_dag_" + + @property + def node_id(self) -> str: + raise NotImplementedError() + + @property + def label(self) -> str | None: + tg = self.task_group + if tg and tg.node_id and tg.prefix_group_id: + # "task_group_id.task_id" -> "task_id" + return self.node_id[len(tg.node_id) + 1 :] + return self.node_id + + @property + def upstream_list(self) -> Iterable[Task]: + if not self.dag: + raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") + return [self.dag.get_task(tid) for tid in self.upstream_task_ids] + + @property + def downstream_list(self) -> Iterable[Task]: + if not self.dag: + raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") + return [self.dag.get_task(tid) for tid in self.downstream_task_ids] + + def has_dag(self) -> bool: + return self.dag is not None + + def get_dag(self) -> Dag | None: + return self.dag + + def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: + """Get set of the direct relative ids to the current task, upstream or downstream.""" + if upstream: + return self.upstream_task_ids + return self.downstream_task_ids + + def get_direct_relatives(self, upstream: bool = False) -> Iterable[Task]: + """Get list of the direct relatives to the current task, upstream or downstream.""" + if upstream: + return self.upstream_list + return self.downstream_list + + def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: + """ + Get a flat set of relative IDs, upstream or downstream. + + Will recurse each relative found in the direction specified. + + :param upstream: Whether to look for upstream or downstream relatives. + """ + dag = self.get_dag() + if not dag: + return set() + + relatives: set[str] = set() + + # This is intentionally implemented as a loop, instead of calling + # get_direct_relative_ids() recursively, since Python has significant + # limitation on stack level, and a recursive implementation can blow up + # if a DAG contains very long routes. + task_ids_to_trace = self.get_direct_relative_ids(upstream) + while task_ids_to_trace: + task_ids_to_trace_next: set[str] = set() + for task_id in task_ids_to_trace: + if task_id in relatives: + continue + task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) + relatives.add(task_id) + task_ids_to_trace = task_ids_to_trace_next + + return relatives + + def get_flat_relatives(self, upstream: bool = False) -> Collection[Task]: + """Get a flat list of relatives, either upstream or downstream.""" + dag = self.get_dag() + if not dag: + return set() + return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] + + def get_upstreams_follow_setups(self) -> Iterable[Task]: + """All upstreams and, for each upstream setup, its respective teardowns.""" + for task in self.get_flat_relatives(upstream=True): + yield task + if task.is_setup: + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Task]: + """ + Only *relevant* upstream setups and their teardowns. + + This method is meant to be used when we are clearing the task (non-upstream) and we need + to add in the *relevant* setups and their teardowns. + + Relevant in this case means, the setup has a teardown that is downstream of ``self``, + or the setup has no teardowns. + """ + downstream_teardown_ids = { + x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown + } + for task in self.get_flat_relatives(upstream=True): + if not task.is_setup: + continue + has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) + # if task has no teardowns or has teardowns downstream of self + if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): + yield task + for t in task.downstream_list: + if t.is_teardown and t != self: + yield t + + def get_upstreams_only_setups(self) -> Iterable[Task]: + """ + Return relevant upstream setups. + + This method is meant to be used when we are checking task dependencies where we need + to wait for all the upstream setups to complete before we can run the task. + """ + for task in self.get_upstreams_only_setups_and_teardowns(): + if task.is_setup: + yield task diff --git a/shared/dagnode/tests/__init__.py b/shared/dagnode/tests/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/shared/dagnode/tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/shared/dagnode/tests/dagnode/__init__.py b/shared/dagnode/tests/dagnode/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/shared/dagnode/tests/dagnode/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/shared/dagnode/tests/dagnode/test_node.py b/shared/dagnode/tests/dagnode/test_node.py new file mode 100644 index 0000000000000..4259ca7555ffe --- /dev/null +++ b/shared/dagnode/tests/dagnode/test_node.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import attrs +import pytest + +from airflow_shared.dagnode.node import GenericDAGNode + + +class Task: + """Task type for tests.""" + + +@attrs.define +class TaskGroup: + """Task group type for tests.""" + + node_id: str = attrs.field(init=False, default="test_group_id") + prefix_group_id: str + + +class Dag: + """Dag type for tests.""" + + dag_id = "test_dag_id" + + +class ConcreteDAGNode(GenericDAGNode[Dag, Task, TaskGroup]): + """Concrete DAGNode variant for tests.""" + + dag = None + task_group = None + + @property + def node_id(self) -> str: + return "test_group_id.test_node_id" + + +class TestDAGNode: + @pytest.fixture + def node(self): + return ConcreteDAGNode() + + def test_log(self, node: ConcreteDAGNode) -> None: + assert node._cached_logger is None + with mock.patch("structlog.get_logger") as mock_get_logger: + log = node.log + assert log is node._cached_logger + assert mock_get_logger.mock_calls == [mock.call("tests.dagnode.test_node.ConcreteDAGNode")] + + def test_dag_id(self, node: ConcreteDAGNode) -> None: + assert node.dag is None + assert node.dag_id == "_in_memory_dag_" + node.dag = Dag() + assert node.dag_id == "test_dag_id" + + @pytest.mark.parametrize( + ("prefix_group_id", "expected_label"), + [(True, "test_node_id"), (False, "test_group_id.test_node_id")], + ) + def test_label(self, node: ConcreteDAGNode, prefix_group_id: bool, expected_label: str) -> None: + assert node.task_group is None + assert node.label == "test_group_id.test_node_id" + node.task_group = TaskGroup(prefix_group_id) + assert node.label == expected_label diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml index 98d916576343a..7c26f1babb58a 100644 --- a/task-sdk/pyproject.toml +++ b/task-sdk/pyproject.toml @@ -116,8 +116,9 @@ path = "src/airflow/sdk/__init__.py" [tool.hatch.build.targets.sdist.force-include] "../shared/configuration/src/airflow_shared/configuration" = "src/airflow/sdk/_shared/configuration" -"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/sdk/_shared/module_loading" +"../shared/dagnode/src/airflow_shared/dagnode" = "src/airflow/sdk/_shared/dagnode" "../shared/logging/src/airflow_shared/logging" = "src/airflow/sdk/_shared/logging" +"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/sdk/_shared/module_loading" "../shared/observability/src/airflow_shared/observability" = "src/airflow/_shared/observability" "../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/sdk/_shared/secrets_backend" "../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/sdk/_shared/secrets_masker" @@ -264,6 +265,7 @@ tmp_path_retention_policy = "failed" [tool.airflow] shared_distributions = [ "apache-airflow-shared-configuration", + "apache-airflow-shared-dagnode", "apache-airflow-shared-logging", "apache-airflow-shared-module-loading", "apache-airflow-shared-secrets-backend", diff --git a/task-sdk/src/airflow/sdk/_shared/dagnode b/task-sdk/src/airflow/sdk/_shared/dagnode new file mode 120000 index 0000000000000..9455ba69b087b --- /dev/null +++ b/task-sdk/src/airflow/sdk/_shared/dagnode @@ -0,0 +1 @@ +../../../../../shared/dagnode/src/airflow_shared/dagnode \ No newline at end of file diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py index d14d629915987..e186ef97e64dd 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -31,8 +31,6 @@ Operator: TypeAlias = BaseOperator | MappedOperator -# TODO: Should this all just live on DAGNode? - class DependencyMixin: """Mixing implementing common dependency setting methods like >> and <<.""" diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index 86979e442cd5e..b2cb651efe1a8 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -19,19 +19,18 @@ import re from abc import ABCMeta, abstractmethod -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Sequence from datetime import datetime from typing import TYPE_CHECKING, Any -import structlog - +from airflow.sdk._shared.dagnode.node import GenericDAGNode from airflow.sdk.definitions._internal.mixins import DependencyMixin if TYPE_CHECKING: from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier - from airflow.sdk.definitions.taskgroup import TaskGroup - from airflow.sdk.types import Logger, Operator + from airflow.sdk.definitions.taskgroup import TaskGroup # noqa: F401 + from airflow.sdk.types import Operator # noqa: F401 from airflow.serialization.enums import DagAttributeTypes @@ -65,84 +64,15 @@ def validate_group_key(k: str, max_length: int = 200): ) -class DAGNode(DependencyMixin, metaclass=ABCMeta): +class DAGNode(GenericDAGNode["DAG", "Operator", "TaskGroup"], DependencyMixin, metaclass=ABCMeta): """ A base class for a node in the graph of a workflow. A node may be an Operator or a Task Group, either mapped or unmapped. """ - dag: DAG | None - task_group: TaskGroup | None - """The task_group that contains this node""" start_date: datetime | None end_date: datetime | None - upstream_task_ids: set[str] - downstream_task_ids: set[str] - - _log_config_logger_name: str | None = None - _logger_name: str | None = None - _cached_logger: Logger | None = None - - def __init__(self): - self.upstream_task_ids = set() - self.downstream_task_ids = set() - super().__init__() - - def get_dag(self) -> DAG | None: - return self.dag - - @property - @abstractmethod - def node_id(self) -> str: - raise NotImplementedError() - - @property - def label(self) -> str | None: - tg = self.task_group - if tg and tg.node_id and tg.prefix_group_id: - # "task_group_id.task_id" -> "task_id" - return self.node_id[len(tg.node_id) + 1 :] - return self.node_id - - def has_dag(self) -> bool: - return self.dag is not None - - @property - def dag_id(self) -> str: - """Returns dag id if it has one or an adhoc/meaningless ID.""" - if self.dag: - return self.dag.dag_id - return "_in_memory_dag_" - - @property - def log(self) -> Logger: - """ - Get a logger for this node. - - The logger name is determined by: - 1. Using _logger_name if provided - 2. Otherwise, using the class's module and qualified name - 3. Prefixing with _log_config_logger_name if set - """ - if self._cached_logger is not None: - return self._cached_logger - - typ = type(self) - - logger_name: str = ( - self._logger_name if self._logger_name is not None else f"{typ.__module__}.{typ.__qualname__}" - ) - - if self._log_config_logger_name: - logger_name = ( - f"{self._log_config_logger_name}.{logger_name}" - if logger_name - else self._log_config_logger_name - ) - - self._cached_logger = structlog.get_logger(logger_name) - return self._cached_logger @property @abstractmethod @@ -227,113 +157,6 @@ def set_upstream( """Set a node (or nodes) to be directly upstream from the current node.""" self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier) - @property - def downstream_list(self) -> Iterable[Operator]: - """List of nodes directly downstream.""" - if not self.dag: - raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") - return [self.dag.get_task(tid) for tid in self.downstream_task_ids] - - @property - def upstream_list(self) -> Iterable[Operator]: - """List of nodes directly upstream.""" - if not self.dag: - raise RuntimeError(f"Operator {self} has not been assigned to a Dag yet") - return [self.dag.get_task(tid) for tid in self.upstream_task_ids] - - def get_direct_relative_ids(self, upstream: bool = False) -> set[str]: - """Get set of the direct relative ids to the current task, upstream or downstream.""" - if upstream: - return self.upstream_task_ids - return self.downstream_task_ids - - def get_direct_relatives(self, upstream: bool = False) -> Iterable[Operator]: - """Get list of the direct relatives to the current task, upstream or downstream.""" - if upstream: - return self.upstream_list - return self.downstream_list - - def get_flat_relative_ids(self, *, upstream: bool = False) -> set[str]: - """ - Get a flat set of relative IDs, upstream or downstream. - - Will recurse each relative found in the direction specified. - - :param upstream: Whether to look for upstream or downstream relatives. - """ - dag = self.get_dag() - if not dag: - return set() - - relatives: set[str] = set() - - # This is intentionally implemented as a loop, instead of calling - # get_direct_relative_ids() recursively, since Python has significant - # limitation on stack level, and a recursive implementation can blow up - # if a DAG contains very long routes. - task_ids_to_trace = self.get_direct_relative_ids(upstream) - while task_ids_to_trace: - task_ids_to_trace_next: set[str] = set() - for task_id in task_ids_to_trace: - if task_id in relatives: - continue - task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream)) - relatives.add(task_id) - task_ids_to_trace = task_ids_to_trace_next - - return relatives - - def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]: - """Get a flat list of relatives, either upstream or downstream.""" - dag = self.get_dag() - if not dag: - return set() - return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream=upstream)] - - def get_upstreams_follow_setups(self) -> Iterable[Operator]: - """All upstreams and, for each upstream setup, its respective teardowns.""" - for task in self.get_flat_relatives(upstream=True): - yield task - if task.is_setup: - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups_and_teardowns(self) -> Iterable[Operator]: - """ - Only *relevant* upstream setups and their teardowns. - - This method is meant to be used when we are clearing the task (non-upstream) and we need - to add in the *relevant* setups and their teardowns. - - Relevant in this case means, the setup has a teardown that is downstream of ``self``, - or the setup has no teardowns. - """ - downstream_teardown_ids = { - x.task_id for x in self.get_flat_relatives(upstream=False) if x.is_teardown - } - for task in self.get_flat_relatives(upstream=True): - if not task.is_setup: - continue - has_no_teardowns = not any(x.is_teardown for x in task.downstream_list) - # if task has no teardowns or has teardowns downstream of self - if has_no_teardowns or task.downstream_task_ids.intersection(downstream_teardown_ids): - yield task - for t in task.downstream_list: - if t.is_teardown and t != self: - yield t - - def get_upstreams_only_setups(self) -> Iterable[Operator]: - """ - Return relevant upstream setups. - - This method is meant to be used when we are checking task dependencies where we need - to wait for all the upstream setups to complete before we can run the task. - """ - for task in self.get_upstreams_only_setups_and_teardowns(): - if task.is_setup: - yield task - def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]: """Serialize a task group's content; used by TaskGroupSerialization.""" raise NotImplementedError() diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 887493e1ca703..479958b99d783 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -915,8 +915,7 @@ def is_task(obj) -> TypeGuard[Operator]: direct_upstreams: list[Operator] = [] if include_direct_upstream: for t in itertools.chain(matched_tasks, also_include): - upstream = (u for u in t.upstream_list if is_task(u)) - direct_upstreams.extend(upstream) + direct_upstreams.extend(u for u in t.upstream_list if is_task(u)) # Make sure to not recursively deepcopy the dag or task_group while copying the task. # task_group is reset later