Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 1 addition & 120 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@
from typing import Any

from asgiref.sync import sync_to_async
from deprecated import deprecated
from sqlalchemy import func

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models import DagRun, TaskInstance
from airflow.models import DagRun
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import utcnow

if typing.TYPE_CHECKING:
from datetime import datetime
Expand Down Expand Up @@ -136,121 +132,6 @@ def _get_count(self, states: typing.Iterable[str] | None) -> int:
)


@deprecated(
reason="TaskStateTrigger has been deprecated and will be removed in future.",
category=RemovedInAirflow3Warning,
)
class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a specific logical date.

:param dag_id: The dag_id that contains the task you want to wait for
:param task_id: The task_id that contains the task you want to
wait for.
:param states: allowed states, default is ``['success']``
:param execution_dates: task execution time interval
:param poll_interval: The time interval in seconds to check the state.
The default value is 5 sec.
:param trigger_start_time: time in Datetime format when the trigger was started. Is used
to control the execution of trigger to prevent infinite loop in case if specified name
of the dag does not exist in database. It will wait period of time equals _timeout_sec parameter
from the time, when the trigger was started and if the execution lasts more time than expected,
the trigger will terminate with 'timeout' status.
"""

def __init__(
self,
dag_id: str,
execution_dates: list[datetime],
trigger_start_time: datetime,
states: list[str] | None = None,
task_id: str | None = None,
poll_interval: float = 2.0,
):
super().__init__()
self.dag_id = dag_id
self.task_id = task_id
self.states = states
self.execution_dates = execution_dates
self.poll_interval = poll_interval
self.trigger_start_time = trigger_start_time
self.states = states or [TaskInstanceState.SUCCESS.value]
self._timeout_sec = 60

def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serialize TaskStateTrigger arguments and classpath."""
return (
"airflow.triggers.external_task.TaskStateTrigger",
{
"dag_id": self.dag_id,
"task_id": self.task_id,
"states": self.states,
"execution_dates": self.execution_dates,
"poll_interval": self.poll_interval,
"trigger_start_time": self.trigger_start_time,
},
)

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""
Check periodically in the database to see if the dag exists and is in the running state.

If found, wait until the task specified will reach one of the expected states.
If dag with specified name was not in the running state after _timeout_sec seconds
after starting execution process of the trigger, terminate with status 'timeout'.
"""
try:
while True:
delta = utcnow() - self.trigger_start_time
if delta.total_seconds() < self._timeout_sec:
# mypy confuses typing here
if await self.count_running_dags() == 0: # type: ignore[call-arg]
self.log.info("Waiting for DAG to start execution...")
await asyncio.sleep(self.poll_interval)
else:
yield TriggerEvent({"status": "timeout"})
return
# mypy confuses typing here
if await self.count_tasks() == len(self.execution_dates): # type: ignore[call-arg]
yield TriggerEvent({"status": "success"})
return
self.log.info("Task is still running, sleeping for %s seconds...", self.poll_interval)
await asyncio.sleep(self.poll_interval)
except Exception:
yield TriggerEvent({"status": "failed"})

@sync_to_async
@provide_session
def count_running_dags(self, session: Session):
"""Count how many dag instances in running state in the database."""
dags = (
session.query(func.count("*"))
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.execution_date.in_(self.execution_dates),
TaskInstance.state.in_(["running", "success"]),
)
.scalar()
)
return dags

@sync_to_async
@provide_session
def count_tasks(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many task instances in the database match our criteria."""
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == self.task_id,
TaskInstance.state.in_(self.states),
TaskInstance.execution_date.in_(self.execution_dates),
)
.scalar()
)
return typing.cast(int, count)


class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a DAG to complete for a specific logical date.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/41737.significant.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Removed deprecated ``TaskStateTrigger`` from ``airflow.triggers.external_task`` module.
201 changes: 2 additions & 199 deletions tests/triggers/test_external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,17 @@
from __future__ import annotations

import asyncio
import datetime
import time
from unittest import mock

import pytest
from sqlalchemy.exc import SQLAlchemyError

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.triggers.base import TriggerEvent
from airflow.triggers.external_task import DagStateTrigger, TaskStateTrigger, WorkflowTrigger
from airflow.triggers.external_task import DagStateTrigger, WorkflowTrigger
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.timezone import utcnow
from airflow.utils.state import DagRunState


class TestWorkflowTrigger:
Expand Down Expand Up @@ -222,197 +216,6 @@ def test_serialization(self):
}


class TestTaskStateTrigger:
DAG_ID = "external_task"
TASK_ID = "external_task_op"
RUN_ID = "external_task_run_id"
STATES = ["success", "fail"]

@pytest.mark.skip_if_database_isolation_mode # Test is broken in db isolation mode
@pytest.mark.db_test
@pytest.mark.asyncio
async def test_task_state_trigger_success(self, session):
"""
Asserts that the TaskStateTrigger only goes off on or after a TaskInstance
reaches an allowed state (i.e. SUCCESS).
"""
trigger_start_time = utcnow()
dag = DAG(self.DAG_ID, schedule=None, start_date=timezone.datetime(2022, 1, 1))
dag_run = DagRun(
dag_id=dag.dag_id,
run_type="manual",
execution_date=timezone.datetime(2022, 1, 1),
run_id=self.RUN_ID,
)
session.add(dag_run)
session.commit()

external_task = EmptyOperator(task_id=self.TASK_ID, dag=dag)
instance = TaskInstance(external_task, run_id=self.RUN_ID)
session.add(instance)
session.commit()

with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id=dag.dag_id,
task_id=instance.task_id,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# It should not have produced a result
assert task.done() is False

# Progress the task to a "success" state so that run() yields a TriggerEvent
instance.state = TaskInstanceState.SUCCESS
session.commit()
await asyncio.sleep(0.5)
assert task.done() is True

# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()

@mock.patch("airflow.triggers.external_task.utcnow")
@pytest.mark.asyncio
async def test_task_state_trigger_timeout(self, mock_utcnow):
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=61)

with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0

gen = trigger.run()
task = asyncio.create_task(gen.__anext__())
await task

result = task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "timeout"}
assert task.done() is True

# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
await gen.__anext__()

@mock.patch("airflow.triggers.external_task.utcnow")
@mock.patch("airflow.triggers.external_task.asyncio.sleep")
@pytest.mark.asyncio
async def test_task_state_trigger_timeout_sleep_success(self, mock_sleep, mock_utcnow):
trigger_start_time = utcnow()
mock_utcnow.return_value = trigger_start_time + datetime.timedelta(seconds=20)

with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.return_value = 0

trigger.count_tasks = mock.AsyncMock()
trigger.count_tasks.return_value = 1

gen = trigger.run()
task = asyncio.create_task(gen.__anext__())
await task

mock_sleep.assert_awaited()
assert mock_sleep.await_count == 1

result = task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "success"}
assert task.done() is True

# test that it returns after yielding
with pytest.raises(StopAsyncIteration):
await gen.__anext__()

@mock.patch("airflow.triggers.external_task.utcnow")
@mock.patch("airflow.triggers.external_task.asyncio.sleep")
@pytest.mark.asyncio
async def test_task_state_trigger_failed_exception(self, mock_sleep, mock_utcnow):
"""
Asserts that the TaskStateTrigger only goes off on or after a TaskInstance
reaches an allowed state (i.e. SUCCESS).
"""
trigger_start_time = utcnow()
mock_utcnow.return_value = +datetime.timedelta(seconds=61)

mock_utcnow.side_effect = [
trigger_start_time,
trigger_start_time + datetime.timedelta(seconds=20),
]

with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id="dag1",
task_id="task1",
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=0.2,
trigger_start_time=trigger_start_time,
)

trigger.count_running_dags = mock.AsyncMock()
trigger.count_running_dags.side_effect = [SQLAlchemyError]

gen = trigger.run()
task = asyncio.create_task(gen.__anext__())
await task

result = task.result()
assert isinstance(result, TriggerEvent)
assert result.payload == {"status": "failed"}
assert task.done() is True

def test_serialization(self):
"""
Asserts that the TaskStateTrigger correctly serializes its arguments
and classpath.
"""
trigger_start_time = utcnow()
with pytest.warns(RemovedInAirflow3Warning, match="TaskStateTrigger has been deprecated"):
trigger = TaskStateTrigger(
dag_id=self.DAG_ID,
task_id=self.TASK_ID,
states=self.STATES,
execution_dates=[timezone.datetime(2022, 1, 1)],
poll_interval=5,
trigger_start_time=trigger_start_time,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.triggers.external_task.TaskStateTrigger"
assert kwargs == {
"dag_id": self.DAG_ID,
"task_id": self.TASK_ID,
"states": self.STATES,
"execution_dates": [timezone.datetime(2022, 1, 1)],
"poll_interval": 5,
"trigger_start_time": trigger_start_time,
}


class TestDagStateTrigger:
DAG_ID = "test_dag_state_trigger"
RUN_ID = "external_task_run_id"
Expand Down