diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 68f178c1ad380..0118661cd233c 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -766,13 +766,10 @@ def __init__(self, self._last_runtime = {} # Map from file path to the last finish time self._last_finish_time = {} - self._last_zombie_query_time = timezone.utcnow() # Last time that the DAG dir was traversed to look for files self.last_dag_dir_refresh_time = timezone.utcnow() # Last time stats were printed self.last_stat_print_time = timezone.datetime(2000, 1, 1) - # TODO: Remove magic number - self._zombie_query_interval = 10 # Map from file path to the number of runs self._run_count = defaultdict(int) # Manager heartbeat key. @@ -1243,35 +1240,31 @@ def _find_zombies(self, session): Find zombie task instances, which are tasks haven't heartbeated for too long. :return: Zombie task instances in SimpleTaskInstance format. """ - now = timezone.utcnow() + # to avoid circular imports + from airflow.jobs import LocalTaskJob as LJ + self.log.info("Finding 'running' jobs without a recent heartbeat") + TI = airflow.models.TaskInstance + limit_dttm = timezone.utcnow() - timedelta( + seconds=self._zombie_threshold_secs) + self.log.info("Failing jobs without heartbeat after %s", limit_dttm) + + tis = ( + session.query(TI) + .join(LJ, TI.job_id == LJ.id) + .filter(TI.state == State.RUNNING) + .filter( + or_( + LJ.state != State.RUNNING, + LJ.latest_heartbeat < limit_dttm, + ) + ).all() + ) zombies = [] - if (now - self._last_zombie_query_time).total_seconds() \ - > self._zombie_query_interval: - # to avoid circular imports - from airflow.jobs import LocalTaskJob as LJ - self.log.info("Finding 'running' jobs without a recent heartbeat") - TI = airflow.models.TaskInstance - limit_dttm = timezone.utcnow() - timedelta( - seconds=self._zombie_threshold_secs) - self.log.info("Failing jobs without heartbeat after %s", limit_dttm) - - tis = ( - session.query(TI) - .join(LJ, TI.job_id == LJ.id) - .filter(TI.state == State.RUNNING) - .filter( - or_( - LJ.state != State.RUNNING, - LJ.latest_heartbeat < limit_dttm, - ) - ).all() - ) - self._last_zombie_query_time = timezone.utcnow() - for ti in tis: - sti = SimpleTaskInstance(ti) - self.log.info("Detected zombie job with dag_id %s, task_id %s, and execution date %s", - sti.dag_id, sti.task_id, sti.execution_date.isoformat()) - zombies.append(sti) + for ti in tis: + sti = SimpleTaskInstance(ti) + self.log.info("Detected zombie job with dag_id %s, task_id %s, and execution date %s", + sti.dag_id, sti.task_id, sti.execution_date.isoformat()) + zombies.append(sti) return zombies diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 6ea9c9b3ec573..d37defbf3b971 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -23,7 +23,6 @@ import tempfile import unittest from unittest import mock -from datetime import timedelta from unittest.mock import MagicMock @@ -203,8 +202,15 @@ def test_find_zombies(self): session.add(ti) session.commit() - manager._last_zombie_query_time = timezone.utcnow() - timedelta( - seconds=manager._zombie_threshold_secs + 1) + # initial call should return zombies + zombies = manager._find_zombies() + self.assertEqual(1, len(zombies)) + self.assertIsInstance(zombies[0], SimpleTaskInstance) + self.assertEqual(ti.dag_id, zombies[0].dag_id) + self.assertEqual(ti.task_id, zombies[0].task_id) + self.assertEqual(ti.execution_date, zombies[0].execution_date) + + # AIRFLOW-4797: repeated call should return zombies again zombies = manager._find_zombies() self.assertEqual(1, len(zombies)) self.assertIsInstance(zombies[0], SimpleTaskInstance)