From 7d9520d929572eb01646142954becde105adf698 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 11 Jan 2025 01:55:51 +0530 Subject: [PATCH 1/2] AIP-72: Support better type-hinting for `Context` dict in SDK This PR adds a `Context` class that is used for Type hinting of the Context dictionary. It replaces Context in the Scheduler. --- .pre-commit-config.yaml | 2 +- .../execution_api/datamodels/taskinstance.py | 6 +- .../execution_api/routes/task_instances.py | 7 +- airflow/dag_processing/processor.py | 2 +- airflow/decorators/base.py | 2 +- airflow/decorators/bash.py | 7 +- airflow/decorators/condition.py | 2 +- airflow/example_dags/example_dag_decorator.py | 2 +- airflow/example_dags/example_skip_dag.py | 2 +- airflow/executors/workloads.py | 4 +- airflow/lineage/__init__.py | 2 +- airflow/models/abstractoperator.py | 6 +- airflow/models/mappedoperator.py | 4 +- airflow/models/param.py | 11 +- airflow/models/skipmixin.py | 28 ++-- airflow/models/taskinstance.py | 29 +++- airflow/notifications/basenotifier.py | 2 +- airflow/operators/branch.py | 9 +- airflow/operators/email.py | 2 +- airflow/operators/empty.py | 2 +- airflow/operators/smooth.py | 2 +- airflow/sensors/base.py | 2 +- airflow/utils/context.py | 31 ++-- airflow/utils/context.pyi | 145 ------------------ airflow/utils/helpers.py | 2 +- airflow/utils/log/file_task_handler.py | 2 +- .../amazon/aws/transfers/google_api_to_s3.py | 11 +- .../providers/apache/hive/operators/hive.py | 8 +- .../cncf/kubernetes/operators/pod.py | 4 +- .../providers/edge/example_dags/win_test.py | 5 +- .../providers/google/cloud/operators/gcs.py | 26 ++-- .../providers/standard/operators/bash.py | 8 +- .../standard/operators/latest_only.py | 2 +- .../providers/standard/operators/python.py | 3 +- .../providers/standard/sensors/time_delta.py | 19 ++- .../tests/google/cloud/operators/test_gcs.py | 31 +--- .../pre_commit/template_context_key_sync.py | 9 +- task_sdk/src/airflow/sdk/api/client.py | 1 + .../airflow/sdk/api/datamodels/_generated.py | 4 +- .../definitions/_internal/contextmanager.py | 10 +- .../src/airflow/sdk/definitions/context.py | 60 +++++++- task_sdk/src/airflow/sdk/definitions/dag.py | 2 +- .../src/airflow/sdk/definitions/protocols.py | 67 ++++++++ .../src/airflow/sdk/execution_time/context.py | 5 +- .../airflow/sdk/execution_time/task_runner.py | 17 +- task_sdk/tests/conftest.py | 3 +- task_sdk/tests/execution_time/conftest.py | 5 +- .../tests/execution_time/test_task_runner.py | 4 +- .../routes/test_task_instances.py | 1 + tests/dags/test_on_kill.py | 2 +- tests/dags/test_parsing_context.py | 2 +- tests/decorators/test_condition.py | 7 +- tests/lineage/test_lineage.py | 2 +- tests/models/test_baseoperatormeta.py | 2 +- tests/models/test_mappedoperator.py | 2 +- tests/models/test_skipmixin.py | 4 +- tests/models/test_taskinstance.py | 4 +- tests/notifications/test_basenotifier.py | 2 +- tests/sensors/test_base.py | 2 +- tests/serialization/test_dag_serialization.py | 2 +- tests_common/test_utils/mock_operators.py | 2 +- tests_common/test_utils/system_tests.py | 2 +- 62 files changed, 337 insertions(+), 318 deletions(-) delete mode 100644 airflow/utils/context.pyi create mode 100644 task_sdk/src/airflow/sdk/definitions/protocols.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 06f028e1b84f8..8e3b2877895f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -695,7 +695,7 @@ repos: name: Sync template context variable refs language: python entry: ./scripts/ci/pre_commit/template_context_key_sync.py - files: ^airflow/models/taskinstance\.py$|^airflow/utils/context\.pyi?$|^docs/apache-airflow/templates-ref\.rst$ + files: ^airflow/models/taskinstance\.py$|^task_sdk/src/airflow/sdk/definitions/context\.py$|^docs/apache-airflow/templates-ref\.rst$ - id: check-base-operator-usage language: pygrep name: Check BaseOperator core imports diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..563b32a2693a1 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -163,7 +163,8 @@ class TaskInstance(BaseModel): dag_id: str run_id: str try_number: int - map_index: int | None = None + map_index: int = -1 + hostname: str | None = None class DagRun(BaseModel): @@ -190,6 +191,9 @@ class TIRunContext(BaseModel): dag_run: DagRun """DAG run information for the task instance.""" + max_tries: int + """Maximum number of tries for the task instance (from DB).""" + variables: Annotated[list[VariableResponse], Field(default_factory=list)] """Variables that can be accessed by the task instance.""" diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 6086e1093ce00..ba6ea0c14b63f 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -75,12 +75,14 @@ def ti_run( ti_id_str = str(task_instance_id) old = ( - select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method) + select(TI.state, TI.dag_id, TI.run_id, TI.task_id, TI.map_index, TI.next_method, TI.max_tries) .where(TI.id == ti_id_str) .with_for_update() ) try: - (previous_state, dag_id, run_id, task_id, map_index, next_method) = session.execute(old).one() + (previous_state, dag_id, run_id, task_id, map_index, next_method, max_tries) = session.execute( + old + ).one() except NoResultFound: log.error("Task Instance %s not found", ti_id_str) raise HTTPException( @@ -165,6 +167,7 @@ def ti_run( return TIRunContext( dag_run=DagRun.model_validate(dr, from_attributes=True), + max_tries=max_tries, # TODO: Add variables and connections that are needed (and has perms) for the task variables=[], connections=[], diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index 3b849d0e822c5..c175f3f68c726 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -39,8 +39,8 @@ if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger + from airflow.sdk.definitions.context import Context from airflow.typing_compat import Self - from airflow.utils.context import Context def _parse_file_entrypoint(): diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 0c6ae7b6d55bf..ecea5c957489e 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -63,8 +63,8 @@ OperatorExpandKwargsArgument, ) from airflow.models.mappedoperator import ValidationSource + from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG - from airflow.utils.context import Context from airflow.utils.task_group import TaskGroup diff --git a/airflow/decorators/bash.py b/airflow/decorators/bash.py index ae5b0a9e0c153..996ac5ffe05de 100644 --- a/airflow/decorators/bash.py +++ b/airflow/decorators/bash.py @@ -19,14 +19,17 @@ import warnings from collections.abc import Collection, Mapping, Sequence -from typing import Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, Callable, ClassVar from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory from airflow.providers.standard.operators.bash import BashOperator -from airflow.utils.context import Context, context_merge +from airflow.utils.context import context_merge from airflow.utils.operator_helpers import determine_kwargs from airflow.utils.types import NOTSET +if TYPE_CHECKING: + from airflow.sdk.definitions.context import Context + class _BashDecoratedOperator(DecoratedOperator, BashOperator): """ diff --git a/airflow/decorators/condition.py b/airflow/decorators/condition.py index 27c92fcc62cd8..e276b9fc7178e 100644 --- a/airflow/decorators/condition.py +++ b/airflow/decorators/condition.py @@ -26,7 +26,7 @@ from typing_extensions import TypeAlias from airflow.models.baseoperator import TaskPreExecuteHook - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context BoolConditionFunc: TypeAlias = Callable[[Context], bool] MsgConditionFunc: TypeAlias = "Callable[[Context], tuple[bool, str | None]]" diff --git a/airflow/example_dags/example_dag_decorator.py b/airflow/example_dags/example_dag_decorator.py index 447b4471b97e4..0fed70fa2e9f1 100644 --- a/airflow/example_dags/example_dag_decorator.py +++ b/airflow/example_dags/example_dag_decorator.py @@ -27,7 +27,7 @@ from airflow.operators.email import EmailOperator if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class GetRequestOperator(BaseOperator): diff --git a/airflow/example_dags/example_skip_dag.py b/airflow/example_dags/example_skip_dag.py index 2655394c6f6f4..885cbd2e43310 100644 --- a/airflow/example_dags/example_skip_dag.py +++ b/airflow/example_dags/example_skip_dag.py @@ -31,7 +31,7 @@ from airflow.utils.trigger_rule import TriggerRule if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context # Create some placeholder operators diff --git a/airflow/executors/workloads.py b/airflow/executors/workloads.py index 1afd6e10c5d48..13331f9b5793a 100644 --- a/airflow/executors/workloads.py +++ b/airflow/executors/workloads.py @@ -48,7 +48,7 @@ class TaskInstance(BaseModel): dag_id: str run_id: str try_number: int - map_index: int | None = None + map_index: int = -1 pool_slots: int queue: str @@ -64,7 +64,7 @@ def key(self) -> TaskInstanceKey: task_id=self.task_id, run_id=self.run_id, try_number=self.try_number, - map_index=-1 if self.map_index is None else self.map_index, + map_index=self.map_index, ) diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index 4385f3fbaf586..8e581952e6945 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -28,7 +28,7 @@ from airflow.utils.session import create_session if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context PIPELINE_OUTLETS = "pipeline_outlets" PIPELINE_INLETS = "pipeline_inlets" diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index dd386f6274b5b..732180d372946 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,7 +19,7 @@ import datetime import inspect -from collections.abc import Iterable, Iterator, Mapping, Sequence +from collections.abc import Iterable, Iterator, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Callable @@ -30,7 +30,7 @@ from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator -from airflow.utils.context import Context +from airflow.sdk.definitions.context import Context from airflow.utils.db import exists_query from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.setup_teardown import SetupTeardownContext @@ -512,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence def render_template_fields( self, - context: Mapping[str, Any], + context: Context, jinja_env: jinja2.Environment | None = None, ) -> None: """ diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index f4728095037d4..2ca2fe9e0004a 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -81,8 +81,8 @@ from airflow.models.operator import Operator from airflow.models.param import ParamsDict from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.context import Context from airflow.ti_deps.deps.base_ti_dep import BaseTIDep - from airflow.utils.context import Context from airflow.utils.operator_resources import Resources from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: def render_template_fields( self, - context: Mapping[str, Any], + context: Context, jinja_env: jinja2.Environment | None = None, ) -> None: """ diff --git a/airflow/models/param.py b/airflow/models/param.py index c1b47a8f5e453..788c215e00430 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -28,7 +28,6 @@ from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: - from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG @@ -332,19 +331,21 @@ def deserialize(cls, data: dict, dags: dict) -> DagParam: def process_params( dag: DAG, task: Operator, - dag_run: DagRun | None, + dagrun_conf: dict[str, Any] | None, *, suppress_exception: bool, ) -> dict[str, Any]: """Merge, validate params, and convert them into a simple dict.""" from airflow.configuration import conf + dagrun_conf = dagrun_conf or {} + params = ParamsDict(suppress_exception=suppress_exception) with contextlib.suppress(AttributeError): params.update(dag.params) if task.params: params.update(task.params) - if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf: - logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf) - params.update(dag_run.conf) + if conf.getboolean("core", "dag_run_conf_overrides_params") and dagrun_conf: + logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dagrun_conf) + params.update(dagrun_conf) return params.validate() diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 63564ebbc431d..5e3c47ad3a1cc 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -59,7 +59,8 @@ class SkipMixin(LoggingMixin): @staticmethod def _set_state_to_skipped( - dag_run: DagRun, + dag_id: str, + run_id: str, tasks: Sequence[str] | Sequence[tuple[str, int]], session: Session, ) -> None: @@ -71,8 +72,8 @@ def _set_state_to_skipped( session.execute( update(TaskInstance) .where( - TaskInstance.dag_id == dag_run.dag_id, - TaskInstance.run_id == dag_run.run_id, + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(tasks), ) .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now) @@ -82,8 +83,8 @@ def _set_state_to_skipped( session.execute( update(TaskInstance) .where( - TaskInstance.dag_id == dag_run.dag_id, - TaskInstance.run_id == dag_run.run_id, + TaskInstance.dag_id == dag_id, + TaskInstance.run_id == run_id, TaskInstance.task_id.in_(tasks), ) .values(state=TaskInstanceState.SKIPPED, start_date=now, end_date=now) @@ -93,7 +94,8 @@ def _set_state_to_skipped( @provide_session def skip( self, - dag_run: DagRun, + dag_id: str, + run_id: str, tasks: Iterable[DAGNode], map_index: int = -1, session: Session = NEW_SESSION, @@ -105,7 +107,8 @@ def skip( so that NotPreviouslySkippedDep knows these tasks should be skipped when they are cleared. - :param dag_run: the DagRun for which to set the tasks to skipped + :param dag_id: the dag_id of the dag run for which to set the tasks to skipped + :param run_id: the run_id of the dag run for which to set the tasks to skipped :param tasks: tasks to skip (not task_ids) :param session: db session to use :param map_index: map_index of the current task instance @@ -116,11 +119,8 @@ def skip( if not task_list: return - if dag_run is None: - raise ValueError("dag_run is required") - task_ids_list = [d.task_id for d in task_list] - SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session) + SkipMixin._set_state_to_skipped(dag_id, run_id, task_ids_list, session) session.commit() if task_id is not None: @@ -130,8 +130,8 @@ def skip( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_SKIPPED: task_ids_list}, task_id=task_id, - dag_id=dag_run.dag_id, - run_id=dag_run.run_id, + dag_id=dag_id, + run_id=run_id, map_index=map_index, session=session, ) @@ -225,7 +225,7 @@ def skip_all_except( follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] log.info("Skipping tasks %s", skip_tasks) - SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session) + SkipMixin._set_state_to_skipped(dag_run.dag_id, dag_run.run_id, skip_tasks, session=session) ti.xcom_push( key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session ) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 580d8cb7b8d6d..2f3fa4e8fb4a9 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -163,6 +163,7 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol from airflow.timetables.base import DataInterval from airflow.typing_compat import Literal, TypeGuard from airflow.utils.task_group import TaskGroup @@ -940,7 +941,7 @@ def _get_template_context( dag_run = task_instance.get_dagrun(session) data_interval = dag.get_run_data_interval(dag_run) - validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions) + validated_params = process_params(dag, task, dag_run.conf, suppress_exception=ignore_param_exceptions) logical_date: DateTime = timezone.coerce_datetime(task_instance.logical_date) ds = logical_date.strftime("%Y-%m-%d") @@ -1007,10 +1008,10 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: expanded_ti_count = None # NOTE: If you add to this dict, make sure to also update the following: - # * Context in airflow/utils/context.pyi + # * Context in task_sdk/src/airflow/sdk/definitions/context.py # * KNOWN_CONTEXT_KEYS in airflow/utils/context.py # * Table in docs/apache-airflow/templates-ref.rst - context: dict[str, Any] = { + context: Context = { "dag": dag, "dag_run": dag_run, "data_interval_end": timezone.coerce_datetime(data_interval.end), @@ -1048,7 +1049,7 @@ def get_triggering_events() -> dict[str, list[AssetEvent]]: } # Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it # is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890 - return Context(context) # type: ignore + return context def _is_eligible_to_retry(*, task_instance: TaskInstance): @@ -1874,6 +1875,22 @@ def operator_name(self) -> str | None: def task_display_name(self) -> str: return self._task_display_property_value or self.task_id + @classmethod + def from_runtime_ti(cls, runtime_ti: RuntimeTaskInstanceProtocol) -> TaskInstance: + if runtime_ti.map_index is None: + runtime_ti.map_index = -1 + ti = TaskInstance( + run_id=runtime_ti.run_id, + task=runtime_ti.task, # type: ignore[arg-type] + map_index=runtime_ti.map_index, + ) + ti.refresh_from_db() + + if TYPE_CHECKING: + assert ti + assert isinstance(ti, TaskInstance) + return ti + @staticmethod def _command_as_list( ti: TaskInstance, @@ -3276,7 +3293,7 @@ def render_templates( assert ti.task if ti.task.dag.__class__ is AttributeRemoved: - ti.task.dag = self.task.dag + ti.task.dag = self.task.dag # type: ignore[assignment] # If self.task is mapped, this call replaces self.task to point to the # unmapped BaseOperator created by this function! This is because the @@ -3284,7 +3301,7 @@ def render_templates( # able to access the unmapped task instead. original_task.render_template_fields(context, jinja_env) if isinstance(self.task, MappedOperator): - self.task = context["ti"].task + self.task = context["ti"].task # type: ignore[assignment] return original_task diff --git a/airflow/notifications/basenotifier.py b/airflow/notifications/basenotifier.py index ae69f07db26a5..db79a0db4482c 100644 --- a/airflow/notifications/basenotifier.py +++ b/airflow/notifications/basenotifier.py @@ -29,7 +29,7 @@ import jinja2 from airflow import DAG - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class BaseNotifier(LoggingMixin, Templater): diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py index 81a82e9d12a08..de038c68a6873 100644 --- a/airflow/operators/branch.py +++ b/airflow/operators/branch.py @@ -24,10 +24,10 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.skipmixin import SkipMixin +from airflow.models.taskinstance import TaskInstance if TYPE_CHECKING: - from airflow.models import TaskInstance - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class BranchMixIn(SkipMixin): @@ -36,8 +36,9 @@ class BranchMixIn(SkipMixin): def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) -> str | Iterable[str]: """Implement the handling of branching including logging.""" self.log.info("Branch into %s", branches_to_execute) - branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute) - self.skip_all_except(context["ti"], branch_task_ids) + ti = TaskInstance.from_runtime_ti(context["ti"]) + branch_task_ids = self._expand_task_group_roots(ti, branches_to_execute) + self.skip_all_except(ti, branch_task_ids) return branches_to_execute def _expand_task_group_roots( diff --git a/airflow/operators/email.py b/airflow/operators/email.py index e2ae26739c10f..85ac709e2bd5e 100644 --- a/airflow/operators/email.py +++ b/airflow/operators/email.py @@ -24,7 +24,7 @@ from airflow.utils.email import send_email if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class EmailOperator(BaseOperator): diff --git a/airflow/operators/empty.py b/airflow/operators/empty.py index fb116ee0f9772..dc906b4b3c17b 100644 --- a/airflow/operators/empty.py +++ b/airflow/operators/empty.py @@ -21,7 +21,7 @@ from airflow.models.baseoperator import BaseOperator if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class EmptyOperator(BaseOperator): diff --git a/airflow/operators/smooth.py b/airflow/operators/smooth.py index 927f33b1a2fe5..55b6b930d5ef2 100644 --- a/airflow/operators/smooth.py +++ b/airflow/operators/smooth.py @@ -22,7 +22,7 @@ from airflow.models.baseoperator import BaseOperator if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class SmoothOperator(BaseOperator): diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 39172ce64afd7..6330dae3c47f8 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -49,8 +49,8 @@ from airflow.utils.session import create_session if TYPE_CHECKING: + from airflow.sdk.definitions.context import Context from airflow.typing_compat import Self - from airflow.utils.context import Context # As documented in https://dev.mysql.com/doc/refman/5.7/en/datetime.html. _MYSQL_TIMESTAMP_MAX = datetime.datetime(2038, 1, 19, 3, 14, 7, tzinfo=timezone.utc) diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 10cd44585019a..1f453457e4323 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -28,8 +28,8 @@ from typing import ( TYPE_CHECKING, Any, - SupportsIndex, Union, + cast, ) import attrs @@ -53,6 +53,7 @@ AssetUriRef, BaseAssetUniqueKey, ) +from airflow.sdk.definitions.context import Context from airflow.utils.db import LazySelectSequence from airflow.utils.session import create_session from airflow.utils.types import NOTSET @@ -65,7 +66,7 @@ from airflow.models.baseoperator import BaseOperator # NOTE: Please keep this in sync with the following: -# * Context in airflow/utils/context.pyi. +# * Context in task_sdk/src/airflow/sdk/definitions/context.py # * Table in docs/apache-airflow/templates-ref.rst KNOWN_CONTEXT_KEYS: set[str] = { "conn", @@ -359,15 +360,7 @@ class AirflowContextDeprecationWarning(RemovedInAirflow3Warning): """Warn for usage of deprecated context variables in a task.""" -class Context(dict[str, Any]): - """Jinja2 template context for task rendering.""" - - def __reduce_ex__(self, protocol: SupportsIndex) -> tuple[Any, ...]: - """Pickle the context as a dict.""" - return dict, (list(self.items()),) - - -def context_merge(context: Mapping[str, Any], *args: Any, **kwargs: Any) -> None: +def context_merge(context: Context, *args: Any, **kwargs: Any) -> None: """ Merge parameters into an existing context. @@ -386,7 +379,7 @@ def context_merge(context: Mapping[str, Any], *args: Any, **kwargs: Any) -> None context.update(*args, **kwargs) -def context_update_for_unmapped(context: Mapping[str, Any], task: BaseOperator) -> None: +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: """ Update context after task unmapping. @@ -399,21 +392,19 @@ def context_update_for_unmapped(context: Mapping[str, Any], task: BaseOperator) from airflow.models.param import process_params context["task"] = context["ti"].task = task - context["params"] = process_params(context["dag"], task, context["dag_run"], suppress_exception=False) + context["params"] = process_params( + context["dag"], task, context["dag_run"].conf, suppress_exception=False + ) -def context_copy_partial(source: Mapping[str, Any], keys: Container[str]) -> Context: +def context_copy_partial(source: Context, keys: Container[str]) -> Context: """ Create a context by copying items under selected keys in ``source``. - This is implemented as a free function because the ``Context`` type is - "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom - functions. - :meta private: """ - new = Context({k: v for k, v in source.items() if k in keys}) - return new + new = {k: v for k, v in source.items() if k in keys} + return cast(Context, new) def context_get_outlet_events(context: Context) -> OutletEventAccessors: diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi deleted file mode 100644 index 2038f30ff6ab0..0000000000000 --- a/airflow/utils/context.pyi +++ /dev/null @@ -1,145 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# This stub exists to "fake" the Context class as a TypedDict to provide -# better typehint and editor support. -# -# Unfortunately 'conn', 'macros', 'var.json', and 'var.value' need to be -# annotated as Any and loose discoverability because we don't know what -# attributes are injected at runtime, and giving them a class would trigger -# undefined attribute errors from Mypy. Hopefully there will be a mechanism to -# declare "these are defined, but don't error if others are accessed" someday. -from __future__ import annotations - -from collections.abc import Collection, Container, Iterable, Iterator, Mapping, Sequence -from typing import Any, TypedDict, overload - -from pendulum import DateTime -from sqlalchemy.orm import Session - -from airflow.models.asset import AssetEvent -from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import DAG -from airflow.models.dagrun import DagRun -from airflow.models.param import ParamsDict -from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetRef, AssetUniqueKey, BaseAssetUniqueKey - -KNOWN_CONTEXT_KEYS: set[str] - -class _VariableAccessors(TypedDict): - json: Any - value: Any - -class VariableAccessor: - def __init__(self, *, deserialize_json: bool) -> None: ... - def get(self, key, default: Any = ...) -> Any: ... - -class ConnectionAccessor: - def get(self, key: str, default_conn: Any = None) -> Any: ... - -class AssetAliasEvent: - source_alias_name: str - dest_asset_key: AssetUniqueKey - extra: dict[str, Any] - def __init__( - self, source_alias_name: str, dest_asset_key: AssetUniqueKey, extra: dict[str, Any] - ) -> None: ... - -class OutletEventAccessor: - def __init__( - self, - *, - key: BaseAssetUniqueKey, - extra: dict[str, Any], - asset_alias_events: list[AssetAliasEvent], - ) -> None: ... - def add(self, asset: Asset, extra: dict[str, Any] | None = None) -> None: ... - key: BaseAssetUniqueKey - extra: dict[str, Any] - asset_alias_events: list[AssetAliasEvent] - -class OutletEventAccessors(Mapping[Asset | AssetAlias, OutletEventAccessor]): - def __iter__(self) -> Iterator[Asset | AssetAlias]: ... - def __len__(self) -> int: ... - def __getitem__(self, key: Asset | AssetAlias) -> OutletEventAccessor: ... - -class InletEventsAccessor(Sequence[AssetEvent]): - @overload - def __getitem__(self, key: int) -> AssetEvent: ... - @overload - def __getitem__(self, key: slice) -> Sequence[AssetEvent]: ... - def __len__(self) -> int: ... - -class InletEventsAccessors(Mapping[Asset | AssetAlias, InletEventsAccessor]): - def __init__(self, inlets: list, *, session: Session) -> None: ... - def __iter__(self) -> Iterator[Asset | AssetAlias]: ... - def __len__(self) -> int: ... - def __getitem__(self, key: int | Asset | AssetAlias | AssetRef) -> InletEventsAccessor: ... - -# NOTE: Please keep this in sync with the following: -# * KNOWN_CONTEXT_KEYS in airflow/utils/context.py -# * Table in docs/apache-airflow/templates-ref.rst -class Context(TypedDict, total=False): - conn: Any - dag: DAG - dag_run: DagRun - data_interval_end: DateTime - data_interval_start: DateTime - outlet_events: OutletEventAccessors - ds: str - ds_nodash: str - exception: BaseException | str | None - expanded_ti_count: int | None - inlets: list - inlet_events: InletEventsAccessors - logical_date: DateTime - macros: Any - map_index_template: str - outlets: list - params: ParamsDict - prev_data_interval_start_success: DateTime | None - prev_data_interval_end_success: DateTime | None - prev_start_date_success: DateTime | None - prev_end_date_success: DateTime | None - reason: str | None - run_id: str - task: BaseOperator - task_instance: TaskInstance - task_instance_key_str: str - test_mode: bool - templates_dict: Mapping[str, Any] | None - ti: TaskInstance - triggering_asset_events: Mapping[str, Collection[AssetEvent]] - ts: str - ts_nodash: str - ts_nodash_with_tz: str - try_number: int | None - var: _VariableAccessors - -class AirflowContextDeprecationWarning(DeprecationWarning): ... - -@overload -def context_merge(context: Context, additions: Mapping[str, Any], **kwargs: Any) -> None: ... -@overload -def context_merge(context: Context, additions: Iterable[tuple[str, Any]], **kwargs: Any) -> None: ... -@overload -def context_merge(context: Context, **kwargs: Any) -> None: ... -def context_update_for_unmapped(context: Mapping[str, Any], task: BaseOperator) -> None: ... -def context_copy_partial(source: Context, keys: Container[str]) -> Context: ... -def context_get_outlet_events(context: Context) -> OutletEventAccessors: ... diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py index 30f8dde41af61..0be76b9feebaa 100644 --- a/airflow/utils/helpers.py +++ b/airflow/utils/helpers.py @@ -36,7 +36,7 @@ import jinja2 from airflow.models.taskinstance import TaskInstance - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context KEY_REGEX = re.compile(r"^[\w.-]+$") GROUP_KEY_REGEX = re.compile(r"^[\w-]+$") diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 09866de7214ed..a8abfe0ea8638 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -34,7 +34,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.executor_loader import ExecutorLoader -from airflow.utils.context import Context +from airflow.sdk.definitions.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string from airflow.utils.log.logging_mixin import SetContextPropagate from airflow.utils.log.non_caching_file_handler import NonCachingRotatingFileHandler diff --git a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py index f1023e91b1653..a3d6bd619ce4b 100644 --- a/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py +++ b/providers/src/airflow/providers/amazon/aws/transfers/google_api_to_s3.py @@ -30,7 +30,10 @@ from airflow.providers.google.common.hooks.discovery_api import GoogleDiscoveryApiHook if TYPE_CHECKING: - from airflow.models import TaskInstance + try: + from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol + except ImportError: + from airflow.models import TaskInstance as RuntimeTaskInstanceProtocol # type: ignore[assignment] from airflow.utils.context import Context @@ -174,7 +177,7 @@ def _load_data_to_s3(self, data: dict) -> None: replace=self.s3_overwrite, ) - def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None: + def _update_google_api_endpoint_params_via_xcom(self, task_instance: RuntimeTaskInstanceProtocol) -> None: if self.google_api_endpoint_params_via_xcom: google_api_endpoint_params = task_instance.xcom_pull( task_ids=self.google_api_endpoint_params_via_xcom_task_ids, @@ -182,7 +185,9 @@ def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstanc ) self.google_api_endpoint_params.update(google_api_endpoint_params) - def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data: dict) -> None: + def _expose_google_api_response_via_xcom( + self, task_instance: RuntimeTaskInstanceProtocol, data: dict + ) -> None: if sys.getsizeof(data) < MAX_XCOM_SIZE: task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data) else: diff --git a/providers/src/airflow/providers/apache/hive/operators/hive.py b/providers/src/airflow/providers/apache/hive/operators/hive.py index e87674a552184..c131f89ee7c60 100644 --- a/providers/src/airflow/providers/apache/hive/operators/hive.py +++ b/providers/src/airflow/providers/apache/hive/operators/hive.py @@ -141,13 +141,15 @@ def execute(self, context: Context) -> None: # set the mapred_job_name if it's not set with dag, task, execution time info if not self.mapred_job_name: ti = context["ti"] - if ti.logical_date is None: + logical_date = context["logical_date"] + if logical_date is None: raise RuntimeError("logical_date is None") + hostname = ti.hostname or "" self.hook.mapred_job_name = self.mapred_job_name_template.format( dag_id=ti.dag_id, task_id=ti.task_id, - logical_date=ti.logical_date.isoformat(), - hostname=ti.hostname.split(".")[0], + logical_date=logical_date.isoformat(), + hostname=hostname.split(".")[0], ) if self.hiveconf_jinja_translate: diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py index 7adedebc69447..4326d4516f70f 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/pod.py @@ -481,10 +481,10 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool map_index = ti.map_index if map_index >= 0: - labels["map_index"] = map_index + labels["map_index"] = str(map_index) if include_try_number: - labels.update(try_number=ti.try_number) + labels.update(try_number=str(ti.try_number)) # In the case of sub dags this is just useful # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 if getattr(context["dag"], "parent_dag", False): diff --git a/providers/src/airflow/providers/edge/example_dags/win_test.py b/providers/src/airflow/providers/edge/example_dags/win_test.py index b6992ae1f0f35..a2727363d6401 100644 --- a/providers/src/airflow/providers/edge/example_dags/win_test.py +++ b/providers/src/airflow/providers/edge/example_dags/win_test.py @@ -45,7 +45,10 @@ from airflow.utils.types import ArgNotSet if TYPE_CHECKING: - from airflow.models.taskinstance import TaskInstance + try: + from airflow.sdk.definitions.protocols import RuntimeTaskInstanceProtocol as TaskInstance + except ImportError: + from airflow.models import TaskInstance # type: ignore[assignment] from airflow.utils.context import Context try: diff --git a/providers/src/airflow/providers/google/cloud/operators/gcs.py b/providers/src/airflow/providers/google/cloud/operators/gcs.py index 55835219dfe5a..8fc8eadeaedc6 100644 --- a/providers/src/airflow/providers/google/cloud/operators/gcs.py +++ b/providers/src/airflow/providers/google/cloud/operators/gcs.py @@ -787,22 +787,20 @@ def __init__( def execute(self, context: Context) -> list[str]: # Define intervals and prefixes. - try: - orig_start = context["data_interval_start"] - orig_end = context["data_interval_end"] - except KeyError: - orig_start = pendulum.instance(context["logical_date"]) - next_dagrun = context["dag"].next_dagrun_info(last_automated_dagrun=None, restricted=False) - if next_dagrun and next_dagrun.data_interval and next_dagrun.data_interval.end: - orig_end = next_dagrun.data_interval.end - else: - orig_end = None + orig_start = context["data_interval_start"] + orig_end = context["data_interval_end"] + + if orig_start is None or orig_end is None: + raise RuntimeError("`data_interval_start` & `data_interval_end` must not be None") + + if not isinstance(orig_start, pendulum.DateTime): + orig_start = pendulum.instance(orig_start) + + if not isinstance(orig_end, pendulum.DateTime): + orig_end = pendulum.instance(orig_end) timespan_start = orig_start - if orig_end is None: # Only possible in Airflow before 2.2. - self.log.warning("No following schedule found, setting timespan end to max %s", orig_end) - timespan_end = pendulum.instance(datetime.datetime.max) - elif orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules. + if orig_start >= orig_end: # Airflow 2.2 sets start == end for non-perodic schedules. self.log.warning("DAG schedule not periodic, setting timespan end to max %s", orig_end) timespan_end = pendulum.instance(datetime.datetime.max) else: diff --git a/providers/src/airflow/providers/standard/operators/bash.py b/providers/src/airflow/providers/standard/operators/bash.py index 357be4c02823d..bf006c004b2bd 100644 --- a/providers/src/airflow/providers/standard/operators/bash.py +++ b/providers/src/airflow/providers/standard/operators/bash.py @@ -34,7 +34,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session as SASession - from airflow.models.taskinstance import TaskInstance from airflow.utils.context import Context @@ -198,7 +197,7 @@ def subprocess_hook(self): # TODO: This should be replaced with Task SDK API call @staticmethod @provide_session - def refresh_bash_command(ti: TaskInstance, session: SASession = NEW_SESSION) -> None: + def refresh_bash_command(ti, session: SASession = NEW_SESSION) -> None: """ Rewrite the underlying rendered bash_command value for a task instance in the metadatabase. @@ -211,11 +210,6 @@ def refresh_bash_command(ti: TaskInstance, session: SASession = NEW_SESSION) -> from airflow.models.renderedtifields import RenderedTaskInstanceFields """Update rendered task instance fields for cases where runtime evaluated, not templated.""" - # Note: Need lazy import to break the partly loaded class loop - from airflow.models.taskinstance import TaskInstance - - # If called via remote API the DAG needs to be re-loaded - TaskInstance.ensure_dag(ti, session=session) rtif = RenderedTaskInstanceFields(ti) RenderedTaskInstanceFields.write(rtif, session=session) diff --git a/providers/src/airflow/providers/standard/operators/latest_only.py b/providers/src/airflow/providers/standard/operators/latest_only.py index ae15ee017b046..a573e05c28d61 100644 --- a/providers/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/src/airflow/providers/standard/operators/latest_only.py @@ -52,7 +52,7 @@ def choose_branch(self, context: Context) -> str | Iterable[str]: self.log.info("Externally triggered DAG_Run: allowing execution to proceed.") return list(context["task"].get_direct_relative_ids(upstream=False)) - dag: DAG = context["dag"] + dag: DAG = context["dag"] # type: ignore[assignment] next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False) now = pendulum.now("UTC") diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 35bc488860604..5048c4006958e 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -298,7 +298,8 @@ def get_tasks_to_skip(): self.log.info("Skipping downstream tasks") if AIRFLOW_V_3_0_PLUS: self.skip( - dag_run=dag_run, + dag_id=dag_run.dag_id, + run_id=dag_run.run_id, tasks=to_skip, map_index=context["ti"].map_index, ) diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index 6b09a361efadf..216a8eda40b57 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from datetime import timedelta +from datetime import datetime, timedelta from time import sleep from typing import TYPE_CHECKING, Any, NoReturn @@ -58,8 +58,12 @@ def __init__(self, *, delta, **kwargs): self.delta = delta def poke(self, context: Context): - target_dttm = context["data_interval_end"] - target_dttm += self.delta + data_interval_end = context["data_interval_end"] + + if not isinstance(data_interval_end, datetime): + raise ValueError("`data_interval_end` returned non-datetime object") + + target_dttm: datetime = data_interval_end + self.delta self.log.info("Checking if the time (%s) has come", target_dttm) return timezone.utcnow() > target_dttm @@ -84,8 +88,13 @@ def __init__(self, *, end_from_trigger: bool = False, delta, **kwargs) -> None: self.end_from_trigger = end_from_trigger def execute(self, context: Context) -> bool | NoReturn: - target_dttm = context["data_interval_end"] - target_dttm += self.delta + data_interval_end = context["data_interval_end"] + + if not isinstance(data_interval_end, datetime): + raise ValueError("`data_interval_end` returned non-datetime object") + + target_dttm: datetime = data_interval_end + self.delta + if timezone.utcnow() > target_dttm: # If the target datetime is in the past, return immediately return True diff --git a/providers/tests/google/cloud/operators/test_gcs.py b/providers/tests/google/cloud/operators/test_gcs.py index a92de4a7de240..a5eb7ddafe4fe 100644 --- a/providers/tests/google/cloud/operators/test_gcs.py +++ b/providers/tests/google/cloud/operators/test_gcs.py @@ -21,7 +21,6 @@ from pathlib import Path from unittest import mock -import pendulum import pytest from airflow.providers.common.compat.openlineage.facet import ( @@ -41,7 +40,6 @@ GCSSynchronizeBucketsOperator, GCSTimeSpanFileTransformOperator, ) -from airflow.timetables.base import DagRunInfo, DataInterval TASK_ID = "test-gcs-operator" TEST_BUCKET = "test-bucket" @@ -396,20 +394,12 @@ def test_execute(self, mock_hook, mock_subprocess, mock_tempdir): timespan_start = datetime(2015, 2, 1, 15, 16, 17, 345, tzinfo=timezone.utc) timespan_end = timespan_start + timedelta(hours=1) - mock_dag = mock.Mock() - mock_dag.next_dagrun_info.side_effect = [ - DagRunInfo( - run_after=pendulum.instance(timespan_start), - data_interval=DataInterval( - start=pendulum.instance(timespan_start), - end=pendulum.instance(timespan_end), - ), - ), - ] + mock_ti = mock.Mock() context = dict( logical_date=timespan_start, - dag=mock_dag, + data_interval_start=timespan_start, + data_interval_end=timespan_end, ti=mock_ti, ) @@ -584,19 +574,12 @@ def test_get_openlineage_facets_on_complete( file2 = "file2" timespan_start = datetime(2015, 2, 1, 15, 16, 17, 345, tzinfo=timezone.utc) - mock_dag = mock.Mock() - mock_dag.next_dagrun_info.side_effect = [ - DagRunInfo( - run_after=pendulum.instance(timespan_start), - data_interval=DataInterval( - start=pendulum.instance(timespan_start), - end=None, - ), - ), - ] + timespan_end = timespan_start + timedelta(hours=1) + context = dict( logical_date=timespan_start, - dag=mock_dag, + data_interval_start=timespan_start, + data_interval_end=timespan_end, ti=mock.Mock(), ) diff --git a/scripts/ci/pre_commit/template_context_key_sync.py b/scripts/ci/pre_commit/template_context_key_sync.py index a26615f9f2d06..2f6c6021b2ed1 100755 --- a/scripts/ci/pre_commit/template_context_key_sync.py +++ b/scripts/ci/pre_commit/template_context_key_sync.py @@ -29,7 +29,7 @@ TASKINSTANCE_PY = ROOT_DIR.joinpath("airflow", "models", "taskinstance.py") CONTEXT_PY = ROOT_DIR.joinpath("airflow", "utils", "context.py") -CONTEXT_PYI = ROOT_DIR.joinpath("airflow", "utils", "context.pyi") +CONTEXT_HINT = ROOT_DIR.joinpath("task_sdk", "src", "airflow", "sdk", "definitions", "context.py") TEMPLATES_REF_RST = ROOT_DIR.joinpath("docs", "apache-airflow", "templates-ref.rst") @@ -73,13 +73,16 @@ def _iter_template_context_keys_from_declaration() -> typing.Iterator[str]: def _iter_template_context_keys_from_type_hints() -> typing.Iterator[str]: - context_mod = ast.parse(CONTEXT_PYI.read_text("utf-8"), str(CONTEXT_PYI)) + context_mod = ast.parse(CONTEXT_HINT.read_text("utf-8"), str(CONTEXT_HINT)) cls_context = next( node for node in ast.iter_child_nodes(context_mod) if isinstance(node, ast.ClassDef) and node.name == "Context" ) for stmt in cls_context.body: + if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Constant): + # Skip docstring + continue if not isinstance(stmt, ast.AnnAssign) or not isinstance(stmt.target, ast.Name): raise ValueError("key in 'Context' hint is not an annotated assignment") yield stmt.target.id @@ -112,7 +115,7 @@ def _compare_keys(retn_keys: set[str], decl_keys: set[str], hint_keys: set[str], canonical_keys = set.union(*(s for _, s in check_candidates)) def _check_one(identifier: str, keys: set[str]) -> int: - if missing := canonical_keys.difference(retn_keys): + if missing := canonical_keys.difference(keys): print("Missing template variables from", f"{identifier}:", ", ".join(sorted(missing))) return len(missing) diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 0ef3efe95f14e..5ee270591481e 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -295,6 +295,7 @@ def noop_handler(request: httpx.Request) -> httpx.Response: "start_date": "2021-01-01T00:00:00Z", "run_type": DagRunType.MANUAL, }, + "max_tries": 0, }, ) return httpx.Response(200, json={"text": "Hello, world!"}) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index ff4cc588ff564..785037b6bfbaa 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -178,7 +178,8 @@ class TaskInstance(BaseModel): dag_id: Annotated[str, Field(title="Dag Id")] run_id: Annotated[str, Field(title="Run Id")] try_number: Annotated[int, Field(title="Try Number")] - map_index: Annotated[int | None, Field(title="Map Index")] = None + map_index: Annotated[int, Field(title="Map Index")] = -1 + hostname: Annotated[str | None, Field(title="Hostname")] = None class DagRun(BaseModel): @@ -207,6 +208,7 @@ class TIRunContext(BaseModel): """ dag_run: DagRun + max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None diff --git a/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py b/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py index cecb1a5c6ec0b..2914799a33a55 100644 --- a/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py +++ b/task_sdk/src/airflow/sdk/definitions/_internal/contextmanager.py @@ -19,13 +19,15 @@ import sys from collections import deque -from collections.abc import Mapping from types import ModuleType -from typing import Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup +if TYPE_CHECKING: + from airflow.sdk.definitions.context import Context + T = TypeVar("T") __all__ = ["DagContext", "TaskGroupContext"] @@ -34,10 +36,10 @@ # It is used to push the Context dictionary when Task starts execution # and it is used to retrieve the current context in PythonOperator or Taskflow API via # the `get_current_context` function. -_CURRENT_CONTEXT: list[Mapping[str, Any]] = [] +_CURRENT_CONTEXT: list[Context] = [] -def _get_current_context() -> Mapping[str, Any]: +def _get_current_context() -> Context: if not _CURRENT_CONTEXT: raise RuntimeError( "Current context was requested but no context was found! Are you running within an Airflow task?" diff --git a/task_sdk/src/airflow/sdk/definitions/context.py b/task_sdk/src/airflow/sdk/definitions/context.py index 41911143a1f73..a6bbc88ef86f5 100644 --- a/task_sdk/src/airflow/sdk/definitions/context.py +++ b/task_sdk/src/airflow/sdk/definitions/context.py @@ -17,11 +17,65 @@ # under the License. from __future__ import annotations -from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any, TypedDict +if TYPE_CHECKING: + # TODO: Should we use pendulum.DateTime instead of datetime like AF 2.x? + from datetime import datetime -def get_current_context() -> Mapping[str, Any]: + from airflow.models.operator import Operator + from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.dag import DAG + from airflow.sdk.definitions.protocols import DagRunProtocol, RuntimeTaskInstanceProtocol + + +class Context(TypedDict, total=False): + """Jinja2 template context for task rendering.""" + + conn: Any + dag: DAG + dag_run: DagRunProtocol + data_interval_end: datetime | None + data_interval_start: datetime | None + # outlet_events: OutletEventAccessors + outlet_events: Any + ds: str + ds_nodash: str + expanded_ti_count: int | None + exception: None | str | BaseException + inlets: list + # inlet_events: InletEventsAccessors + inlet_events: Any + logical_date: datetime + macros: Any + map_index_template: str | None + outlets: list + params: dict[str, Any] + prev_data_interval_start_success: datetime | None + prev_data_interval_end_success: datetime | None + prev_start_date_success: datetime | None + prev_end_date_success: datetime | None + reason: str | None + run_id: str + # TODO: Remove Operator from below once we have MappedOperator to the Task SDK + # and once we can remove context related code from the Scheduler/models.TaskInstance + task: BaseOperator | Operator + task_instance: RuntimeTaskInstanceProtocol + task_instance_key_str: str + # `templates_dict` is only set in PythonOperator + templates_dict: dict[str, Any] | None + test_mode: bool + ti: RuntimeTaskInstanceProtocol + # triggering_asset_events: Mapping[str, Collection[AssetEvent | AssetEventPydantic]] + triggering_asset_events: Any + try_number: int | None + ts: str + ts_nodash: str + ts_nodash_with_tz: str + var: Any + + +def get_current_context() -> Context: """ Retrieve the execution context dictionary without altering user method's signature. diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index 90b8b74360b60..016de8d3136c0 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -56,6 +56,7 @@ from airflow.sdk.definitions._internal.types import NOTSET from airflow.sdk.definitions.asset import AssetAll, BaseAsset from airflow.sdk.definitions.baseoperator import BaseOperator +from airflow.sdk.definitions.context import Context from airflow.timetables.base import Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable, @@ -63,7 +64,6 @@ NullTimetable, OnceTimetable, ) -from airflow.utils.context import Context from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.trigger_rule import TriggerRule diff --git a/task_sdk/src/airflow/sdk/definitions/protocols.py b/task_sdk/src/airflow/sdk/definitions/protocols.py new file mode 100644 index 0000000000000..80dba602ff135 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/protocols.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +if TYPE_CHECKING: + from datetime import datetime + + from airflow.sdk.definitions.baseoperator import BaseOperator + + +class DagRunProtocol(Protocol): + """Minimal interface for a DAG run available during the execution.""" + + dag_id: str + run_id: str + logical_date: datetime + data_interval_start: datetime | None + data_interval_end: datetime | None + start_date: datetime + end_date: datetime | None + run_type: Any + conf: dict[str, Any] | None + + +class RuntimeTaskInstanceProtocol(Protocol): + """Minimal interface for a task instance available during the execution.""" + + task: BaseOperator + task_id: str + dag_id: str + run_id: str + try_number: int + map_index: int + max_tries: int + hostname: str | None = None + + def xcom_pull( + self, + task_ids: str | list[str] | None = None, + dag_id: str | None = None, + key: str = "return_value", + # TODO: `include_prior_dates` isn't yet supported in the SDK + # include_prior_dates: bool = False, + *, + map_indexes: int | list[int] | None = None, + default: Any = None, + run_id: str | None = None, + ) -> Any: ... + + def xcom_push(self, key: str, value: Any) -> None: ... diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 50cbcf0a99504..cdb3880bb36b3 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -17,7 +17,7 @@ from __future__ import annotations import contextlib -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import TYPE_CHECKING, Any import structlog @@ -28,6 +28,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection + from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.variable import Variable from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult @@ -163,7 +164,7 @@ def __eq__(self, other: object) -> bool: @contextlib.contextmanager -def set_current_context(context: Mapping[str, Any]) -> Generator[Mapping[str, Any], None, None]: +def set_current_context(context: Context) -> Generator[Context, None, None]: """ Set the current execution context to the provided context object. diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index e63252efa9552..2beff84e8f8d7 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -50,11 +50,14 @@ VariableAccessor, set_current_context, ) +from airflow.utils.net import get_hostname if TYPE_CHECKING: import jinja2 from structlog.typing import FilteringBoundLogger as Logger + from airflow.sdk.definitions.context import Context + # TODO: Move this entire class into a separate file: # `airflow/sdk/execution_time/task_instance.py` @@ -66,12 +69,15 @@ class RuntimeTaskInstance(TaskInstance): _ti_context_from_server: Annotated[TIRunContext | None, Field(repr=False)] = None """The Task Instance context from the API server, if any.""" - def get_template_context(self): + max_tries: int = 0 + """The maximum number of retries for the task.""" + + def get_template_context(self) -> Context: # TODO: Move this to `airflow.sdk.execution_time.context` # once we port the entire context logic from airflow/utils/context.py ? # TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime() - context: dict[str, Any] = { + context: Context = { # From the Task Execution interface "dag": self.task.dag, "inlets": self.task.inlets, @@ -111,7 +117,7 @@ def get_template_context(self): ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S") ts_nodash_with_tz = ts.replace("-", "").replace(":", "") - context_from_server = { + context_from_server: Context = { # TODO: Assess if we need to pass these through timezone.coerce_datetime "dag_run": dag_run, "data_interval_end": dag_run.data_interval_end, @@ -125,11 +131,12 @@ def get_template_context(self): "ts_nodash_with_tz": ts_nodash_with_tz, } context.update(context_from_server) + # TODO: We should use/move TypeDict from airflow.utils.context.Context return context def render_templates( - self, context: dict[str, Any] | None = None, jinja_env: jinja2.Environment | None = None + self, context: Context | None = None, jinja_env: jinja2.Environment | None = None ) -> BaseOperator: """ Render templates in the operator fields. @@ -316,6 +323,7 @@ def parse(what: StartupDetails) -> RuntimeTaskInstance: **what.ti.model_dump(exclude_unset=True), task=task, _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, ) @@ -440,6 +448,7 @@ def run(ti: RuntimeTaskInstance, log: Logger): try: # TODO: pre execute etc. # TODO: Get a real context object + ti.hostname = get_hostname() ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() with set_current_context(context): diff --git a/task_sdk/tests/conftest.py b/task_sdk/tests/conftest.py index 25d0a1b0061b6..50429d91b018a 100644 --- a/task_sdk/tests/conftest.py +++ b/task_sdk/tests/conftest.py @@ -183,7 +183,8 @@ def _make_context( data_interval_end=data_interval_end, # type: ignore start_date=start_date, # type: ignore run_type=run_type, # type: ignore - ) + ), + max_tries=0, ) return _make_context diff --git a/task_sdk/tests/execution_time/conftest.py b/task_sdk/tests/execution_time/conftest.py index 9ff2f0378959b..507d17aa3508e 100644 --- a/task_sdk/tests/execution_time/conftest.py +++ b/task_sdk/tests/execution_time/conftest.py @@ -83,7 +83,10 @@ def set_dag(what: StartupDetails, dag_id: str, task: BaseOperator) -> RuntimeTas task.dag = dag t = dag.task_dict[task.task_id] ti = RuntimeTaskInstance.model_construct( - **what.ti.model_dump(exclude_unset=True), task=t, _ti_context_from_server=what.ti_context + **what.ti.model_dump(exclude_unset=True), + task=t, + _ti_context_from_server=what.ti_context, + max_tries=what.ti_context.max_tries, ) spy_agency.spy_on(parse, call_fake=lambda _: ti) return ti diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 3dc06379cdffa..ff0cbff631772 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -89,7 +89,7 @@ def test_recv_StartupDetails(self): b'"ti_context":{"dag_run":{"dag_id":"c","run_id":"b","logical_date":"2024-12-01T01:00:00Z",' b'"data_interval_start":"2024-12-01T00:00:00Z","data_interval_end":"2024-12-01T01:00:00Z",' b'"start_date":"2024-12-01T01:00:00Z","end_date":null,"run_type":"manual","conf":null},' - b'"variables":null,"connections":null},"file": "/dev/null", "requests_fd": ' + b'"max_tries":0,"variables":null,"connections":null},"file": "/dev/null", "requests_fd": ' + str(w2.fileno()).encode("ascii") + b"}\n" ) @@ -772,7 +772,7 @@ def execute(self, context): dag_id="test_dag", run_id="test_run", task_id=task_id, - map_index=None, + map_index=-1, ), ) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 3941835b314a7..e6da0f3a19227 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -88,6 +88,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti "run_type": "manual", "conf": {}, }, + "max_tries": 0, "variables": [], "connections": [], } diff --git a/tests/dags/test_on_kill.py b/tests/dags/test_on_kill.py index 9b9708bef7d59..93ff1f204bc0c 100644 --- a/tests/dags/test_on_kill.py +++ b/tests/dags/test_on_kill.py @@ -25,7 +25,7 @@ from airflow.utils.timezone import datetime if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class DummyWithOnKill(EmptyOperator): diff --git a/tests/dags/test_parsing_context.py b/tests/dags/test_parsing_context.py index ba3a3491caa3c..c901dbc7062d0 100644 --- a/tests/dags/test_parsing_context.py +++ b/tests/dags/test_parsing_context.py @@ -29,7 +29,7 @@ from airflow.utils.timezone import datetime if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class DagWithParsingContext(EmptyOperator): diff --git a/tests/decorators/test_condition.py b/tests/decorators/test_condition.py index 157e490e620e6..45ea6db236f5f 100644 --- a/tests/decorators/test_condition.py +++ b/tests/decorators/test_condition.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import typing from typing import TYPE_CHECKING import pytest @@ -26,7 +27,7 @@ if TYPE_CHECKING: from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context pytestmark = pytest.mark.db_test @@ -87,6 +88,8 @@ def f(): ... def test_skip_if_with_other_pre_execute(dag_maker, session): def setup_conf(context: Context) -> None: + if typing.TYPE_CHECKING: + assert context["dag_run"].conf context["dag_run"].conf["some_key"] = "some_value" with dag_maker(session=session): @@ -106,6 +109,8 @@ def f(): ... def test_run_if_with_other_pre_execute(dag_maker, session): def setup_conf(context: Context) -> None: + if typing.TYPE_CHECKING: + assert context["dag_run"].conf context["dag_run"].conf["some_key"] = "some_value" with dag_maker(session=session): diff --git a/tests/lineage/test_lineage.py b/tests/lineage/test_lineage.py index f154eb835cb82..d805d3f5ec48e 100644 --- a/tests/lineage/test_lineage.py +++ b/tests/lineage/test_lineage.py @@ -27,8 +27,8 @@ from airflow.lineage.entities import File from airflow.models import TaskInstance as TI from airflow.operators.empty import EmptyOperator +from airflow.sdk.definitions.context import Context from airflow.utils import timezone -from airflow.utils.context import Context from airflow.utils.types import DagRunType from tests_common.test_utils.config import conf_vars diff --git a/tests/models/test_baseoperatormeta.py b/tests/models/test_baseoperatormeta.py index b16ea1e68397a..a85e1fa89b703 100644 --- a/tests/models/test_baseoperatormeta.py +++ b/tests/models/test_baseoperatormeta.py @@ -33,7 +33,7 @@ from airflow.utils.state import DagRunState, State if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class HelloWorldOperator(BaseOperator): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 1ee81cae1c832..e683a00b963cc 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -54,7 +54,7 @@ pytestmark = pytest.mark.db_test if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context def test_task_mapping_with_dag(): diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 383403b9434ee..4c2e23e0ffdb4 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -65,7 +65,7 @@ def test_skip(self, mock_now, dag_maker): logical_date=now, state=State.FAILED, ) - SkipMixin().skip(dag_run=dag_run, tasks=tasks) + SkipMixin().skip(dag_id=dag_run.dag_id, run_id=dag_run.run_id, tasks=tasks) session.query(TI).filter( TI.dag_id == "dag", @@ -77,7 +77,7 @@ def test_skip(self, mock_now, dag_maker): def test_skip_none_tasks(self): session = Mock() - SkipMixin().skip(dag_run=None, tasks=[]) + SkipMixin().skip(dag_id="test_dag", run_id="test_run", tasks=[]) assert not session.query.called assert not session.commit.called diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index bb3edc53cdd28..0422ceb39949d 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2127,7 +2127,7 @@ def test_overwrite_params_with_dag_run_conf(self, create_task_instance): dag_run.conf = {"override": True} ti.task.params = {"override": False} - params = process_params(ti.task.dag, ti.task, dag_run, suppress_exception=False) + params = process_params(ti.task.dag, ti.task, dag_run.conf, suppress_exception=False) assert params["override"] is True def test_overwrite_params_with_dag_run_none(self, create_task_instance): @@ -2142,7 +2142,7 @@ def test_overwrite_params_with_dag_run_conf_none(self, create_task_instance): dag_run = ti.dag_run ti.task.params = {"override": False} - params = process_params(ti.task.dag, ti.task, dag_run, suppress_exception=False) + params = process_params(ti.task.dag, ti.task, dag_run.conf, suppress_exception=False) assert params["override"] is False @pytest.mark.parametrize("use_native_obj", [True, False]) diff --git a/tests/notifications/test_basenotifier.py b/tests/notifications/test_basenotifier.py index b2e3d751a6f30..311272dd55219 100644 --- a/tests/notifications/test_basenotifier.py +++ b/tests/notifications/test_basenotifier.py @@ -30,7 +30,7 @@ if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class MockNotifier(BaseNotifier): diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index 348062394fb2f..8d19b3c5615fe 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -64,7 +64,7 @@ pytestmark = pytest.mark.db_test if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = "unit_test_dag" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index d806e162d0c83..cae0e53c628fd 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -101,7 +101,7 @@ ) if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context repo_root = Path(airflow.__file__).parent.parent diff --git a/tests_common/test_utils/mock_operators.py b/tests_common/test_utils/mock_operators.py index f48ff17ed93eb..94c129ab2c787 100644 --- a/tests_common/test_utils/mock_operators.py +++ b/tests_common/test_utils/mock_operators.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: import jinja2 - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context class MockOperator(BaseOperator): diff --git a/tests_common/test_utils/system_tests.py b/tests_common/test_utils/system_tests.py index 6558ae2d1e4cf..9be67c06822ed 100644 --- a/tests_common/test_utils/system_tests.py +++ b/tests_common/test_utils/system_tests.py @@ -25,7 +25,7 @@ from airflow.utils.state import DagRunState if TYPE_CHECKING: - from airflow.utils.context import Context + from airflow.sdk.definitions.context import Context logger = logging.getLogger(__name__) From c16901dd1e99a702110b72d70ea41c28b058831f Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Wed, 15 Jan 2025 00:37:18 +0530 Subject: [PATCH 2/2] fixup! AIP-72: Support better type-hinting for `Context` dict in SDK --- airflow/models/baseoperator.py | 3 +- .../test_kubernetes_pod_operator.py | 2 +- .../providers/arangodb/operators/arangodb.py | 6 +++- .../providers/arangodb/sensors/arangodb.py | 6 +++- .../providers/asana/operators/asana_tasks.py | 6 +++- .../providers/celery/sensors/celery_queue.py | 6 +++- .../providers/cohere/operators/embedding.py | 6 +++- .../databricks/operators/databricks_repos.py | 6 +++- .../sensors/databricks_partition.py | 6 +++- .../databricks/sensors/databricks_sql.py | 6 +++- .../providers/datadog/sensors/datadog.py | 6 +++- .../providers/dingding/operators/dingding.py | 6 +++- .../discord/operators/discord_webhook.py | 6 +++- .../providers/docker/decorators/docker.py | 7 ++++- .../providers/docker/operators/docker.py | 6 +++- .../docker/operators/docker_swarm.py | 6 +++- .../src/airflow/providers/ftp/sensors/ftp.py | 6 +++- .../providers/github/operators/github.py | 6 +++- .../providers/github/sensors/github.py | 6 +++- .../airflow/providers/grpc/operators/grpc.py | 6 +++- .../airflow/providers/http/operators/http.py | 7 ++++- .../airflow/providers/http/sensors/http.py | 6 +++- .../providers/imap/sensors/imap_attachment.py | 6 +++- .../providers/influxdb/operators/influxdb.py | 6 +++- .../providers/jenkins/sensors/jenkins.py | 6 +++- .../airflow/providers/mongo/sensors/mongo.py | 6 +++- .../mysql/transfers/presto_to_mysql.py | 6 +++- .../providers/mysql/transfers/s3_to_mysql.py | 6 +++- .../mysql/transfers/trino_to_mysql.py | 6 +++- .../mysql/transfers/vertica_to_mysql.py | 6 +++- .../providers/neo4j/operators/neo4j.py | 6 +++- .../providers/openai/operators/openai.py | 6 +++- .../opensearch/operators/opensearch.py | 6 +++- .../opsgenie/notifications/opsgenie.py | 7 ++++- .../providers/opsgenie/operators/opsgenie.py | 6 +++- .../providers/oracle/operators/oracle.py | 6 +++- .../oracle/transfers/oracle_to_oracle.py | 6 +++- .../papermill/operators/papermill.py | 6 +++- .../providers/pinecone/operators/pinecone.py | 7 +++-- .../presto/transfers/gcs_to_presto.py | 6 +++- .../providers/qdrant/operators/qdrant.py | 6 +++- .../redis/operators/redis_publish.py | 6 +++- .../providers/redis/sensors/redis_key.py | 6 +++- .../providers/redis/sensors/redis_pub_sub.py | 6 +++- .../providers/salesforce/operators/bulk.py | 6 +++- .../operators/salesforce_apex_rest.py | 6 +++- .../providers/samba/transfers/gcs_to_samba.py | 6 +++- .../segment/operators/segment_track_event.py | 6 +++- .../airflow/providers/sftp/sensors/sftp.py | 6 +++- .../providers/slack/operators/slack.py | 6 +++- .../slack/operators/slack_webhook.py | 6 +++- .../providers/slack/transfers/sql_to_slack.py | 6 +++- .../slack/transfers/sql_to_slack_webhook.py | 6 +++- .../airflow/providers/smtp/operators/smtp.py | 6 +++- .../snowflake/operators/snowflake.py | 6 +++- .../providers/standard/operators/bash.py | 6 +++- .../providers/standard/operators/datetime.py | 6 +++- .../standard/operators/generic_transfer.py | 6 +++- .../standard/operators/latest_only.py | 7 ++++- .../providers/standard/operators/python.py | 6 +++- .../standard/operators/trigger_dagrun.py | 7 ++++- .../providers/standard/operators/weekday.py | 6 +++- .../providers/standard/sensors/bash.py | 6 +++- .../providers/standard/sensors/date_time.py | 6 +++- .../standard/sensors/external_task.py | 7 ++++- .../providers/standard/sensors/filesystem.py | 6 +++- .../providers/standard/sensors/python.py | 6 +++- .../providers/standard/sensors/time.py | 6 +++- .../providers/standard/sensors/time_delta.py | 6 +++- .../providers/standard/sensors/weekday.py | 6 +++- .../providers/tableau/operators/tableau.py | 6 +++- .../providers/tableau/sensors/tableau.py | 6 +++- .../providers/telegram/operators/telegram.py | 6 +++- .../providers/teradata/operators/teradata.py | 6 +++- .../operators/teradata_compute_cluster.py | 12 ++++++-- .../teradata/transfers/s3_to_teradata.py | 6 +++- .../transfers/teradata_to_teradata.py | 6 +++- .../providers/trino/transfers/gcs_to_trino.py | 6 +++- .../providers/weaviate/operators/weaviate.py | 6 +++- .../src/airflow/providers/yandex/links/yq.py | 7 ++++- .../providers/yandex/operators/dataproc.py | 6 +++- .../airflow/providers/yandex/operators/yq.py | 6 +++- .../tests/amazon/aws/sensors/test_emr_base.py | 6 +++- .../tests/microsoft/azure/hooks/test_asb.py | 7 ++++- .../microsoft/azure/operators/test_asb.py | 7 ++++- .../operators/test_container_instances.py | 7 ++++- .../microsoft/azure/operators/test_msgraph.py | 6 +++- providers/tests/microsoft/conftest.py | 7 ++++- providers/tests/mongo/sensors/test_mongo.py | 7 ++++- .../tests/openai/operators/test_openai.py | 7 ++++- .../openlineage/extractors/test_manager.py | 6 +++- .../tests/standard/operators/test_python.py | 28 ++++++++++++++++--- .../tests/system/openlineage/operator.py | 6 +++- 93 files changed, 495 insertions(+), 98 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index d39586c8a19f3..88a656eb7b1ed 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -80,6 +80,7 @@ # Keeping this file at all is a temp thing as we migrate the repo to the task sdk as the base, but to keep # main working and useful for others to develop against we use the TaskSDK here but keep this file around +from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.dag import DAG, BaseOperator as TaskSDKBaseOperator from airflow.sdk.definitions.edges import EdgeModifier as TaskSDKEdgeModifier from airflow.serialization.enums import DagAttributeTypes @@ -89,7 +90,7 @@ from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone -from airflow.utils.context import Context, context_get_outlet_events +from airflow.utils.context import context_get_outlet_events from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.operator_resources import Resources diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index e4f0fd0077b20..18570ec9fc9c2 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -39,8 +39,8 @@ from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodManager +from airflow.sdk.definitions.context import Context from airflow.utils import timezone -from airflow.utils.context import Context from airflow.utils.types import DagRunType from airflow.version import version as airflow_version from kubernetes_tests.test_base import BaseK8STest, StringContainingId diff --git a/providers/src/airflow/providers/arangodb/operators/arangodb.py b/providers/src/airflow/providers/arangodb/operators/arangodb.py index cb8257495e431..6514f756bd208 100644 --- a/providers/src/airflow/providers/arangodb/operators/arangodb.py +++ b/providers/src/airflow/providers/arangodb/operators/arangodb.py @@ -25,7 +25,11 @@ from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class AQLOperator(BaseOperator): diff --git a/providers/src/airflow/providers/arangodb/sensors/arangodb.py b/providers/src/airflow/providers/arangodb/sensors/arangodb.py index aaa02ddc6bd43..f5be45650e2fa 100644 --- a/providers/src/airflow/providers/arangodb/sensors/arangodb.py +++ b/providers/src/airflow/providers/arangodb/sensors/arangodb.py @@ -24,7 +24,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class AQLSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/asana/operators/asana_tasks.py b/providers/src/airflow/providers/asana/operators/asana_tasks.py index cb3ec87eea284..39ba24002cc4b 100644 --- a/providers/src/airflow/providers/asana/operators/asana_tasks.py +++ b/providers/src/airflow/providers/asana/operators/asana_tasks.py @@ -23,7 +23,11 @@ from airflow.providers.asana.hooks.asana import AsanaHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class AsanaCreateTaskOperator(BaseOperator): diff --git a/providers/src/airflow/providers/celery/sensors/celery_queue.py b/providers/src/airflow/providers/celery/sensors/celery_queue.py index a985c3006ab03..c8c3b3d131162 100644 --- a/providers/src/airflow/providers/celery/sensors/celery_queue.py +++ b/providers/src/airflow/providers/celery/sensors/celery_queue.py @@ -24,7 +24,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class CeleryQueueSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/cohere/operators/embedding.py b/providers/src/airflow/providers/cohere/operators/embedding.py index c5de22b9b582f..b06f13ab02194 100644 --- a/providers/src/airflow/providers/cohere/operators/embedding.py +++ b/providers/src/airflow/providers/cohere/operators/embedding.py @@ -28,7 +28,11 @@ from cohere.core.request_options import RequestOptions from cohere.types import EmbedByTypeResponseEmbeddings - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class CohereEmbeddingOperator(BaseOperator): diff --git a/providers/src/airflow/providers/databricks/operators/databricks_repos.py b/providers/src/airflow/providers/databricks/operators/databricks_repos.py index 78ccffe97e9f3..75e6f9f8f9f6e 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_repos.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_repos.py @@ -30,7 +30,11 @@ from airflow.providers.databricks.hooks.databricks import DatabricksHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DatabricksReposCreateOperator(BaseOperator): diff --git a/providers/src/airflow/providers/databricks/sensors/databricks_partition.py b/providers/src/airflow/providers/databricks/sensors/databricks_partition.py index 3577e26a813d0..df9f8cf3fe1d1 100644 --- a/providers/src/airflow/providers/databricks/sensors/databricks_partition.py +++ b/providers/src/airflow/providers/databricks/sensors/databricks_partition.py @@ -33,7 +33,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DatabricksPartitionSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/databricks/sensors/databricks_sql.py b/providers/src/airflow/providers/databricks/sensors/databricks_sql.py index bb78d5186f65f..0d0d1871b270f 100644 --- a/providers/src/airflow/providers/databricks/sensors/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/sensors/databricks_sql.py @@ -30,7 +30,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DatabricksSqlSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/datadog/sensors/datadog.py b/providers/src/airflow/providers/datadog/sensors/datadog.py index 0eb4d4fb567d5..bf74e5fd86481 100644 --- a/providers/src/airflow/providers/datadog/sensors/datadog.py +++ b/providers/src/airflow/providers/datadog/sensors/datadog.py @@ -26,7 +26,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DatadogSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/dingding/operators/dingding.py b/providers/src/airflow/providers/dingding/operators/dingding.py index 254f76fde5210..3f092c83a119f 100644 --- a/providers/src/airflow/providers/dingding/operators/dingding.py +++ b/providers/src/airflow/providers/dingding/operators/dingding.py @@ -24,7 +24,11 @@ from airflow.providers.dingding.hooks.dingding import DingdingHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DingdingOperator(BaseOperator): diff --git a/providers/src/airflow/providers/discord/operators/discord_webhook.py b/providers/src/airflow/providers/discord/operators/discord_webhook.py index 027faca6e376f..0d87ebc4a0b49 100644 --- a/providers/src/airflow/providers/discord/operators/discord_webhook.py +++ b/providers/src/airflow/providers/discord/operators/discord_webhook.py @@ -25,7 +25,11 @@ from airflow.providers.http.operators.http import HttpOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DiscordWebhookOperator(HttpOperator): diff --git a/providers/src/airflow/providers/docker/decorators/docker.py b/providers/src/airflow/providers/docker/decorators/docker.py index 77355ff03b2ca..69ef22e91bb7b 100644 --- a/providers/src/airflow/providers/docker/decorators/docker.py +++ b/providers/src/airflow/providers/docker/decorators/docker.py @@ -31,7 +31,12 @@ from typing import Literal from airflow.decorators.base import TaskDecorator - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context Serializer = Literal["pickle", "dill", "cloudpickle"] diff --git a/providers/src/airflow/providers/docker/operators/docker.py b/providers/src/airflow/providers/docker/operators/docker.py index a7e8375c91e67..8bdad151e258f 100644 --- a/providers/src/airflow/providers/docker/operators/docker.py +++ b/providers/src/airflow/providers/docker/operators/docker.py @@ -48,7 +48,11 @@ from docker import APIClient from docker.types import DeviceRequest - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context def stringify(line: str | bytes): diff --git a/providers/src/airflow/providers/docker/operators/docker_swarm.py b/providers/src/airflow/providers/docker/operators/docker_swarm.py index 6b2d99fb034c8..50df76ae12d99 100644 --- a/providers/src/airflow/providers/docker/operators/docker_swarm.py +++ b/providers/src/airflow/providers/docker/operators/docker_swarm.py @@ -32,7 +32,11 @@ from airflow.utils.strings import get_random_string if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DockerSwarmOperator(DockerOperator): diff --git a/providers/src/airflow/providers/ftp/sensors/ftp.py b/providers/src/airflow/providers/ftp/sensors/ftp.py index 9d384c889c7ba..7b54e925f923a 100644 --- a/providers/src/airflow/providers/ftp/sensors/ftp.py +++ b/providers/src/airflow/providers/ftp/sensors/ftp.py @@ -26,7 +26,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class FTPSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/github/operators/github.py b/providers/src/airflow/providers/github/operators/github.py index 1aca2625f3993..82996d3ecedb4 100644 --- a/providers/src/airflow/providers/github/operators/github.py +++ b/providers/src/airflow/providers/github/operators/github.py @@ -26,7 +26,11 @@ from airflow.providers.github.hooks.github import GithubHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GithubOperator(BaseOperator): diff --git a/providers/src/airflow/providers/github/sensors/github.py b/providers/src/airflow/providers/github/sensors/github.py index f742121b05827..cacaef9e32fb8 100644 --- a/providers/src/airflow/providers/github/sensors/github.py +++ b/providers/src/airflow/providers/github/sensors/github.py @@ -26,7 +26,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GithubSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/grpc/operators/grpc.py b/providers/src/airflow/providers/grpc/operators/grpc.py index 261a4286a05f3..020852ccc1730 100644 --- a/providers/src/airflow/providers/grpc/operators/grpc.py +++ b/providers/src/airflow/providers/grpc/operators/grpc.py @@ -24,7 +24,11 @@ from airflow.providers.grpc.hooks.grpc import GrpcHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GrpcOperator(BaseOperator): diff --git a/providers/src/airflow/providers/http/operators/http.py b/providers/src/airflow/providers/http/operators/http.py index e6585f8a73226..b4b7ce7e012c1 100644 --- a/providers/src/airflow/providers/http/operators/http.py +++ b/providers/src/airflow/providers/http/operators/http.py @@ -35,7 +35,12 @@ from requests.auth import AuthBase from airflow.providers.http.hooks.http import HttpHook - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class HttpOperator(BaseOperator): diff --git a/providers/src/airflow/providers/http/sensors/http.py b/providers/src/airflow/providers/http/sensors/http.py index 23eb033a04faf..822e3937c438d 100644 --- a/providers/src/airflow/providers/http/sensors/http.py +++ b/providers/src/airflow/providers/http/sensors/http.py @@ -28,7 +28,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class HttpSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/imap/sensors/imap_attachment.py b/providers/src/airflow/providers/imap/sensors/imap_attachment.py index 2109cadcd3060..6a802878323fa 100644 --- a/providers/src/airflow/providers/imap/sensors/imap_attachment.py +++ b/providers/src/airflow/providers/imap/sensors/imap_attachment.py @@ -26,7 +26,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class ImapAttachmentSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/influxdb/operators/influxdb.py b/providers/src/airflow/providers/influxdb/operators/influxdb.py index ba91ea898c53e..6c3cba0b14394 100644 --- a/providers/src/airflow/providers/influxdb/operators/influxdb.py +++ b/providers/src/airflow/providers/influxdb/operators/influxdb.py @@ -24,7 +24,11 @@ from airflow.providers.influxdb.hooks.influxdb import InfluxDBHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class InfluxDBOperator(BaseOperator): diff --git a/providers/src/airflow/providers/jenkins/sensors/jenkins.py b/providers/src/airflow/providers/jenkins/sensors/jenkins.py index 75585d66bcde7..29e69b4047afc 100644 --- a/providers/src/airflow/providers/jenkins/sensors/jenkins.py +++ b/providers/src/airflow/providers/jenkins/sensors/jenkins.py @@ -25,7 +25,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class JenkinsBuildSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/mongo/sensors/mongo.py b/providers/src/airflow/providers/mongo/sensors/mongo.py index edf420f479627..d4505157444d1 100644 --- a/providers/src/airflow/providers/mongo/sensors/mongo.py +++ b/providers/src/airflow/providers/mongo/sensors/mongo.py @@ -24,7 +24,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class MongoSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/mysql/transfers/presto_to_mysql.py b/providers/src/airflow/providers/mysql/transfers/presto_to_mysql.py index 3849de6f7efb5..0563cf4a9bb65 100644 --- a/providers/src/airflow/providers/mysql/transfers/presto_to_mysql.py +++ b/providers/src/airflow/providers/mysql/transfers/presto_to_mysql.py @@ -25,7 +25,11 @@ from airflow.providers.presto.hooks.presto import PrestoHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class PrestoToMySqlOperator(BaseOperator): diff --git a/providers/src/airflow/providers/mysql/transfers/s3_to_mysql.py b/providers/src/airflow/providers/mysql/transfers/s3_to_mysql.py index 930528022c36c..eba18dec251e2 100644 --- a/providers/src/airflow/providers/mysql/transfers/s3_to_mysql.py +++ b/providers/src/airflow/providers/mysql/transfers/s3_to_mysql.py @@ -25,7 +25,11 @@ from airflow.providers.mysql.hooks.mysql import MySqlHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class S3ToMySqlOperator(BaseOperator): diff --git a/providers/src/airflow/providers/mysql/transfers/trino_to_mysql.py b/providers/src/airflow/providers/mysql/transfers/trino_to_mysql.py index 77347db1bb3bc..e36d28390c55e 100644 --- a/providers/src/airflow/providers/mysql/transfers/trino_to_mysql.py +++ b/providers/src/airflow/providers/mysql/transfers/trino_to_mysql.py @@ -25,7 +25,11 @@ from airflow.providers.trino.hooks.trino import TrinoHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TrinoToMySqlOperator(BaseOperator): diff --git a/providers/src/airflow/providers/mysql/transfers/vertica_to_mysql.py b/providers/src/airflow/providers/mysql/transfers/vertica_to_mysql.py index c871ce0dafb75..7fbecc4f9ab1c 100644 --- a/providers/src/airflow/providers/mysql/transfers/vertica_to_mysql.py +++ b/providers/src/airflow/providers/mysql/transfers/vertica_to_mysql.py @@ -38,7 +38,11 @@ from airflow.providers.vertica.hooks.vertica import VerticaHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class VerticaToMySqlOperator(BaseOperator): diff --git a/providers/src/airflow/providers/neo4j/operators/neo4j.py b/providers/src/airflow/providers/neo4j/operators/neo4j.py index 7740dfbe2669e..7211b8bab5936 100644 --- a/providers/src/airflow/providers/neo4j/operators/neo4j.py +++ b/providers/src/airflow/providers/neo4j/operators/neo4j.py @@ -24,7 +24,11 @@ from airflow.providers.neo4j.hooks.neo4j import Neo4jHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class Neo4jOperator(BaseOperator): diff --git a/providers/src/airflow/providers/openai/operators/openai.py b/providers/src/airflow/providers/openai/operators/openai.py index d38caf8e2af88..63ca8380c8888 100644 --- a/providers/src/airflow/providers/openai/operators/openai.py +++ b/providers/src/airflow/providers/openai/operators/openai.py @@ -29,7 +29,11 @@ from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OpenAIEmbeddingOperator(BaseOperator): diff --git a/providers/src/airflow/providers/opensearch/operators/opensearch.py b/providers/src/airflow/providers/opensearch/operators/opensearch.py index 4351f2dc1bbf1..9174e03b36936 100644 --- a/providers/src/airflow/providers/opensearch/operators/opensearch.py +++ b/providers/src/airflow/providers/opensearch/operators/opensearch.py @@ -31,7 +31,11 @@ if TYPE_CHECKING: from opensearchpy import Connection as OpenSearchConnectionClass - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OpenSearchQueryOperator(BaseOperator): diff --git a/providers/src/airflow/providers/opsgenie/notifications/opsgenie.py b/providers/src/airflow/providers/opsgenie/notifications/opsgenie.py index 6d29210364b58..f87e2f9ae51b7 100644 --- a/providers/src/airflow/providers/opsgenie/notifications/opsgenie.py +++ b/providers/src/airflow/providers/opsgenie/notifications/opsgenie.py @@ -26,7 +26,12 @@ if TYPE_CHECKING: from airflow.providers.opsgenie.typing.opsgenie import CreateAlertPayload - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OpsgenieNotifier(BaseNotifier): diff --git a/providers/src/airflow/providers/opsgenie/operators/opsgenie.py b/providers/src/airflow/providers/opsgenie/operators/opsgenie.py index 43d161b15206c..1ee2f17831787 100644 --- a/providers/src/airflow/providers/opsgenie/operators/opsgenie.py +++ b/providers/src/airflow/providers/opsgenie/operators/opsgenie.py @@ -24,7 +24,11 @@ from airflow.providers.opsgenie.hooks.opsgenie import OpsgenieAlertHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OpsgenieCreateAlertOperator(BaseOperator): diff --git a/providers/src/airflow/providers/oracle/operators/oracle.py b/providers/src/airflow/providers/oracle/operators/oracle.py index 9e1247262ffaf..0eaf32ae6816f 100644 --- a/providers/src/airflow/providers/oracle/operators/oracle.py +++ b/providers/src/airflow/providers/oracle/operators/oracle.py @@ -27,7 +27,11 @@ from airflow.providers.oracle.hooks.oracle import OracleHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OracleStoredProcedureOperator(BaseOperator): diff --git a/providers/src/airflow/providers/oracle/transfers/oracle_to_oracle.py b/providers/src/airflow/providers/oracle/transfers/oracle_to_oracle.py index f97fda2bb3d9e..6e10383609acf 100644 --- a/providers/src/airflow/providers/oracle/transfers/oracle_to_oracle.py +++ b/providers/src/airflow/providers/oracle/transfers/oracle_to_oracle.py @@ -24,7 +24,11 @@ from airflow.providers.oracle.hooks.oracle import OracleHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class OracleToOracleOperator(BaseOperator): diff --git a/providers/src/airflow/providers/papermill/operators/papermill.py b/providers/src/airflow/providers/papermill/operators/papermill.py index acc4ccd9b9e0e..b07c83d048f9b 100644 --- a/providers/src/airflow/providers/papermill/operators/papermill.py +++ b/providers/src/airflow/providers/papermill/operators/papermill.py @@ -29,7 +29,11 @@ from airflow.providers.papermill.hooks.kernel import REMOTE_KERNEL_ENGINE, KernelHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context @attr.s(auto_attribs=True) diff --git a/providers/src/airflow/providers/pinecone/operators/pinecone.py b/providers/src/airflow/providers/pinecone/operators/pinecone.py index 040bef46227e1..b2d363f139b73 100644 --- a/providers/src/airflow/providers/pinecone/operators/pinecone.py +++ b/providers/src/airflow/providers/pinecone/operators/pinecone.py @@ -23,12 +23,15 @@ from airflow.models import BaseOperator from airflow.providers.pinecone.hooks.pinecone import PineconeHook -from airflow.utils.context import Context if TYPE_CHECKING: from pinecone import Vector - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class PineconeIngestOperator(BaseOperator): diff --git a/providers/src/airflow/providers/presto/transfers/gcs_to_presto.py b/providers/src/airflow/providers/presto/transfers/gcs_to_presto.py index 876e923c7da1e..d64ecbfbee170 100644 --- a/providers/src/airflow/providers/presto/transfers/gcs_to_presto.py +++ b/providers/src/airflow/providers/presto/transfers/gcs_to_presto.py @@ -30,7 +30,11 @@ from airflow.providers.presto.hooks.presto import PrestoHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GCSToPrestoOperator(BaseOperator): diff --git a/providers/src/airflow/providers/qdrant/operators/qdrant.py b/providers/src/airflow/providers/qdrant/operators/qdrant.py index f747a5809d84d..e22171e9aab84 100644 --- a/providers/src/airflow/providers/qdrant/operators/qdrant.py +++ b/providers/src/airflow/providers/qdrant/operators/qdrant.py @@ -27,7 +27,11 @@ if TYPE_CHECKING: from qdrant_client.models import VectorStruct - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class QdrantIngestOperator(BaseOperator): diff --git a/providers/src/airflow/providers/redis/operators/redis_publish.py b/providers/src/airflow/providers/redis/operators/redis_publish.py index eebaf3e1102df..fef0ff9bdd73d 100644 --- a/providers/src/airflow/providers/redis/operators/redis_publish.py +++ b/providers/src/airflow/providers/redis/operators/redis_publish.py @@ -24,7 +24,11 @@ from airflow.providers.redis.hooks.redis import RedisHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class RedisPublishOperator(BaseOperator): diff --git a/providers/src/airflow/providers/redis/sensors/redis_key.py b/providers/src/airflow/providers/redis/sensors/redis_key.py index 3fba123e6026e..c39426658dfce 100644 --- a/providers/src/airflow/providers/redis/sensors/redis_key.py +++ b/providers/src/airflow/providers/redis/sensors/redis_key.py @@ -24,7 +24,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class RedisKeySensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/redis/sensors/redis_pub_sub.py b/providers/src/airflow/providers/redis/sensors/redis_pub_sub.py index 2c0fed4e6a81c..42502f1f92db0 100644 --- a/providers/src/airflow/providers/redis/sensors/redis_pub_sub.py +++ b/providers/src/airflow/providers/redis/sensors/redis_pub_sub.py @@ -25,7 +25,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class RedisPubSubSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/salesforce/operators/bulk.py b/providers/src/airflow/providers/salesforce/operators/bulk.py index 779ab08b056f6..4baff384efc10 100644 --- a/providers/src/airflow/providers/salesforce/operators/bulk.py +++ b/providers/src/airflow/providers/salesforce/operators/bulk.py @@ -26,7 +26,11 @@ from simple_salesforce.bulk import SFBulkHandler from typing_extensions import Literal - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SalesforceBulkOperator(BaseOperator): diff --git a/providers/src/airflow/providers/salesforce/operators/salesforce_apex_rest.py b/providers/src/airflow/providers/salesforce/operators/salesforce_apex_rest.py index 8411e2320d240..66f731715d2d4 100644 --- a/providers/src/airflow/providers/salesforce/operators/salesforce_apex_rest.py +++ b/providers/src/airflow/providers/salesforce/operators/salesforce_apex_rest.py @@ -22,7 +22,11 @@ from airflow.providers.salesforce.hooks.salesforce import SalesforceHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SalesforceApexRestOperator(BaseOperator): diff --git a/providers/src/airflow/providers/samba/transfers/gcs_to_samba.py b/providers/src/airflow/providers/samba/transfers/gcs_to_samba.py index bfc1f0f2c1097..74120642442c1 100644 --- a/providers/src/airflow/providers/samba/transfers/gcs_to_samba.py +++ b/providers/src/airflow/providers/samba/transfers/gcs_to_samba.py @@ -32,7 +32,11 @@ WILDCARD = "*" if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GCSToSambaOperator(BaseOperator): diff --git a/providers/src/airflow/providers/segment/operators/segment_track_event.py b/providers/src/airflow/providers/segment/operators/segment_track_event.py index e3274fe375a57..f878addfa2744 100644 --- a/providers/src/airflow/providers/segment/operators/segment_track_event.py +++ b/providers/src/airflow/providers/segment/operators/segment_track_event.py @@ -24,7 +24,11 @@ from airflow.providers.segment.hooks.segment import SegmentHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SegmentTrackEventOperator(BaseOperator): diff --git a/providers/src/airflow/providers/sftp/sensors/sftp.py b/providers/src/airflow/providers/sftp/sensors/sftp.py index 9a5cb14345282..fa6a5219a7e53 100644 --- a/providers/src/airflow/providers/sftp/sensors/sftp.py +++ b/providers/src/airflow/providers/sftp/sensors/sftp.py @@ -34,7 +34,11 @@ from airflow.utils.timezone import convert_to_utc, parse if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SFTPSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/slack/operators/slack.py b/providers/src/airflow/providers/slack/operators/slack.py index 440ec377fdc47..6595bb70837dd 100644 --- a/providers/src/airflow/providers/slack/operators/slack.py +++ b/providers/src/airflow/providers/slack/operators/slack.py @@ -30,7 +30,11 @@ if TYPE_CHECKING: from slack_sdk.http_retry import RetryHandler - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SlackAPIOperator(BaseOperator): diff --git a/providers/src/airflow/providers/slack/operators/slack_webhook.py b/providers/src/airflow/providers/slack/operators/slack_webhook.py index e17c3c706baa8..58c6906c5ea02 100644 --- a/providers/src/airflow/providers/slack/operators/slack_webhook.py +++ b/providers/src/airflow/providers/slack/operators/slack_webhook.py @@ -27,7 +27,11 @@ if TYPE_CHECKING: from slack_sdk.http_retry import RetryHandler - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SlackWebhookOperator(BaseOperator): diff --git a/providers/src/airflow/providers/slack/transfers/sql_to_slack.py b/providers/src/airflow/providers/slack/transfers/sql_to_slack.py index 2ad3b9f341ac9..82aaf0523b8b1 100644 --- a/providers/src/airflow/providers/slack/transfers/sql_to_slack.py +++ b/providers/src/airflow/providers/slack/transfers/sql_to_slack.py @@ -29,7 +29,11 @@ from airflow.providers.slack.utils import parse_filename if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SqlToSlackApiFileOperator(BaseSqlToSlackOperator): diff --git a/providers/src/airflow/providers/slack/transfers/sql_to_slack_webhook.py b/providers/src/airflow/providers/slack/transfers/sql_to_slack_webhook.py index 4861a778ea7a3..c840822c609dd 100644 --- a/providers/src/airflow/providers/slack/transfers/sql_to_slack_webhook.py +++ b/providers/src/airflow/providers/slack/transfers/sql_to_slack_webhook.py @@ -26,7 +26,11 @@ from airflow.providers.slack.transfers.base_sql_to_slack import BaseSqlToSlackOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SqlToSlackWebhookOperator(BaseSqlToSlackOperator): diff --git a/providers/src/airflow/providers/smtp/operators/smtp.py b/providers/src/airflow/providers/smtp/operators/smtp.py index 3c08f8d51fd24..2c097e8aa8b84 100644 --- a/providers/src/airflow/providers/smtp/operators/smtp.py +++ b/providers/src/airflow/providers/smtp/operators/smtp.py @@ -24,7 +24,11 @@ from airflow.providers.smtp.hooks.smtp import SmtpHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class EmailOperator(BaseOperator): diff --git a/providers/src/airflow/providers/snowflake/operators/snowflake.py b/providers/src/airflow/providers/snowflake/operators/snowflake.py index ffce0c56325a1..902463b18af7d 100644 --- a/providers/src/airflow/providers/snowflake/operators/snowflake.py +++ b/providers/src/airflow/providers/snowflake/operators/snowflake.py @@ -34,7 +34,11 @@ from airflow.providers.snowflake.triggers.snowflake_trigger import SnowflakeSqlApiTrigger if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class SnowflakeCheckOperator(SQLCheckOperator): diff --git a/providers/src/airflow/providers/standard/operators/bash.py b/providers/src/airflow/providers/standard/operators/bash.py index bf006c004b2bd..d3f66380eeaff 100644 --- a/providers/src/airflow/providers/standard/operators/bash.py +++ b/providers/src/airflow/providers/standard/operators/bash.py @@ -34,7 +34,11 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session as SASession - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class BashOperator(BaseOperator): diff --git a/providers/src/airflow/providers/standard/operators/datetime.py b/providers/src/airflow/providers/standard/operators/datetime.py index cb3ce66fce202..b9fcb59f58958 100644 --- a/providers/src/airflow/providers/standard/operators/datetime.py +++ b/providers/src/airflow/providers/standard/operators/datetime.py @@ -25,7 +25,11 @@ from airflow.utils import timezone if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class BranchDateTimeOperator(BaseBranchOperator): diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 1cb3448f8a578..0d30c45dba29a 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -24,7 +24,11 @@ from airflow.models import BaseOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GenericTransfer(BaseOperator): diff --git a/providers/src/airflow/providers/standard/operators/latest_only.py b/providers/src/airflow/providers/standard/operators/latest_only.py index a573e05c28d61..c8b7ce16fb64d 100644 --- a/providers/src/airflow/providers/standard/operators/latest_only.py +++ b/providers/src/airflow/providers/standard/operators/latest_only.py @@ -28,7 +28,12 @@ if TYPE_CHECKING: from airflow.models import DAG, DagRun - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class LatestOnlyOperator(BaseBranchOperator): diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 5048c4006958e..86f5f0156d243 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -60,7 +60,11 @@ from pendulum.datetime import DateTime - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] diff --git a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py index e5f7aca313f43..920c0bf516e03 100644 --- a/providers/src/airflow/providers/standard/operators/trigger_dagrun.py +++ b/providers/src/airflow/providers/standard/operators/trigger_dagrun.py @@ -54,7 +54,12 @@ from sqlalchemy.orm.session import Session from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TriggerDagRunLink(BaseOperatorLink): diff --git a/providers/src/airflow/providers/standard/operators/weekday.py b/providers/src/airflow/providers/standard/operators/weekday.py index 7c6f7f7a9b8bb..cd5c787e0ad0a 100644 --- a/providers/src/airflow/providers/standard/operators/weekday.py +++ b/providers/src/airflow/providers/standard/operators/weekday.py @@ -25,7 +25,11 @@ from airflow.utils.weekday import WeekDay if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class BranchDayOfWeekOperator(BaseBranchOperator): diff --git a/providers/src/airflow/providers/standard/sensors/bash.py b/providers/src/airflow/providers/standard/sensors/bash.py index def3cc688b2ef..023983e7b90dc 100644 --- a/providers/src/airflow/providers/standard/sensors/bash.py +++ b/providers/src/airflow/providers/standard/sensors/bash.py @@ -27,7 +27,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class BashSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/standard/sensors/date_time.py b/providers/src/airflow/providers/standard/sensors/date_time.py index 050c36239b2d5..d04f524cbabc8 100644 --- a/providers/src/airflow/providers/standard/sensors/date_time.py +++ b/providers/src/airflow/providers/standard/sensors/date_time.py @@ -44,7 +44,11 @@ class StartTriggerArgs: # type: ignore[no-redef] from airflow.utils import timezone if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DateTimeSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/standard/sensors/external_task.py b/providers/src/airflow/providers/standard/sensors/external_task.py index ff43ff2f463f7..ad2aa89613fa1 100644 --- a/providers/src/airflow/providers/standard/sensors/external_task.py +++ b/providers/src/airflow/providers/standard/sensors/external_task.py @@ -43,7 +43,12 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class ExternalDagLink(BaseOperatorLink): diff --git a/providers/src/airflow/providers/standard/sensors/filesystem.py b/providers/src/airflow/providers/standard/sensors/filesystem.py index 650787c485e65..5f6f9e5a0fc61 100644 --- a/providers/src/airflow/providers/standard/sensors/filesystem.py +++ b/providers/src/airflow/providers/standard/sensors/filesystem.py @@ -47,7 +47,11 @@ class StartTriggerArgs: # type: ignore[no-redef] if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class FileSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/standard/sensors/python.py b/providers/src/airflow/providers/standard/sensors/python.py index 30031f968fb39..28f293135fdb3 100644 --- a/providers/src/airflow/providers/standard/sensors/python.py +++ b/providers/src/airflow/providers/standard/sensors/python.py @@ -25,7 +25,11 @@ from airflow.utils.operator_helpers import determine_kwargs if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class PythonSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/standard/sensors/time.py b/providers/src/airflow/providers/standard/sensors/time.py index 273fda9d362cb..ee9dde773fee2 100644 --- a/providers/src/airflow/providers/standard/sensors/time.py +++ b/providers/src/airflow/providers/standard/sensors/time.py @@ -43,7 +43,11 @@ class StartTriggerArgs: # type: ignore[no-redef] from airflow.utils import timezone if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TimeSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/standard/sensors/time_delta.py b/providers/src/airflow/providers/standard/sensors/time_delta.py index 216a8eda40b57..c5e3c6f8aede8 100644 --- a/providers/src/airflow/providers/standard/sensors/time_delta.py +++ b/providers/src/airflow/providers/standard/sensors/time_delta.py @@ -31,7 +31,11 @@ from airflow.utils import timezone if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context def _get_airflow_version(): diff --git a/providers/src/airflow/providers/standard/sensors/weekday.py b/providers/src/airflow/providers/standard/sensors/weekday.py index 4740022cd0669..29d3442ddb718 100644 --- a/providers/src/airflow/providers/standard/sensors/weekday.py +++ b/providers/src/airflow/providers/standard/sensors/weekday.py @@ -25,7 +25,11 @@ from airflow.utils.weekday import WeekDay if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class DayOfWeekSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/tableau/operators/tableau.py b/providers/src/airflow/providers/tableau/operators/tableau.py index 77a6508343446..951ad44535197 100644 --- a/providers/src/airflow/providers/tableau/operators/tableau.py +++ b/providers/src/airflow/providers/tableau/operators/tableau.py @@ -28,7 +28,11 @@ ) if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context RESOURCES_METHODS = { "datasources": ["delete", "refresh"], diff --git a/providers/src/airflow/providers/tableau/sensors/tableau.py b/providers/src/airflow/providers/tableau/sensors/tableau.py index 04c990412441d..4aefc55169d4e 100644 --- a/providers/src/airflow/providers/tableau/sensors/tableau.py +++ b/providers/src/airflow/providers/tableau/sensors/tableau.py @@ -27,7 +27,11 @@ from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TableauJobStatusSensor(BaseSensorOperator): diff --git a/providers/src/airflow/providers/telegram/operators/telegram.py b/providers/src/airflow/providers/telegram/operators/telegram.py index 41da0079fe0d2..641a2760839b7 100644 --- a/providers/src/airflow/providers/telegram/operators/telegram.py +++ b/providers/src/airflow/providers/telegram/operators/telegram.py @@ -27,7 +27,11 @@ from airflow.providers.telegram.hooks.telegram import TelegramHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TelegramOperator(BaseOperator): diff --git a/providers/src/airflow/providers/teradata/operators/teradata.py b/providers/src/airflow/providers/teradata/operators/teradata.py index c52904952ffff..beaca171106c6 100644 --- a/providers/src/airflow/providers/teradata/operators/teradata.py +++ b/providers/src/airflow/providers/teradata/operators/teradata.py @@ -25,7 +25,11 @@ from airflow.providers.teradata.hooks.teradata import TeradataHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TeradataOperator(SQLExecuteQueryOperator): diff --git a/providers/src/airflow/providers/teradata/operators/teradata_compute_cluster.py b/providers/src/airflow/providers/teradata/operators/teradata_compute_cluster.py index 92e4e4177b9d3..1648d8dd618e5 100644 --- a/providers/src/airflow/providers/teradata/operators/teradata_compute_cluster.py +++ b/providers/src/airflow/providers/teradata/operators/teradata_compute_cluster.py @@ -28,7 +28,11 @@ from airflow.providers.teradata.utils.constants import Constants if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context from collections.abc import Sequence from datetime import timedelta @@ -37,7 +41,11 @@ from airflow.providers.teradata.triggers.teradata_compute_cluster import TeradataComputeClusterSyncTrigger if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context from airflow.exceptions import AirflowException diff --git a/providers/src/airflow/providers/teradata/transfers/s3_to_teradata.py b/providers/src/airflow/providers/teradata/transfers/s3_to_teradata.py index 6c34efa870452..707de137ab063 100644 --- a/providers/src/airflow/providers/teradata/transfers/s3_to_teradata.py +++ b/providers/src/airflow/providers/teradata/transfers/s3_to_teradata.py @@ -32,7 +32,11 @@ from airflow.providers.teradata.hooks.teradata import TeradataHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class S3ToTeradataOperator(BaseOperator): diff --git a/providers/src/airflow/providers/teradata/transfers/teradata_to_teradata.py b/providers/src/airflow/providers/teradata/transfers/teradata_to_teradata.py index 077ce097aed5d..d2465bd91f71e 100644 --- a/providers/src/airflow/providers/teradata/transfers/teradata_to_teradata.py +++ b/providers/src/airflow/providers/teradata/transfers/teradata_to_teradata.py @@ -25,7 +25,11 @@ from airflow.providers.teradata.hooks.teradata import TeradataHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TeradataToTeradataOperator(BaseOperator): diff --git a/providers/src/airflow/providers/trino/transfers/gcs_to_trino.py b/providers/src/airflow/providers/trino/transfers/gcs_to_trino.py index 36fd082c84a91..65e8b2d6d9e59 100644 --- a/providers/src/airflow/providers/trino/transfers/gcs_to_trino.py +++ b/providers/src/airflow/providers/trino/transfers/gcs_to_trino.py @@ -30,7 +30,11 @@ from airflow.providers.trino.hooks.trino import TrinoHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class GCSToTrinoOperator(BaseOperator): diff --git a/providers/src/airflow/providers/weaviate/operators/weaviate.py b/providers/src/airflow/providers/weaviate/operators/weaviate.py index dd97dc3c66bc0..1facdf77beb0d 100644 --- a/providers/src/airflow/providers/weaviate/operators/weaviate.py +++ b/providers/src/airflow/providers/weaviate/operators/weaviate.py @@ -28,7 +28,11 @@ import pandas as pd from weaviate.types import UUID - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class WeaviateIngestOperator(BaseOperator): diff --git a/providers/src/airflow/providers/yandex/links/yq.py b/providers/src/airflow/providers/yandex/links/yq.py index b168c5b0cf67e..86babfec0f04a 100644 --- a/providers/src/airflow/providers/yandex/links/yq.py +++ b/providers/src/airflow/providers/yandex/links/yq.py @@ -23,7 +23,12 @@ if TYPE_CHECKING: from airflow.models import BaseOperator from airflow.models.taskinstancekey import TaskInstanceKey - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context XCOM_WEBLINK_KEY = "web_link" diff --git a/providers/src/airflow/providers/yandex/operators/dataproc.py b/providers/src/airflow/providers/yandex/operators/dataproc.py index 20f992496da0a..e389f85abe583 100644 --- a/providers/src/airflow/providers/yandex/operators/dataproc.py +++ b/providers/src/airflow/providers/yandex/operators/dataproc.py @@ -24,7 +24,11 @@ from airflow.providers.yandex.hooks.dataproc import DataprocHook if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context @dataclass diff --git a/providers/src/airflow/providers/yandex/operators/yq.py b/providers/src/airflow/providers/yandex/operators/yq.py index babe088fed86f..da3890f4adfec 100644 --- a/providers/src/airflow/providers/yandex/operators/yq.py +++ b/providers/src/airflow/providers/yandex/operators/yq.py @@ -25,7 +25,11 @@ from airflow.providers.yandex.links.yq import YQLink if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class YQExecuteQueryOperator(BaseOperator): diff --git a/providers/tests/amazon/aws/sensors/test_emr_base.py b/providers/tests/amazon/aws/sensors/test_emr_base.py index dfeefbb63cad9..5d69969bfbca7 100644 --- a/providers/tests/amazon/aws/sensors/test_emr_base.py +++ b/providers/tests/amazon/aws/sensors/test_emr_base.py @@ -25,7 +25,11 @@ from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context TARGET_STATE = "TARGET_STATE" FAILED_STATE = "FAILED_STATE" diff --git a/providers/tests/microsoft/azure/hooks/test_asb.py b/providers/tests/microsoft/azure/hooks/test_asb.py index 6f9203bd0d1c2..7953facfdb85a 100644 --- a/providers/tests/microsoft/azure/hooks/test_asb.py +++ b/providers/tests/microsoft/azure/hooks/test_asb.py @@ -33,7 +33,12 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.asb import AdminClientHook, MessageHook -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context MESSAGE = "Test Message" MESSAGE_LIST = [f"{MESSAGE} {n}" for n in range(10)] diff --git a/providers/tests/microsoft/azure/operators/test_asb.py b/providers/tests/microsoft/azure/operators/test_asb.py index d887db2fff044..667f40971987e 100644 --- a/providers/tests/microsoft/azure/operators/test_asb.py +++ b/providers/tests/microsoft/azure/operators/test_asb.py @@ -38,7 +38,12 @@ AzureServiceBusTopicDeleteOperator, AzureServiceBusUpdateSubscriptionOperator, ) -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context QUEUE_NAME = "test_queue" MESSAGE = "Test Message" diff --git a/providers/tests/microsoft/azure/operators/test_container_instances.py b/providers/tests/microsoft/azure/operators/test_container_instances.py index 3c1fdbffd4c4e..8985927ecfb18 100644 --- a/providers/tests/microsoft/azure/operators/test_container_instances.py +++ b/providers/tests/microsoft/azure/operators/test_container_instances.py @@ -34,7 +34,12 @@ from airflow.exceptions import AirflowException from airflow.providers.microsoft.azure.operators.container_instances import AzureContainerInstancesOperator -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context def make_mock_cg(container_state, events=None): diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index b722c4c9f0e85..cd36ae5b131e5 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -38,7 +38,11 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context class TestMSGraphAsyncOperator(Base): diff --git a/providers/tests/microsoft/conftest.py b/providers/tests/microsoft/conftest.py index 240f33e335d7e..d875096402b8b 100644 --- a/providers/tests/microsoft/conftest.py +++ b/providers/tests/microsoft/conftest.py @@ -34,7 +34,12 @@ from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context if TYPE_CHECKING: from sqlalchemy.orm import Session diff --git a/providers/tests/mongo/sensors/test_mongo.py b/providers/tests/mongo/sensors/test_mongo.py index 028e8dba33893..e02c54ee2691c 100644 --- a/providers/tests/mongo/sensors/test_mongo.py +++ b/providers/tests/mongo/sensors/test_mongo.py @@ -23,7 +23,12 @@ from airflow.providers.mongo.hooks.mongo import MongoHook from airflow.providers.mongo.sensors.mongo import MongoSensor from airflow.utils import timezone -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context DEFAULT_DATE = timezone.datetime(2017, 1, 1) diff --git a/providers/tests/openai/operators/test_openai.py b/providers/tests/openai/operators/test_openai.py index c0c6c225b015c..6494e88727c91 100644 --- a/providers/tests/openai/operators/test_openai.py +++ b/providers/tests/openai/operators/test_openai.py @@ -27,7 +27,12 @@ from airflow.exceptions import TaskDeferred from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator, OpenAITriggerBatchOperator from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger -from airflow.utils.context import Context + +try: + from airflow.sdk.definitions.context import Context +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context TASK_ID = "TaskId" CONN_ID = "test_conn_id" diff --git a/providers/tests/openlineage/extractors/test_manager.py b/providers/tests/openlineage/extractors/test_manager.py index 6b5e0dedd96ec..df64b7d1e75ef 100644 --- a/providers/tests/openlineage/extractors/test_manager.py +++ b/providers/tests/openlineage/extractors/test_manager.py @@ -42,7 +42,11 @@ from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context if AIRFLOW_V_2_10_PLUS: diff --git a/providers/tests/standard/operators/test_python.py b/providers/tests/standard/operators/test_python.py index 3899c89fae9d8..c43c00dd0e814 100644 --- a/providers/tests/standard/operators/test_python.py +++ b/providers/tests/standard/operators/test_python.py @@ -1043,7 +1043,12 @@ def f(): def test_current_context(self): def f(): from airflow.providers.standard.operators.python import get_current_context - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] @@ -1099,7 +1104,12 @@ def f(): def test_use_airflow_context_touch_other_variables(self): def f(): from airflow.providers.standard.operators.python import get_current_context - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] @@ -1477,7 +1487,12 @@ def f( def test_current_context_system_site_packages(self, session): def f(): from airflow.providers.standard.operators.python import get_current_context - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] @@ -1840,7 +1855,12 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): def test_current_context_system_site_packages(self, session): def f(): from airflow.providers.standard.operators.python import get_current_context - from airflow.utils.context import Context + + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context context = get_current_context() if not isinstance(context, Context): # type: ignore[misc] diff --git a/providers/tests/system/openlineage/operator.py b/providers/tests/system/openlineage/operator.py index a305ca7026a7f..28995b1f44eaf 100644 --- a/providers/tests/system/openlineage/operator.py +++ b/providers/tests/system/openlineage/operator.py @@ -32,7 +32,11 @@ from airflow.utils.session import create_session if TYPE_CHECKING: - from airflow.utils.context import Context + try: + from airflow.sdk.definitions.context import Context + except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context log = logging.getLogger(__name__)