diff --git a/airflow/triggers/external_task.py b/airflow/triggers/external_task.py index a5de817c357c2..cd43d59876e9e 100644 --- a/airflow/triggers/external_task.py +++ b/airflow/triggers/external_task.py @@ -21,16 +21,12 @@ from typing import Any from asgiref.sync import sync_to_async -from deprecated import deprecated from sqlalchemy import func -from airflow.exceptions import RemovedInAirflow3Warning -from airflow.models import DagRun, TaskInstance +from airflow.models import DagRun from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils.sensor_helper import _get_count from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.state import TaskInstanceState -from airflow.utils.timezone import utcnow if typing.TYPE_CHECKING: from datetime import datetime @@ -136,121 +132,6 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int: ) -@deprecated( - reason="TaskStateTrigger has been deprecated and will be removed in future.", - category=RemovedInAirflow3Warning, -) -class TaskStateTrigger(BaseTrigger): - """ - Waits asynchronously for a task in a different DAG to complete for a specific logical date. - - :param dag_id: The dag_id that contains the task you want to wait for - :param task_id: The task_id that contains the task you want to - wait for. - :param states: allowed states, default is ``['success']`` - :param execution_dates: task execution time interval - :param poll_interval: The time interval in seconds to check the state. - The default value is 5 sec. - :param trigger_start_time: time in Datetime format when the trigger was started. Is used - to control the execution of trigger to prevent infinite loop in case if specified name - of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter - from the time, when the trigger was started and if the execution lasts more time than expected, - the trigger will terminate with 'timeout' status. - """ - - def __init__( - self, - dag_id: str, - execution_dates: list[datetime], - trigger_start_time: datetime, - states: list[str] | None = None, - task_id: str | None = None, - poll_interval: float = 2.0, - ): - super().__init__() - self.dag_id = dag_id - self.task_id = task_id - self.states = states - self.execution_dates = execution_dates - self.poll_interval = poll_interval - self.trigger_start_time = trigger_start_time - self.states = states or [TaskInstanceState.SUCCESS.value] - self._timeout_sec = 60 - - def serialize(self) -> tuple[str, dict[str, typing.Any]]: - """Serialize TaskStateTrigger arguments and classpath.""" - return ( - "airflow.triggers.external_task.TaskStateTrigger", - { - "dag_id": self.dag_id, - "task_id": self.task_id, - "states": self.states, - "execution_dates": self.execution_dates, - "poll_interval": self.poll_interval, - "trigger_start_time": self.trigger_start_time, - }, - ) - - async def run(self) -> typing.AsyncIterator[TriggerEvent]: - """ - Check periodically in the database to see if the dag exists and is in the running state. - - If found, wait until the task specified will reach one of the expected states. - If dag with specified name was not in the running state after _timeout_sec seconds - after starting execution process of the trigger, terminate with status 'timeout'. - """ - try: - while True: - delta = utcnow() - self.trigger_start_time - if delta.total_seconds() < self._timeout_sec: - # mypy confuses typing here - if await self.count_running_dags() == 0: # type: ignore[call-arg] - self.log.info("Waiting for DAG to start execution...") - await asyncio.sleep(self.poll_interval) - else: - yield TriggerEvent({"status": "timeout"}) - return - # mypy confuses typing here - if await self.count_tasks() == len(self.execution_dates): # type: ignore[call-arg] - yield TriggerEvent({"status": "success"}) - return - self.log.info("Task is still running, sleeping for %s seconds...", self.poll_interval) - await asyncio.sleep(self.poll_interval) - except Exception: - yield TriggerEvent({"status": "failed"}) - - @sync_to_async - @provide_session - def count_running_dags(self, session: Session): - """Count how many dag instances in running state in the database.""" - dags = ( - session.query(func.count("*")) - .filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.execution_date.in_(self.execution_dates), - TaskInstance.state.in_(["running", "success"]), - ) - .scalar() - ) - return dags - - @sync_to_async - @provide_session - def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None: - """Count how many task instances in the database match our criteria.""" - count = ( - session.query(func.count("*")) # .count() is inefficient - .filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.state.in_(self.states), - TaskInstance.execution_date.in_(self.execution_dates), - ) - .scalar() - ) - return typing.cast(int, count) - - class DagStateTrigger(BaseTrigger): """ Waits asynchronously for a DAG to complete for a specific logical date. diff --git a/newsfragments/41737.significant.rst b/newsfragments/41737.significant.rst new file mode 100644 index 0000000000000..55704581be9b2 --- /dev/null +++ b/newsfragments/41737.significant.rst @@ -0,0 +1 @@ +Removed deprecated ``TaskStateTrigger`` from ``airflow.triggers.external_task`` module. diff --git a/tests/triggers/test_external_task.py b/tests/triggers/test_external_task.py index 7bb41c34502d6..ced867c4bd6aa 100644 --- a/tests/triggers/test_external_task.py +++ b/tests/triggers/test_external_task.py @@ -17,23 +17,17 @@ from __future__ import annotations import asyncio -import datetime import time from unittest import mock import pytest -from sqlalchemy.exc import SQLAlchemyError -from airflow.exceptions import RemovedInAirflow3Warning from airflow.models.dag import DAG from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance -from airflow.operators.empty import EmptyOperator from airflow.triggers.base import TriggerEvent -from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger, WorkflowTrigger +from airflow.triggers.external_task import DagStateTrigger, WorkflowTrigger from airflow.utils import timezone -from airflow.utils.state import DagRunState, TaskInstanceState -from airflow.utils.timezone import utcnow +from airflow.utils.state import DagRunState class TestWorkflowTrigger: @@ -222,197 +216,6 @@ def test_serialization(self): } -class TestTaskStateTrigger: - DAG_ID = "external_task" - TASK_ID = "external_task_op" - RUN_ID = "external_task_run_id" - STATES = ["success", "fail"] - - @pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode - @pytest.mark.db_test - @pytest.mark.asyncio - async def test_task_state_trigger_success(self, session): - """ - Asserts that the TaskStateTrigger only goes off on or after a TaskInstance - reaches an allowed state (i.e. SUCCESS). - """ - trigger_start_time = utcnow() - dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1)) - dag_run = DagRun( - dag_id=dag.dag_id, - run_type="manual", - execution_date=timezone.datetime(2022, 1, 1), - run_id=self.RUN_ID, - ) - session.add(dag_run) - session.commit() - - external_task = EmptyOperator(task_id=self.TASK_ID, dag=dag) - instance = TaskInstance(external_task, run_id=self.RUN_ID) - session.add(instance) - session.commit() - - with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): - trigger = TaskStateTrigger( - dag_id=dag.dag_id, - task_id=instance.task_id, - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) - - task = asyncio.create_task(trigger.run().__anext__()) - await asyncio.sleep(0.5) - - # It should not have produced a result - assert task.done() is False - - # Progress the task to a "success" state so that run() yields a TriggerEvent - instance.state = TaskInstanceState.SUCCESS - session.commit() - await asyncio.sleep(0.5) - assert task.done() is True - - # Prevents error when task is destroyed while in "pending" state - asyncio.get_event_loop().stop() - - @mock.patch("airflow.triggers.external_task.utcnow") - @pytest.mark.asyncio - async def test_task_state_trigger_timeout(self, mock_utcnow): - trigger_start_time = utcnow() - mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=61) - - with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) - - trigger.count_running_dags = mock.AsyncMock() - trigger.count_running_dags.return_value = 0 - - gen = trigger.run() - task = asyncio.create_task(gen.__anext__()) - await task - - result = task.result() - assert isinstance(result, TriggerEvent) - assert result.payload == {"status": "timeout"} - assert task.done() is True - - # test that it returns after yielding - with pytest.raises(StopAsyncIteration): - await gen.__anext__() - - @mock.patch("airflow.triggers.external_task.utcnow") - @mock.patch("airflow.triggers.external_task.asyncio.sleep") - @pytest.mark.asyncio - async def test_task_state_trigger_timeout_sleep_success(self, mock_sleep, mock_utcnow): - trigger_start_time = utcnow() - mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=20) - - with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) - - trigger.count_running_dags = mock.AsyncMock() - trigger.count_running_dags.return_value = 0 - - trigger.count_tasks = mock.AsyncMock() - trigger.count_tasks.return_value = 1 - - gen = trigger.run() - task = asyncio.create_task(gen.__anext__()) - await task - - mock_sleep.assert_awaited() - assert mock_sleep.await_count == 1 - - result = task.result() - assert isinstance(result, TriggerEvent) - assert result.payload == {"status": "success"} - assert task.done() is True - - # test that it returns after yielding - with pytest.raises(StopAsyncIteration): - await gen.__anext__() - - @mock.patch("airflow.triggers.external_task.utcnow") - @mock.patch("airflow.triggers.external_task.asyncio.sleep") - @pytest.mark.asyncio - async def test_task_state_trigger_failed_exception(self, mock_sleep, mock_utcnow): - """ - Asserts that the TaskStateTrigger only goes off on or after a TaskInstance - reaches an allowed state (i.e. SUCCESS). - """ - trigger_start_time = utcnow() - mock_utcnow.return_value = +datetime.timedelta(seconds=61) - - mock_utcnow.side_effect = [ - trigger_start_time, - trigger_start_time + datetime.timedelta(seconds=20), - ] - - with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): - trigger = TaskStateTrigger( - dag_id="dag1", - task_id="task1", - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=0.2, - trigger_start_time=trigger_start_time, - ) - - trigger.count_running_dags = mock.AsyncMock() - trigger.count_running_dags.side_effect = [SQLAlchemyError] - - gen = trigger.run() - task = asyncio.create_task(gen.__anext__()) - await task - - result = task.result() - assert isinstance(result, TriggerEvent) - assert result.payload == {"status": "failed"} - assert task.done() is True - - def test_serialization(self): - """ - Asserts that the TaskStateTrigger correctly serializes its arguments - and classpath. - """ - trigger_start_time = utcnow() - with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"): - trigger = TaskStateTrigger( - dag_id=self.DAG_ID, - task_id=self.TASK_ID, - states=self.STATES, - execution_dates=[timezone.datetime(2022, 1, 1)], - poll_interval=5, - trigger_start_time=trigger_start_time, - ) - classpath, kwargs = trigger.serialize() - assert classpath == "airflow.triggers.external_task.TaskStateTrigger" - assert kwargs == { - "dag_id": self.DAG_ID, - "task_id": self.TASK_ID, - "states": self.STATES, - "execution_dates": [timezone.datetime(2022, 1, 1)], - "poll_interval": 5, - "trigger_start_time": trigger_start_time, - } - - class TestDagStateTrigger: DAG_ID = "test_dag_state_trigger" RUN_ID = "external_task_run_id"