From f92e1906ea15ff2723a9514c71afc432ff19a141 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 28 Mar 2025 14:54:21 +0530 Subject: [PATCH] Move bases classes to `airflow.sdk.bases` These class aren't needed in DAG definition -- not atleast by DAG Authors This PR moves core base classes to a dedicated `airflow.sdk.bases` module to clarify their intended usage. These classes are internal implementation details and not meant for direct use by DAG authors. They are meant to be used by Providers developer or developer of custom operators. Key Changes: - Created new `airflow.sdk.bases` module - Moved base classes: - `BaseOperator` -> `airflow.sdk.bases.baseoperator` - `BaseSensorOperator` -> `airflow.sdk.bases.sensor` - `BaseOperatorLink` -> `airflow.sdk.bases.operatorlink` - `BaseNotifier` -> `airflow.sdk.bases.notifier` - `XCom` -> `airflow.sdk.bases.xcom` --- airflow-core/src/airflow/decorators/base.py | 2 +- .../src/airflow/decorators/condition.py | 2 +- airflow-core/src/airflow/models/__init__.py | 11 +- .../src/airflow/models/abstractoperator.py | 2 +- .../src/airflow/models/baseoperator.py | 6 +- .../src/airflow/models/baseoperatorlink.py | 2 +- .../src/airflow/models/taskinstance.py | 2 +- airflow-core/src/airflow/models/taskmap.py | 2 +- airflow-core/src/airflow/models/xcom.py | 14 +- airflow-core/src/airflow/sensors/base.py | 2 +- .../serialization/serialized_objects.py | 2 +- airflow-core/src/airflow/utils/task_group.py | 2 +- .../core_api/routes/public/test_xcom.py | 3 +- .../tests/unit/models/test_taskinstance.py | 2 +- airflow-core/tests/unit/models/test_xcom.py | 3 +- .../serialization/test_dag_serialization.py | 4 +- dev/mypy/plugin/outputs.py | 2 +- .../src/tests_common/pytest_plugin.py | 2 +- .../common/compat/notifier/__init__.py | 4 +- .../providers/common/io/xcom/backend.py | 2 +- .../cloud/operators/dataproc_metastore.py | 1 - .../openlineage/utils/selective_enable.py | 2 +- .../openlineage/extractors/test_manager.py | 2 +- .../check_base_operator_partial_arguments.py | 2 +- .../sdk/definitions/sensors => }/__init__.py | 1 - task-sdk/src/airflow/sdk/__init__.py | 20 +- task-sdk/src/airflow/sdk/bases/__init__.py | 16 + .../{definitions => bases}/baseoperator.py | 0 .../sdk/{definitions => bases}/notifier.py | 0 .../operatorlink.py} | 0 .../sensors/base.py => bases/sensor.py} | 2 +- task-sdk/src/airflow/sdk/bases/xcom.py | 311 ++++++++++++++++++ .../definitions/_internal/abstractoperator.py | 4 +- .../sdk/definitions/_internal/mixins.py | 2 +- .../airflow/sdk/definitions/_internal/node.py | 2 +- .../src/airflow/sdk/definitions/context.py | 2 +- task-sdk/src/airflow/sdk/definitions/dag.py | 2 +- .../airflow/sdk/definitions/mappedoperator.py | 4 +- .../src/airflow/sdk/definitions/taskgroup.py | 2 +- .../src/airflow/sdk/definitions/xcom_arg.py | 2 +- .../src/airflow/sdk/execution_time/context.py | 2 +- .../airflow/sdk/execution_time/task_runner.py | 2 +- .../src/airflow/sdk/execution_time/xcom.py | 296 +---------------- task-sdk/src/airflow/sdk/types.py | 2 +- task-sdk/tests/task_sdk/bases/__init__.py | 16 + .../notifier/test_notifier.txt | 0 .../test_baseoperator.py | 2 +- .../{definitions => bases}/test_notifier.py | 2 +- .../test_base.py => bases/test_sensor.py} | 2 +- task-sdk/tests/task_sdk/dags/super_basic.py | 2 +- .../tests/task_sdk/dags/super_basic_run.py | 2 +- .../tests/task_sdk/definitions/test_dag.py | 2 +- .../definitions/test_mappedoperator.py | 2 +- .../tests/task_sdk/definitions/test_mixins.py | 2 +- 54 files changed, 420 insertions(+), 362 deletions(-) rename task-sdk/{src/airflow/sdk/definitions/sensors => }/__init__.py (99%) create mode 100644 task-sdk/src/airflow/sdk/bases/__init__.py rename task-sdk/src/airflow/sdk/{definitions => bases}/baseoperator.py (100%) rename task-sdk/src/airflow/sdk/{definitions => bases}/notifier.py (100%) rename task-sdk/src/airflow/sdk/{definitions/baseoperatorlink.py => bases/operatorlink.py} (100%) rename task-sdk/src/airflow/sdk/{definitions/sensors/base.py => bases/sensor.py} (99%) create mode 100644 task-sdk/src/airflow/sdk/bases/xcom.py create mode 100644 task-sdk/tests/task_sdk/bases/__init__.py rename task-sdk/tests/task_sdk/{definitions => bases}/notifier/test_notifier.txt (100%) rename task-sdk/tests/task_sdk/{definitions => bases}/test_baseoperator.py (99%) rename task-sdk/tests/task_sdk/{definitions => bases}/test_notifier.py (98%) rename task-sdk/tests/task_sdk/{definitions/sensors/test_base.py => bases/test_sensor.py} (99%) diff --git a/airflow-core/src/airflow/decorators/base.py b/airflow-core/src/airflow/decorators/base.py index 00a61dcd2965f..d0c8bdcf6da31 100644 --- a/airflow-core/src/airflow/decorators/base.py +++ b/airflow-core/src/airflow/decorators/base.py @@ -41,9 +41,9 @@ ListOfDictsExpandInput, is_mappable, ) +from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext from airflow.sdk.definitions.asset import Asset -from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value from airflow.sdk.definitions.xcom_arg import XComArg from airflow.typing_compat import ParamSpec diff --git a/airflow-core/src/airflow/decorators/condition.py b/airflow-core/src/airflow/decorators/condition.py index 06fd01391f28b..a38b2ab24197b 100644 --- a/airflow-core/src/airflow/decorators/condition.py +++ b/airflow-core/src/airflow/decorators/condition.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias - from airflow.sdk.definitions.baseoperator import TaskPreExecuteHook + from airflow.sdk.bases.baseoperator import TaskPreExecuteHook from airflow.sdk.definitions.context import Context BoolConditionFunc: TypeAlias = Callable[[Context], bool] diff --git a/airflow-core/src/airflow/models/__init__.py b/airflow-core/src/airflow/models/__init__.py index 20e1c65df8af7..9274ae7a79f3f 100644 --- a/airflow-core/src/airflow/models/__init__.py +++ b/airflow-core/src/airflow/models/__init__.py @@ -26,6 +26,7 @@ "Base", "BaseOperator", "BaseOperatorLink", + "BaseXCom", "Connection", "DagBag", "DagWarning", @@ -44,7 +45,7 @@ "TaskReschedule", "Trigger", "Variable", - "XComModel", + "XCom", "clear_task_instances", ] @@ -65,6 +66,7 @@ def import_all_models(): import airflow.models.serialized_dag import airflow.models.taskinstancehistory import airflow.models.tasklog + import airflow.models.xcom def __getattr__(name): @@ -88,7 +90,8 @@ def __getattr__(name): "ID_LEN": "airflow.models.base", "Base": "airflow.models.base", "BaseOperator": "airflow.models.baseoperator", - "BaseOperatorLink": "airflow.sdk.definitions.baseoperatorlink", + "BaseOperatorLink": "airflow.sdk.bases.operatorlink", + "BaseXCom": "airflow.sdk.bases.xcom", "Connection": "airflow.models.connection", "DagBag": "airflow.models.dagbag", "DagModel": "airflow.models.dag", @@ -114,7 +117,6 @@ def __getattr__(name): if TYPE_CHECKING: # I was unable to get mypy to respect a airflow/models/__init__.pyi, so # having to resort back to this hacky method - from airflow.jobs.job import Job from airflow.models.base import ID_LEN, Base from airflow.models.baseoperator import BaseOperator from airflow.models.connection import Connection @@ -135,6 +137,7 @@ def __getattr__(name): from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.variable import Variable - from airflow.sdk import BaseOperatorLink + from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk.bases.xcom import BaseXCom from airflow.sdk.definitions.param import Param from airflow.sdk.execution_time.xcom import XCom diff --git a/airflow-core/src/airflow/models/abstractoperator.py b/airflow-core/src/airflow/models/abstractoperator.py index 007e43aa5254c..9dc8f32bdfe39 100644 --- a/airflow-core/src/airflow/models/abstractoperator.py +++ b/airflow-core/src/airflow/models/abstractoperator.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs diff --git a/airflow-core/src/airflow/models/baseoperator.py b/airflow-core/src/airflow/models/baseoperator.py index 51cd83a6fecbd..23181f97d58f7 100644 --- a/airflow-core/src/airflow/models/baseoperator.py +++ b/airflow-core/src/airflow/models/baseoperator.py @@ -47,15 +47,15 @@ NotMapped, ) from airflow.models.taskinstance import TaskInstance, clear_task_instances -from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator -from airflow.sdk.definitions.baseoperator import ( +from airflow.sdk.bases.baseoperator import ( + BaseOperator as TaskSDKBaseOperator, # Re-export for compat chain as chain, chain_linear as chain_linear, cross_downstream as cross_downstream, get_merged_defaults as get_merged_defaults, ) -from airflow.sdk.definitions.dag import BaseOperator as TaskSDKBaseOperator +from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup from airflow.serialization.enums import DagAttributeTypes diff --git a/airflow-core/src/airflow/models/baseoperatorlink.py b/airflow-core/src/airflow/models/baseoperatorlink.py index 3f95e162d8f7b..09d21f868515d 100644 --- a/airflow-core/src/airflow/models/baseoperatorlink.py +++ b/airflow-core/src/airflow/models/baseoperatorlink.py @@ -19,4 +19,4 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink as BaseOperatorLink +from airflow.sdk.bases.operatorlink import BaseOperatorLink as BaseOperatorLink diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 0e1bcf4a3bd32..bb1d4016fdb4d 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -605,7 +605,7 @@ def _execute_task(task_instance: TaskInstance, context: Context, task_orig: Oper :meta private: """ - from airflow.sdk.definitions.baseoperator import ExecutorSafeguard + from airflow.sdk.bases.baseoperator import ExecutorSafeguard from airflow.sdk.definitions.mappedoperator import MappedOperator task_to_execute = task_instance.task diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 04bd6d974e139..12f386c3b45a4 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -133,7 +133,7 @@ def expand_mapped_task(cls, task, run_id: str, *, session: Session) -> tuple[Seq from airflow.models.baseoperator import BaseOperator as DBBaseOperator from airflow.models.expandinput import NotFullyPopulated from airflow.models.taskinstance import TaskInstance - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.settings import task_instance_mutation_hook diff --git a/airflow-core/src/airflow/models/xcom.py b/airflow-core/src/airflow/models/xcom.py index 88549d65eb5f7..09bfcff7b02b8 100644 --- a/airflow-core/src/airflow/models/xcom.py +++ b/airflow-core/src/airflow/models/xcom.py @@ -398,12 +398,16 @@ def _process_row(row: Row) -> Any: def __getattr__(name: str): - if name == "BaseXCom" or name == "XCom": - from airflow.sdk.execution_time import xcom + if name == "BaseXCom": + from airflow.sdk.bases.xcom import BaseXCom - val = getattr(xcom, name) + globals()[name] = BaseXCom + return BaseXCom - globals()[name] = val - return val + if name == "XCom": + from airflow.sdk.execution_time.xcom import XCom + + globals()[name] = XCom + return XCom raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/airflow-core/src/airflow/sensors/base.py b/airflow-core/src/airflow/sensors/base.py index 9ac7c28c2d53f..71ae006f53437 100644 --- a/airflow-core/src/airflow/sensors/base.py +++ b/airflow-core/src/airflow/sensors/base.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from airflow.sdk.definitions.sensors.base import ( +from airflow.sdk.bases.sensor import ( BaseSensorOperator as BaseSensorOperator, PokeReturnValue as PokeReturnValue, poke_mode_only as poke_mode_only, diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index d2af3540b8038..8dd1ca10ec737 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -51,6 +51,7 @@ from airflow.models.xcom import XComModel from airflow.models.xcom_arg import SchedulerXComArg, deserialize_xcom_arg from airflow.providers_manager import ProvidersManager +from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.asset import ( Asset, AssetAlias, @@ -63,7 +64,6 @@ AssetWatcher, BaseAsset, ) -from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import Param, ParamsDict from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup diff --git a/airflow-core/src/airflow/utils/task_group.py b/airflow-core/src/airflow/utils/task_group.py index 0712b3cb3aef5..c264274e7167e 100644 --- a/airflow-core/src/airflow/utils/task_group.py +++ b/airflow-core/src/airflow/utils/task_group.py @@ -32,8 +32,8 @@ def task_group_to_dict(task_item_or_group, parent_group_is_mapped=False): """Create a nested dict representation of this TaskGroup and its children used to construct the Graph.""" + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator if isinstance(task := task_item_or_group, AbstractOperator): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py index 61ff52b6bb272..95590bfbe13da 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_xcom.py @@ -27,7 +27,8 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend +from airflow.sdk.bases.xcom import BaseXCom +from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.types import DagRunType diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 70029e56a05c3..ca99aff9d153e 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -77,8 +77,8 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.python import PythonSensor from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse +from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.definitions.asset import Asset, AssetAlias -from airflow.sdk.definitions.notifier import BaseNotifier from airflow.sdk.definitions.param import process_params from airflow.sdk.execution_time.comms import ( AssetEventsResult, diff --git a/airflow-core/tests/unit/models/test_xcom.py b/airflow-core/tests/unit/models/test_xcom.py index 35391d6a6d9de..73ea34bf85289 100644 --- a/airflow-core/tests/unit/models/test_xcom.py +++ b/airflow-core/tests/unit/models/test_xcom.py @@ -29,7 +29,8 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.xcom import XComModel from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.sdk.execution_time.xcom import BaseXCom, resolve_xcom_backend +from airflow.sdk.bases.xcom import BaseXCom +from airflow.sdk.execution_time.xcom import resolve_xcom_backend from airflow.settings import json from airflow.utils import timezone from airflow.utils.session import create_session diff --git a/airflow-core/tests/unit/serialization/test_dag_serialization.py b/airflow-core/tests/unit/serialization/test_dag_serialization.py index 04b246d6af477..c435aa436a3dd 100644 --- a/airflow-core/tests/unit/serialization/test_dag_serialization.py +++ b/airflow-core/tests/unit/serialization/test_dag_serialization.py @@ -2015,7 +2015,7 @@ def test_edge_info_serialization(self): @pytest.mark.db_test @pytest.mark.parametrize("mode", ["poke", "reschedule"]) def test_serialize_sensor(self, mode): - from airflow.sdk.definitions.sensors.base import BaseSensorOperator + from airflow.sdk.bases.sensor import BaseSensorOperator class DummySensor(BaseSensorOperator): def poke(self, context: Context): @@ -2032,7 +2032,7 @@ def poke(self, context: Context): @pytest.mark.parametrize("mode", ["poke", "reschedule"]) def test_serialize_mapped_sensor_has_reschedule_dep(self, mode): - from airflow.sdk.definitions.sensors.base import BaseSensorOperator + from airflow.sdk.bases.sensor import BaseSensorOperator from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep class DummySensor(BaseSensorOperator): diff --git a/dev/mypy/plugin/outputs.py b/dev/mypy/plugin/outputs.py index a3ba7351f556d..485a50cca0e1e 100644 --- a/dev/mypy/plugin/outputs.py +++ b/dev/mypy/plugin/outputs.py @@ -25,7 +25,7 @@ OUTPUT_PROPERTIES = { "airflow.models.baseoperator.BaseOperator.output", "airflow.models.mappedoperator.MappedOperator.output", - "airflow.sdk.definitions.baseoperator.BaseOperator.output", + "airflow.sdk.bases.baseoperator.BaseOperator.output", } TASK_CALL_FUNCTIONS = { diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 3b5aa6571337d..d645a100d042b 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -46,7 +46,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.providers.standard.operators.empty import EmptyOperator from airflow.sdk.api.datamodels._generated import IntermediateTIState, TerminalTIState - from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator as TaskSDKBaseOperator from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance from airflow.timetables.base import DataInterval diff --git a/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py b/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py index 91ce1cdb71d96..e58c7bc7c7329 100644 --- a/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py +++ b/providers/common/compat/src/airflow/providers/common/compat/notifier/__init__.py @@ -22,9 +22,9 @@ from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: - from airflow.sdk.definitions.notifier import BaseNotifier + from airflow.sdk.bases.notifier import BaseNotifier elif AIRFLOW_V_3_0_PLUS: - from airflow.sdk.definitions.notifier import BaseNotifier + from airflow.sdk.bases.notifier import BaseNotifier else: from airflow.notifications.basenotifier import BaseNotifier diff --git a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py index c318698af3d4e..79bedc3977eb5 100644 --- a/providers/common/io/src/airflow/providers/common/io/xcom/backend.py +++ b/providers/common/io/src/airflow/providers/common/io/xcom/backend.py @@ -37,7 +37,7 @@ from airflow.sdk.execution_time.comms import XComResult if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.execution_time.xcom import BaseXCom + from airflow.sdk.bases.xcom import BaseXCom else: from airflow.models.xcom import BaseXCom # type: ignore[no-redef] diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py index 72d4e7fc160a7..d99ba3d160a4e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc_metastore.py @@ -31,7 +31,6 @@ from google.cloud.metastore_v1.types.metastore import DatabaseDumpSpec, Restore from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, BaseOperatorLink from airflow.providers.google.cloud.hooks.dataproc_metastore import DataprocMetastoreHook from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.links.storage import StorageLink diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py index 1317e84869c5c..794681b413573 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/selective_enable.py @@ -38,7 +38,7 @@ T = TypeVar("T", bound="DAG | Operator") if TYPE_CHECKING: - from airflow.sdk.definitions.baseoperator import BaseOperator as SdkBaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator as SdkBaseOperator log = logging.getLogger(__name__) diff --git a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py index 0445964fee39d..14255f28127b5 100644 --- a/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py +++ b/providers/openlineage/tests/unit/openlineage/extractors/test_manager.py @@ -76,7 +76,7 @@ def hook_lineage_collector(): if AIRFLOW_V_3_0_PLUS: from airflow.sdk.api.datamodels._generated import BundleInfo, TaskInstance as SDKTaskInstance - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import StartupDetails from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, parse diff --git a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py index 970d5d50aa766..83a901bae99d5 100755 --- a/scripts/ci/pre_commit/check_base_operator_partial_arguments.py +++ b/scripts/ci/pre_commit/check_base_operator_partial_arguments.py @@ -30,7 +30,7 @@ from common_precommit_utils import AIRFLOW_CORE_SOURCES_PATH, AIRFLOW_TASK_SDK_SOURCES_PATH, console BASEOPERATOR_PY = AIRFLOW_CORE_SOURCES_PATH / "airflow" / "models" / "baseoperator.py" -SDK_BASEOPERATOR_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "baseoperator.py" +SDK_BASEOPERATOR_PY = AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "bases" / "baseoperator.py" SDK_MAPPEDOPERATOR_PY = ( AIRFLOW_TASK_SDK_SOURCES_PATH / "airflow" / "sdk" / "definitions" / "mappedoperator.py" ) diff --git a/task-sdk/src/airflow/sdk/definitions/sensors/__init__.py b/task-sdk/__init__.py similarity index 99% rename from task-sdk/src/airflow/sdk/definitions/sensors/__init__.py rename to task-sdk/__init__.py index 217e5db960782..13a83393a9124 100644 --- a/task-sdk/src/airflow/sdk/definitions/sensors/__init__.py +++ b/task-sdk/__init__.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/task-sdk/src/airflow/sdk/__init__.py b/task-sdk/src/airflow/sdk/__init__.py index 25ef125a8c010..a7c768531b0da 100644 --- a/task-sdk/src/airflow/sdk/__init__.py +++ b/task-sdk/src/airflow/sdk/__init__.py @@ -52,18 +52,18 @@ __version__ = "1.0.0.alpha1" if TYPE_CHECKING: + from airflow.sdk.bases.baseoperator import BaseOperator, chain, chain_linear, cross_downstream + from airflow.sdk.bases.notifier import BaseNotifier + from airflow.sdk.bases.operatorlink import BaseOperatorLink + from airflow.sdk.bases.sensor import BaseSensorOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny, AssetWatcher from airflow.sdk.definitions.asset.decorators import asset from airflow.sdk.definitions.asset.metadata import Metadata - from airflow.sdk.definitions.baseoperator import BaseOperator, chain, chain_linear, cross_downstream - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context, get_current_context, get_parsing_context from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label - from airflow.sdk.definitions.notifier import BaseNotifier from airflow.sdk.definitions.param import Param - from airflow.sdk.definitions.sensors.base import BaseSensorOperator from airflow.sdk.definitions.taskgroup import TaskGroup from airflow.sdk.definitions.template import literal from airflow.sdk.definitions.variable import Variable @@ -76,9 +76,9 @@ "AssetAny": ".definitions.asset", "AssetWatcher": ".definitions.asset", "BaseNotifier": ".definitions.notifier", - "BaseOperator": ".definitions.baseoperator", - "BaseOperatorLink": ".definitions.baseoperatorlink", - "BaseSensorOperator": ".definitions.sensors.base", + "BaseOperator": ".bases.baseoperator", + "BaseOperatorLink": ".bases.operatorlink", + "BaseSensorOperator": ".bases.sensor", "Connection": ".definitions.connection", "Context": ".definitions.context", "DAG": ".definitions.dag", @@ -90,9 +90,9 @@ "Variable": ".definitions.variable", "XComArg": ".definitions.xcom_arg", "asset": ".definitions.asset.decorators", - "chain": ".definitions.baseoperator", - "chain_linear": ".definitions.baseoperator", - "cross_downstream": ".definitions.baseoperator", + "chain": ".bases.baseoperator", + "chain_linear": ".bases.baseoperator", + "cross_downstream": ".bases.baseoperator", "dag": ".definitions.dag", "get_current_context": ".definitions.context", "get_parsing_context": ".definitions.context", diff --git a/task-sdk/src/airflow/sdk/bases/__init__.py b/task-sdk/src/airflow/sdk/bases/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py b/task-sdk/src/airflow/sdk/bases/baseoperator.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/baseoperator.py rename to task-sdk/src/airflow/sdk/bases/baseoperator.py diff --git a/task-sdk/src/airflow/sdk/definitions/notifier.py b/task-sdk/src/airflow/sdk/bases/notifier.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/notifier.py rename to task-sdk/src/airflow/sdk/bases/notifier.py diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperatorlink.py b/task-sdk/src/airflow/sdk/bases/operatorlink.py similarity index 100% rename from task-sdk/src/airflow/sdk/definitions/baseoperatorlink.py rename to task-sdk/src/airflow/sdk/bases/operatorlink.py diff --git a/task-sdk/src/airflow/sdk/definitions/sensors/base.py b/task-sdk/src/airflow/sdk/bases/sensor.py similarity index 99% rename from task-sdk/src/airflow/sdk/definitions/sensors/base.py rename to task-sdk/src/airflow/sdk/bases/sensor.py index 7e89e2550b20d..e896fb9f538bb 100644 --- a/task-sdk/src/airflow/sdk/definitions/sensors/base.py +++ b/task-sdk/src/airflow/sdk/bases/sensor.py @@ -37,7 +37,7 @@ TaskDeferralTimeout, ) from airflow.executors.executor_loader import ExecutorLoader -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.utils import timezone if TYPE_CHECKING: diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py new file mode 100644 index 0000000000000..3376355ff0745 --- /dev/null +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -0,0 +1,311 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +import structlog + +from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult + +log = structlog.get_logger(logger_name="task") + + +class BaseXCom: + """BaseXcom is an interface now to interact with XCom backends.""" + + @classmethod + def set( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + _mapped_length: int | None = None, + ) -> None: + """ + Store an XCom value. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + value = cls.serialize_value( + value=value, + key=key, + task_id=task_id, + dag_id=dag_id, + run_id=run_id, + map_index=map_index, + ) + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + mapped_length=_mapped_length, + ), + ) + + @classmethod + def _set_xcom_in_db( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + ) -> None: + """ + Store an XCom value directly in the metadata database. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + @classmethod + def get_value( + cls, + *, + ti_key: Any, + key: str, + ) -> Any: + """ + Retrieve an XCom value for a task instance. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + :param ti_key: The TaskInstanceKey to look up the XCom for. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + return cls.get_one( + key=key, + task_id=ti_key.task_id, + dag_id=ti_key.dag_id, + run_id=ti_key.run_id, + map_index=ti_key.map_index, + ) + + @classmethod + def _get_xcom_db_ref( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + ) -> XComResult: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + return msg + + @classmethod + def get_one( + cls, + *, + key: str, + dag_id: str, + task_id: str, + run_id: str, + map_index: int | None = None, + include_prior_dates: bool = False, + ) -> Any | None: + """ + Retrieve an XCom value, optionally meeting certain criteria. + + This method returns "full" XCom values (i.e. uses ``deserialize_value`` + from the XCom backend). Use :meth:`get_many` if you want the "shortened" + value via ``orm_deserialize_value``. + + If there are no results, *None* is returned. If multiple XCom entries + match the criteria, an arbitrary one is returned. + + .. seealso:: ``get_value()`` is a convenience function if you already + have a structured TaskInstance or TaskInstanceKey object available. + + :param run_id: DAG run ID for the task. + :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to + remove the filter. + :param task_id: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param map_index: Only XCom from task with matching ID will be pulled. + Pass *None* (default) to remove the filter. + :param key: A key for the XCom. If provided, only XCom with matching + keys will be returned. Pass *None* (default) to remove the filter. + :param include_prior_dates: If *False* (default), only XCom from the + specified DAG run is returned. If *True*, the latest matching XCom is + returned regardless of the run it belongs to. + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=GetXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + include_prior_dates=include_prior_dates, + ), + ) + + msg = SUPERVISOR_COMMS.get_message() + if not isinstance(msg, XComResult): + raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") + + if msg.value is not None: + return cls.deserialize_value(msg) + return None + + @staticmethod + def serialize_value( + value: Any, + *, + key: str | None = None, + task_id: str | None = None, + dag_id: str | None = None, + run_id: str | None = None, + map_index: int | None = None, + ) -> str: + """Serialize XCom value to JSON str.""" + from airflow.serialization.serde import serialize + + # return back the value for BaseXCom, custom backends will implement this + return serialize(value) # type: ignore[return-value] + + @staticmethod + def deserialize_value(result) -> Any: + """Deserialize XCom value from str objects.""" + from airflow.serialization.serde import deserialize + + return deserialize(result.value) + + @classmethod + def purge(cls, xcom: XComResult, *args) -> None: + """Purge an XCom entry from underlying storage implementations.""" + pass + + @classmethod + def delete( + cls, + key: str, + task_id: str, + dag_id: str, + run_id: str, + map_index: int | None = None, + ) -> None: + """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + xcom_result = cls._get_xcom_db_ref( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ) + cls.purge(xcom_result) # type: ignore[call-arg] + SUPERVISOR_COMMS.send_request( + log=log, + msg=DeleteXCom( + key=key, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + ), + ) diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index 944813bce1c24..b5b68d56330ab 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -41,8 +41,8 @@ if TYPE_CHECKING: import jinja2 - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink + from airflow.sdk.bases.baseoperator import BaseOperator + from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py index 93fd9431cbe38..71b88dc400028 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/mixins.py @@ -22,8 +22,8 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.edges import EdgeModifier diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/node.py b/task-sdk/src/airflow/sdk/definitions/_internal/node.py index 93140f0a07c96..7fab2f1919b39 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/node.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/node.py @@ -122,7 +122,7 @@ def _set_relatives( edge_modifier: EdgeModifier | None = None, ) -> None: """Set relatives for the task or task list.""" - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.mappedoperator import MappedOperator if not isinstance(task_or_task_list, Sequence): diff --git a/task-sdk/src/airflow/sdk/definitions/context.py b/task-sdk/src/airflow/sdk/definitions/context.py index ee6a48cd904b9..03fbb6a7763fa 100644 --- a/task-sdk/src/airflow/sdk/definitions/context.py +++ b/task-sdk/src/airflow/sdk/definitions/context.py @@ -25,7 +25,7 @@ from datetime import datetime from airflow.models.operator import Operator - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.execution_time.context import InletEventsAccessors from airflow.sdk.types import ( diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 64a076346a001..dd2f4cef71dc4 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -51,10 +51,10 @@ ParamValidationError, TaskNotFound, ) +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.types import NOTSET from airflow.sdk.definitions.asset import AssetAll, BaseAsset -from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.param import DagParam, ParamsDict from airflow.timetables.base import Timetable diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index 06bf309f6ce15..7416967d42cb0 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -69,8 +69,8 @@ OperatorExpandKwargsArgument, ) from airflow.models.xcom_arg import XComArg - from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.baseoperatorlink import BaseOperatorLink + from airflow.sdk.bases.baseoperator import BaseOperator + from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 8152b76363ff8..c03c50cc4cd52 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -41,9 +41,9 @@ if TYPE_CHECKING: from airflow.models.expandinput import ExpandInput + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator from airflow.sdk.definitions._internal.mixins import DependencyMixin - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d080cc7ff3b13..c5c124cf3c81b 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -34,7 +34,7 @@ from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier from airflow.sdk.types import Operator diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index d9b865c18b1fc..5ce223962a240 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -43,7 +43,7 @@ from uuid import UUID from airflow.sdk import Variable - from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.context import Context from airflow.sdk.execution_time.comms import ( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index f64582f5cbcb1..4d5e9dcd753a9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -46,10 +46,10 @@ TerminalTIState, TIRunContext, ) +from airflow.sdk.bases.baseoperator import BaseOperator, ExecutorSafeguard from airflow.sdk.definitions._internal.dag_parsing_context import _airflow_parsing_context_manager from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef -from airflow.sdk.definitions.baseoperator import BaseOperator, ExecutorSafeguard from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.param import process_params from airflow.sdk.exceptions import ErrorType diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py b/task-sdk/src/airflow/sdk/execution_time/xcom.py index abb964907f196..536d10e22c4bd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/xcom.py +++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py @@ -14,302 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations -from typing import Any - -import structlog - from airflow.configuration import conf -from airflow.sdk.execution_time.comms import DeleteXCom, GetXCom, SetXCom, XComResult - -log = structlog.get_logger(logger_name="task") - - -class BaseXCom: - """BaseXcom is an interface now to interact with XCom backends.""" - - @classmethod - def set( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, - _mapped_length: int | None = None, - ) -> None: - """ - Store an XCom value. - - :param key: Key to store the XCom. - :param value: XCom value to store. - :param dag_id: DAG ID. - :param task_id: Task ID. - :param run_id: DAG run ID for the task. - :param map_index: Optional map index to assign XCom for a mapped task. - The default is ``-1`` (set for a non-mapped task). - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - value = cls.serialize_value( - value=value, - key=key, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, - ) - - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - mapped_length=_mapped_length, - ), - ) - - @classmethod - def _set_xcom_in_db( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, - ) -> None: - """ - Store an XCom value directly in the metadata database. - - :param key: Key to store the XCom. - :param value: XCom value to store. - :param dag_id: DAG ID. - :param task_id: Task ID. - :param run_id: DAG run ID for the task. - :param map_index: Optional map index to assign XCom for a mapped task. - The default is ``-1`` (set for a non-mapped task). - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - @classmethod - def get_value( - cls, - *, - ti_key: Any, - key: str, - ) -> Any: - """ - Retrieve an XCom value for a task instance. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - :param ti_key: The TaskInstanceKey to look up the XCom for. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - """ - return cls.get_one( - key=key, - task_id=ti_key.task_id, - dag_id=ti_key.dag_id, - run_id=ti_key.run_id, - map_index=ti_key.map_index, - ) - - @classmethod - def _get_xcom_db_ref( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, - ) -> XComResult: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - .. seealso:: ``get_value()`` is a convenience function if you already - have a structured TaskInstance or TaskInstanceKey object available. - - :param run_id: DAG run ID for the task. - :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to - remove the filter. - :param task_id: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param map_index: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComResult): - raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") - - return msg - - @classmethod - def get_one( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, - include_prior_dates: bool = False, - ) -> Any | None: - """ - Retrieve an XCom value, optionally meeting certain criteria. - - This method returns "full" XCom values (i.e. uses ``deserialize_value`` - from the XCom backend). Use :meth:`get_many` if you want the "shortened" - value via ``orm_deserialize_value``. - - If there are no results, *None* is returned. If multiple XCom entries - match the criteria, an arbitrary one is returned. - - .. seealso:: ``get_value()`` is a convenience function if you already - have a structured TaskInstance or TaskInstanceKey object available. - - :param run_id: DAG run ID for the task. - :param dag_id: Only pull XCom from this DAG. Pass *None* (default) to - remove the filter. - :param task_id: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param map_index: Only XCom from task with matching ID will be pulled. - Pass *None* (default) to remove the filter. - :param key: A key for the XCom. If provided, only XCom with matching - keys will be returned. Pass *None* (default) to remove the filter. - :param include_prior_dates: If *False* (default), only XCom from the - specified DAG run is returned. If *True*, the latest matching XCom is - returned regardless of the run it belongs to. - """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - SUPERVISOR_COMMS.send_request( - log=log, - msg=GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - include_prior_dates=include_prior_dates, - ), - ) - - msg = SUPERVISOR_COMMS.get_message() - if not isinstance(msg, XComResult): - raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}") - - if msg.value is not None: - return cls.deserialize_value(msg) - return None - - @staticmethod - def serialize_value( - value: Any, - *, - key: str | None = None, - task_id: str | None = None, - dag_id: str | None = None, - run_id: str | None = None, - map_index: int | None = None, - ) -> str: - """Serialize XCom value to JSON str.""" - from airflow.serialization.serde import serialize - - # return back the value for BaseXCom, custom backends will implement this - return serialize(value) # type: ignore[return-value] - - @staticmethod - def deserialize_value(result) -> Any: - """Deserialize XCom value from str objects.""" - from airflow.serialization.serde import deserialize - - return deserialize(result.value) - - @classmethod - def purge(cls, xcom: XComResult, *args) -> None: - """Purge an XCom entry from underlying storage implementations.""" - pass - - @classmethod - def delete( - cls, - key: str, - task_id: str, - dag_id: str, - run_id: str, - map_index: int | None = None, - ) -> None: - """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS - - xcom_result = cls._get_xcom_db_ref( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ) - cls.purge(xcom_result) # type: ignore[call-arg] - SUPERVISOR_COMMS.send_request( - log=log, - msg=DeleteXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - ), - ) +from airflow.sdk.bases.xcom import BaseXCom def resolve_xcom_backend(): @@ -318,7 +26,7 @@ def resolve_xcom_backend(): :returns: returns the custom XCom class if configured. """ - clazz = conf.getimport("core", "xcom_backend", fallback="airflow.sdk.execution_time.xcom.BaseXCom") + clazz = conf.getimport("core", "xcom_backend", fallback="airflow.sdk.bases.xcom.BaseXCom") if not clazz: return BaseXCom if not issubclass(clazz, BaseXCom): diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 6c3c40f2ab93d..e7e2036aa5c2b 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -29,8 +29,8 @@ from collections.abc import Iterator from datetime import datetime + from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator diff --git a/task-sdk/tests/task_sdk/bases/__init__.py b/task-sdk/tests/task_sdk/bases/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/task-sdk/tests/task_sdk/bases/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/task-sdk/tests/task_sdk/definitions/notifier/test_notifier.txt b/task-sdk/tests/task_sdk/bases/notifier/test_notifier.txt similarity index 100% rename from task-sdk/tests/task_sdk/definitions/notifier/test_notifier.txt rename to task-sdk/tests/task_sdk/bases/notifier/test_notifier.txt diff --git a/task-sdk/tests/task_sdk/definitions/test_baseoperator.py b/task-sdk/tests/task_sdk/bases/test_baseoperator.py similarity index 99% rename from task-sdk/tests/task_sdk/definitions/test_baseoperator.py rename to task-sdk/tests/task_sdk/bases/test_baseoperator.py index f87c82e6d4e83..969030a22ca86 100644 --- a/task-sdk/tests/task_sdk/definitions/test_baseoperator.py +++ b/task-sdk/tests/task_sdk/bases/test_baseoperator.py @@ -29,7 +29,7 @@ import structlog from airflow.decorators import task as task_decorator -from airflow.sdk.definitions.baseoperator import ( +from airflow.sdk.bases.baseoperator import ( BaseOperator, BaseOperatorMeta, ExecutorSafeguard, diff --git a/task-sdk/tests/task_sdk/definitions/test_notifier.py b/task-sdk/tests/task_sdk/bases/test_notifier.py similarity index 98% rename from task-sdk/tests/task_sdk/definitions/test_notifier.py rename to task-sdk/tests/task_sdk/bases/test_notifier.py index cc8b5f9659b19..b8cedaa518831 100644 --- a/task-sdk/tests/task_sdk/definitions/test_notifier.py +++ b/task-sdk/tests/task_sdk/bases/test_notifier.py @@ -24,8 +24,8 @@ import pytest from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.bases.notifier import BaseNotifier from airflow.sdk.definitions.dag import DAG -from airflow.sdk.definitions.notifier import BaseNotifier if TYPE_CHECKING: from airflow.sdk.definitions.context import Context diff --git a/task-sdk/tests/task_sdk/definitions/sensors/test_base.py b/task-sdk/tests/task_sdk/bases/test_sensor.py similarity index 99% rename from task-sdk/tests/task_sdk/definitions/sensors/test_base.py rename to task-sdk/tests/task_sdk/bases/test_sensor.py index ebd815d44bf0a..2c0a82783d4d6 100644 --- a/task-sdk/tests/task_sdk/definitions/sensors/test_base.py +++ b/task-sdk/tests/task_sdk/bases/test_sensor.py @@ -34,8 +34,8 @@ ) from airflow.models.trigger import TriggerFailureReason from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.sdk.bases.sensor import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.sdk.definitions.dag import DAG -from airflow.sdk.definitions.sensors.base import BaseSensorOperator, PokeReturnValue, poke_mode_only from airflow.sdk.execution_time.comms import RescheduleTask, TaskRescheduleStartDate from airflow.utils import timezone from airflow.utils.state import State diff --git a/task-sdk/tests/task_sdk/dags/super_basic.py b/task-sdk/tests/task_sdk/dags/super_basic.py index b5a50785bcebf..2cccb9ab4c647 100644 --- a/task-sdk/tests/task_sdk/dags/super_basic.py +++ b/task-sdk/tests/task_sdk/dags/super_basic.py @@ -17,7 +17,7 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import dag diff --git a/task-sdk/tests/task_sdk/dags/super_basic_run.py b/task-sdk/tests/task_sdk/dags/super_basic_run.py index 87d2a6820226b..e178e7acc8935 100644 --- a/task-sdk/tests/task_sdk/dags/super_basic_run.py +++ b/task-sdk/tests/task_sdk/dags/super_basic_run.py @@ -17,7 +17,7 @@ from __future__ import annotations -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import dag diff --git a/task-sdk/tests/task_sdk/definitions/test_dag.py b/task-sdk/tests/task_sdk/definitions/test_dag.py index 71d888d3af10e..07a99221d496a 100644 --- a/task-sdk/tests/task_sdk/definitions/test_dag.py +++ b/task-sdk/tests/task_sdk/definitions/test_dag.py @@ -23,7 +23,7 @@ import pytest from airflow.exceptions import DuplicateTaskIdFound -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG, dag as dag_decorator from airflow.sdk.definitions.param import DagParam, Param, ParamsDict diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index a75cfacf6cb46..3c780551c6aea 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -25,7 +25,7 @@ import pytest from airflow.sdk.api.datamodels._generated import TerminalTIState -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.xcom_arg import XComArg diff --git a/task-sdk/tests/task_sdk/definitions/test_mixins.py b/task-sdk/tests/task_sdk/definitions/test_mixins.py index 83b4d6eabefdf..b7ffa974758ee 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mixins.py +++ b/task-sdk/tests/task_sdk/definitions/test_mixins.py @@ -22,7 +22,7 @@ import pytest from airflow.decorators import setup, task, teardown -from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.bases.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG