Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions airflow-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ exclude = [

[tool.hatch.build.targets.sdist.force-include]
"../shared/configuration/src/airflow_shared/configuration" = "src/airflow/_shared/configuration"
"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/_shared/module_loading"
"../shared/dagnode/src/airflow_shared/dagnode" = "src/airflow/_shared/dagnode"
"../shared/logging/src/airflow_shared/logging" = "src/airflow/_shared/logging"
"../shared/module_loading/src/airflow_shared/module_loading" = "src/airflow/_shared/module_loading"
"../shared/observability/src/airflow_shared/observability" = "src/airflow/_shared/observability"
"../shared/secrets_backend/src/airflow_shared/secrets_backend" = "src/airflow/_shared/secrets_backend"
"../shared/secrets_masker/src/airflow_shared/secrets_masker" = "src/airflow/_shared/secrets_masker"
Expand Down Expand Up @@ -303,10 +304,11 @@ apache-airflow-devel-common = { workspace = true }
[tool.airflow]
shared_distributions = [
"apache-airflow-shared-configuration",
"apache-airflow-shared-dagnode",
"apache-airflow-shared-logging",
"apache-airflow-shared-module-loading",
"apache-airflow-shared-observability",
"apache-airflow-shared-secrets-backend",
"apache-airflow-shared-secrets-masker",
"apache-airflow-shared-timezones",
"apache-airflow-shared-observability",
]
1 change: 1 addition & 0 deletions airflow-core/src/airflow/_shared/dagnode
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@
if TYPE_CHECKING:
from sqlalchemy.sql.dml import Update

from airflow.models.expandinput import SchedulerExpandInput

router = VersionedAPIRouter()

ti_id_router = VersionedAPIRouter(
Expand Down Expand Up @@ -314,7 +312,7 @@ def _get_upstream_map_indexes(
except NotFullyPopulated:
# Second try: resolve XCom for correct count
try:
expand_input = cast("SchedulerExpandInput", upstream_mapped_group._expand_input)
expand_input = upstream_mapped_group._expand_input
mapped_ti_count = expand_input.get_total_map_length(ti.run_id, session=session)
except NotFullyPopulated:
# For these trigger rules, unresolved map indexes are acceptable.
Expand Down
11 changes: 3 additions & 8 deletions airflow-core/src/airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from airflow.exceptions import AirflowException, NotMapped
from airflow.sdk import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions._internal.abstractoperator import DEFAULT_RETRY_DELAY_MULTIPLIER
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.mappedoperator import MappedOperator as TaskSDKMappedOperator
from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParamsDict
from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.enums import DagAttributeTypes
Expand All @@ -48,6 +48,7 @@
from airflow.models import TaskInstance
from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink, Context
from airflow.sdk.definitions._internal.node import DAGNode as TaskSDKDAGNode
from airflow.sdk.definitions.operator_resources import Resources
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.task.trigger_rule import TriggerRule
Expand Down Expand Up @@ -83,7 +84,6 @@ def is_mapped(obj: Operator | SerializedTaskGroup) -> TypeGuard[MappedOperator |
getstate_setstate=False,
repr=False,
)
# TODO (GH-52141): Duplicate DAGNode in the scheduler.
class MappedOperator(DAGNode):
"""Object representing a mapped operator in a DAG."""

Expand All @@ -110,11 +110,6 @@ class MappedOperator(DAGNode):
start_from_trigger: bool = False
_needs_expansion: bool = True

# TODO (GH-52141): These should contain serialized containers, but currently
# this class inherits from an SDK one.
dag: SerializedDAG = attrs.field(init=False) # type: ignore[assignment]
task_group: SerializedTaskGroup = attrs.field(init=False) # type: ignore[assignment]

doc: str | None = attrs.field(init=False)
doc_json: str | None = attrs.field(init=False)
doc_rst: str | None = attrs.field(init=False)
Expand Down Expand Up @@ -503,7 +498,7 @@ def expand_start_trigger_args(self, *, context: Context) -> StartTriggerArgs | N


@functools.singledispatch
def get_mapped_ti_count(task: DAGNode, run_id: str, *, session: Session) -> int:
def get_mapped_ti_count(task: DAGNode | TaskSDKDAGNode, run_id: str, *, session: Session) -> int:
raise NotImplementedError(f"Not implemented for {type(task)}")


Expand Down
7 changes: 2 additions & 5 deletions airflow-core/src/airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from collections.abc import Collection, Iterable
from datetime import datetime, timedelta
from functools import cache
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any
from urllib.parse import quote

import attrs
Expand Down Expand Up @@ -2332,10 +2332,7 @@ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]
# Treat it as a normal task instead.
_visit_relevant_relatives_for_normal([task_id])
continue
# TODO (GH-52141): This should return scheduler operator types, but
# currently get_flat_relatives is inherited from SDK DAGNode.
relatives = cast("Iterable[Operator]", task.get_flat_relatives(upstream=direction == "upstream"))
for relative in relatives:
for relative in task.get_flat_relatives(upstream=direction == "upstream"):
if relative.task_id in visited:
continue
relative_map_indexes = _get_relevant_map_indexes(
Expand Down
4 changes: 1 addition & 3 deletions airflow-core/src/airflow/serialization/definitions/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,7 @@ def is_task(obj) -> TypeIs[SerializedOperator]:
direct_upstreams: list[SerializedOperator] = []
if include_direct_upstream:
for t in itertools.chain(matched_tasks, also_include):
# TODO (GH-52141): This should return scheduler types, but currently we reuse SDK DAGNode.
upstream = (u for u in cast("Iterable[SerializedOperator]", t.upstream_list) if is_task(u))
direct_upstreams.extend(upstream)
direct_upstreams.extend(u for u in t.upstream_list if is_task(u))

# Make sure to not recursively deepcopy the dag or task_group while copying the task.
# task_group is reset later
Expand Down
51 changes: 51 additions & 0 deletions airflow-core/src/airflow/serialization/definitions/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import abc
from typing import TYPE_CHECKING

from airflow._shared.dagnode.node import GenericDAGNode

if TYPE_CHECKING:
from collections.abc import Sequence
from typing import TypeAlias

from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup # noqa: F401
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG # noqa: F401

Operator: TypeAlias = SerializedBaseOperator | MappedOperator


class DAGNode(GenericDAGNode["SerializedDAG", "Operator", "SerializedTaskGroup"], metaclass=abc.ABCMeta):
"""
Base class for a node in the graph of a workflow.

A node may be an operator or task group, either mapped or unmapped.
"""

@property
@abc.abstractmethod
def roots(self) -> Sequence[DAGNode]:
raise NotImplementedError()

@property
@abc.abstractmethod
def leaves(self) -> Sequence[DAGNode]:
raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import attrs
import methodtools

from airflow.sdk.definitions._internal.node import DAGNode
from airflow.serialization.definitions.node import DAGNode

if TYPE_CHECKING:
from collections.abc import Generator, Iterator
Expand All @@ -45,8 +45,7 @@ class SerializedTaskGroup(DAGNode):
group_display_name: str | None = attrs.field()
prefix_group_id: bool = attrs.field()
parent_group: SerializedTaskGroup | None = attrs.field()
# TODO (GH-52141): Replace DAGNode dependency.
dag: SerializedDAG = attrs.field() # type: ignore[assignment]
dag: SerializedDAG = attrs.field()
tooltip: str = attrs.field()
default_args: dict[str, Any] = attrs.field(factory=dict)

Expand Down
15 changes: 6 additions & 9 deletions airflow-core/src/airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg
from airflow.sdk import DAG, Asset, AssetAlias, BaseOperator, XComArg
from airflow.sdk.bases.operator import OPERATOR_DEFAULTS # TODO: Copy this into the scheduler?
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.asset import (
AssetAliasEvent,
AssetAliasUniqueKey,
Expand All @@ -76,6 +75,7 @@
SerializedAssetUniqueKey,
)
from airflow.serialization.definitions.dag import SerializedDAG
from airflow.serialization.definitions.node import DAGNode
from airflow.serialization.definitions.param import SerializedParam, SerializedParamsDict
from airflow.serialization.definitions.taskgroup import SerializedMappedTaskGroup, SerializedTaskGroup
from airflow.serialization.encoders import (
Expand Down Expand Up @@ -118,6 +118,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator # noqa: TC004
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode as SDKDAGNode
from airflow.serialization.json_schema import Validator
from airflow.task.trigger_rule import TriggerRule
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
Expand Down Expand Up @@ -1022,7 +1023,6 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
yield from tt.asset_condition.iter_dag_dependencies(source="", target=dag.dag_id)


# TODO (GH-52141): Duplicate DAGNode in the scheduler.
class SerializedBaseOperator(DAGNode, BaseSerialization):
"""
A JSON serializable representation of operator.
Expand Down Expand Up @@ -1052,10 +1052,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
_task_display_name: str | None
_weight_rule: str | PriorityWeightStrategy = "downstream"

# TODO (GH-52141): These should contain serialized containers, but currently
# this class inherits from an SDK one.
dag: SerializedDAG | None = None # type: ignore[assignment]
task_group: SerializedTaskGroup | None = None # type: ignore[assignment]
dag: SerializedDAG | None = None
task_group: SerializedTaskGroup | None = None

allow_nested_operators: bool = True
depends_on_past: bool = False
Expand Down Expand Up @@ -1159,8 +1157,7 @@ def __repr__(self) -> str:
def node_id(self) -> str:
return self.task_id

# TODO (GH-52141): Replace DAGNode with a scheduler type.
def get_dag(self) -> SerializedDAG | None: # type: ignore[override]
def get_dag(self) -> SerializedDAG | None:
return self.dag

@property
Expand Down Expand Up @@ -1680,7 +1677,7 @@ def _matches_client_defaults(cls, var: Any, attrname: str) -> bool:
return False

@classmethod
def _is_excluded(cls, var: Any, attrname: str, op: DAGNode):
def _is_excluded(cls, var: Any, attrname: str, op: SDKDAGNode) -> bool:
"""
Determine if a variable is excluded from the serialized object.

Expand Down
7 changes: 1 addition & 6 deletions airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,7 @@ def _evaluate_direct_relatives() -> Iterator[TIDepStatus]:

if not task.is_teardown:
# a teardown cannot have any indirect setups
relevant_setups: dict[str, MappedOperator | SerializedBaseOperator] = {
# TODO (GH-52141): This should return scheduler types, but
# currently we reuse logic in SDK DAGNode.
t.task_id: t # type: ignore[misc]
for t in task.get_upstreams_only_setups()
}
relevant_setups = {t.task_id: t for t in task.get_upstreams_only_setups()}
if relevant_setups:
for status, changed in _evaluate_setup_constraint(relevant_setups=relevant_setups):
yield status
Expand Down
8 changes: 1 addition & 7 deletions airflow-core/src/airflow/utils/dag_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG

if TYPE_CHECKING:
from collections.abc import Iterable

from airflow.sdk import DAG

Operator: TypeAlias = MappedOperator | SerializedBaseOperator
Expand Down Expand Up @@ -118,11 +116,7 @@ def collect_edges(task_group):
while tasks_to_trace:
tasks_to_trace_next: list[Operator] = []
for task in tasks_to_trace:
# TODO (GH-52141): downstream_list on DAGNode needs to be able to
# return scheduler types when used in scheduler, but SDK types when
# used at runtime. This means DAGNode needs to be rewritten as a
# generic class.
for child in cast("Iterable[Operator]", task.downstream_list):
for child in task.downstream_list:
edge = (task.task_id, child.task_id)
if task.is_setup and child.is_teardown:
setup_teardown_edges.add(edge)
Expand Down
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/utils/dot_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import graphviz

from airflow.models import TaskInstance
from airflow.models.taskmixin import DependencyMixin
from airflow.serialization.dag_dependency import DagDependency
else:
try:
Expand Down Expand Up @@ -136,7 +135,7 @@ def _draw_task_group(


def _draw_nodes(
node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str | None] | None
node: object, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str | None] | None
) -> None:
"""Draw the node and its children on the given parent_graph recursively."""
if isinstance(node, (BaseOperator, MappedOperator, SerializedBaseOperator, SerializedMappedOperator)):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,7 @@ apache-airflow-kubernetes-tests = { workspace = true }
apache-airflow-providers = { workspace = true }
apache-aurflow-docker-stack = { workspace = true }
apache-airflow-shared-configuration = { workspace = true }
apache-airflow-shared-dagnode = { workspace = true }
apache-airflow-shared-logging = { workspace = true }
apache-airflow-shared-module-loading = { workspace = true }
apache-airflow-shared-secrets-backend = { workspace = true }
Expand Down
48 changes: 48 additions & 0 deletions shared/dagnode/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[project]
name = "apache-airflow-shared-dagnode"
description = "Shared DAGNode logic for Airflow distributions"
version = "0.0"
classifiers = [
"Private :: Do Not Upload",
]

dependencies = [
"structlog>=25.4.0",
]

[dependency-groups]
dev = [
"apache-airflow-devel-common",
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/airflow_shared"]

[tool.ruff]
extend = "../../pyproject.toml"
src = ["src"]

[tool.ruff.lint.per-file-ignores]
# Ignore Doc rules et al for anything outside of tests
"!src/*" = ["D", "S101", "TRY002"]
16 changes: 16 additions & 0 deletions shared/dagnode/src/airflow_shared/dagnode/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading