Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sys
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Tuple

import pendulum

Expand Down Expand Up @@ -540,7 +540,9 @@ def terminate(self):
"""Get called when the daemon receives a SIGTERM."""
raise NotImplementedError

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]: # pragma: no cover
def cleanup_stuck_queued_tasks(
self, tis: list[TaskInstance]
) -> Iterable[TaskInstance]: # pragma: no cover
"""
Handle remnants of tasks that were failed because they were stuck in queued.

Expand Down
131 changes: 71 additions & 60 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@
DR = DagRun
DM = DagModel

RESCHEDULE_STUCK_IN_QUEUED_EVENT = "rescheduling stuck in queued"
STUCK_IN_QUEUED_EVENT = "stuck in queued"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one reason i removed the "rescheduling" part is because at the point where you log this, you don't know that it's reschedulable -- you only know that further down.

""":meta private:"""


class ConcurrencyMap:
Expand Down Expand Up @@ -1790,7 +1791,7 @@ def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None:

As a compromise between always failing a stuck task and always rescheduling a stuck task (which could
lead to tasks being stuck in queued forever without informing the user), we have creating the config
`[core] num_stuck_reschedules`. With this new configuration, an airflow admin can decide how
``[scheduler] num_stuck_in_queued_retries``. With this new configuration, an airflow admin can decide how
Copy link
Contributor Author

@dstandish dstandish Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a scheduler setting not core, and more a retry than a reschedule

sensitive they would like their airflow to be WRT failing stuck tasks.
"""
self.log.debug("Calling SchedulerJob._fail_tasks_stuck_in_queued method")
Expand All @@ -1803,65 +1804,73 @@ def _handle_tasks_stuck_in_queued(self, session: Session = NEW_SESSION) -> None:
)
).all()

num_allowed_retries = conf.getint("core", "num_stuck_reschedules")
num_allowed_retries = conf.getint("scheduler", "num_stuck_in_queued_retries")
for executor, stuck_tis in self._executor_to_tis(tasks_stuck_in_queued).items():
try:
cleaned_up_task_instances = set(executor.cleanup_stuck_queued_tasks(tis=stuck_tis))
for ti in stuck_tis:
if repr(ti) in cleaned_up_task_instances:
num_times_stuck = self._get_num_times_stuck_in_queued(ti, session)
if num_times_stuck < num_allowed_retries:
self.log.warning(
"Task %s was stuck in queued and will be requeued, once it has hit %s attempts"
" the task will be marked as failed. After that, if the task instance has "
"available retries, it will be retried.", ti.key, num_allowed_retries
)
session.add(
Log(
event=RESCHEDULE_STUCK_IN_QUEUED_EVENT,
task_instance=ti.key,
extra=(
f"Task was stuck in queued and will be requeued, once it has hit {num_allowed_retries} attempts"
"Task will be marked as failed. After that, if the task instance has "
"available retries, it will be retried."
),
)
)
if not hasattr(executor, "cleanup_stuck_queued_tasks"):
continue

executor.change_state(ti.key, State.SCHEDULED)
session.execute(
update(TI)
.where(TI.filter_for_tis([ti]))
.values(
# TODO[ha]: should we use func.now()? How does that work with DB timezone
# on mysql when it's not UTC?
state=TaskInstanceState.SCHEDULED,
queued_dttm=None,
# queued_by_job_id=None,
)
.execution_options(synchronize_session=False)
)
else:
self.log.warning(
"Marking task instance %s stuck in queued as failed. "
"If the task instance has available retries, it will be retried.",
ti,
)
session.add(
Log(
event="failing stuck in queued",
task_instance=ti.key,
extra=(
"Task will be marked as failed. If the task instance has "
"available retries, it will be retried."
),
)
)
executor.fail(ti.key)
for ti in executor.cleanup_stuck_queued_tasks(tis=stuck_tis):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing that still bothers me @dimberman is, it doesn't feel right that we defer to the executor and only conditionally log if it "cleans up" the ti. we have already observed that it is stuck in queued so why not log that?

i guess the problem is we are logging the wrong event. the event is not that it is "stuck in queued" (which is an unconditional observation) but rather that it was requeued. that's the thing that conditionally happens.

if not isinstance(ti, TaskInstance):
# this is for backcompat. the pre-2.10.4 version of the interface
# expected a string return val.
self.log.warning(
"Marking task instance %s stuck in queued as failed. "
"If the task instance has available retries, it will be retried.",
ti,
)
continue

session.add(
Log(
event=STUCK_IN_QUEUED_EVENT,
task_instance=ti.key,
extra=(
"Task was in queued state for longer "
f"than {self._task_queued_timeout} seconds."
),
)
)
self.log.warning("Task stuck in queued and may be requeued task_id=%s", ti.key)

num_times_stuck = self._get_num_times_stuck_in_queued(ti, session)
if num_times_stuck < num_allowed_retries:
session.add(
Log(
event=STUCK_IN_QUEUED_EVENT,
task_instance=ti.key,
extra=(
f"Task was stuck in queued and will be requeued, once it has hit {num_allowed_retries} attempts"
"Task will be marked as failed. After that, if the task instance has "
"available retries, it will be retried."
),
)
)

except NotImplementedError:
self.log.debug("Executor doesn't support cleanup of stuck queued tasks. Skipping.")
executor.change_state(ti.key, State.SCHEDULED)
session.execute(
update(TI)
.where(TI.filter_for_tis([ti]))
.values(
state=TaskInstanceState.SCHEDULED,
queued_dttm=None,
)
.execution_options(synchronize_session=False)
)
else:
self.log.warning(
"Task requeue attempts exceeded max; marking failed. task_instance=%s", ti
)
session.add(
Log(
event="stuck in queued tries exceeded",
task_instance=ti.key,
extra=(
f"Task was requeued more than {num_allowed_retries} times "
"and will be failed."
),
)
)
executor.fail(ti.key)

@provide_session
def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NEW_SESSION) -> int:
Expand All @@ -1871,14 +1880,16 @@ def _get_num_times_stuck_in_queued(self, ti: TaskInstance, session: Session = NE
We can then use this information to determine whether to reschedule a task or fail it.
"""
return (
session.query(Log).where(
session.query(Log)
.where(
Log.task_id == ti.task_id,
Log.dag_id == ti.dag_id,
Log.run_id == ti.run_id,
Log.map_index == ti.map_index,
Log.try_number == ti.try_number,
Log.event == RESCHEDULE_STUCK_IN_QUEUED_EVENT,
).count()
Log.event == STUCK_IN_QUEUED_EVENT,
)
.count()
)

@provide_session
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from collections import Counter
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, Tuple

from celery import states as celery_states
from packaging.version import Version
Expand Down Expand Up @@ -433,7 +433,7 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

return not_adopted_tis

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> Generator[TaskInstance, None, None]:
"""
Handle remnants of tasks that were failed because they were stuck in queued.

Expand All @@ -442,13 +442,11 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
if it doesn't.

:param tis: List of Task Instances to clean up
:return: List of readable task instances for a warning message
"""
readable_tis = []
from airflow.providers.celery.executors.celery_executor_utils import app

for ti in tis:
readable_tis.append(repr(ti))
yield ti
task_instance_key = ti.key
if Version(airflow_version) < Version("2.10.4"):
self.fail(task_instance_key)
Expand All @@ -458,7 +456,6 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
app.control.revoke(celery_async_result.task_id)
except Exception as ex:
self.log.error("Error revoking task instance %s from celery: %s", task_instance_key, ex)
return readable_tis

@staticmethod
def get_cli_commands() -> list[GroupCommand]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Generator, Sequence

from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
Expand Down Expand Up @@ -246,13 +246,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
*self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis),
]

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> Generator[TaskInstance, None, None]:
celery_tis = [ti for ti in tis if ti.queue != self.kubernetes_queue]
kubernetes_tis = [ti for ti in tis if ti.queue == self.kubernetes_queue]
return [
*self.celery_executor.cleanup_stuck_queued_tasks(celery_tis),
*self.kubernetes_executor.cleanup_stuck_queued_tasks(kubernetes_tis),
]
yield from self.celery_executor.cleanup_stuck_queued_tasks(celery_tis)
yield from self.kubernetes_executor.cleanup_stuck_queued_tasks(kubernetes_tis)

def end(self) -> None:
"""End celery and kubernetes executor."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from contextlib import suppress
from datetime import datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Sequence
from typing import TYPE_CHECKING, Any, Generator, Sequence

from kubernetes.dynamic import DynamicClient
from packaging.version import Version
Expand Down Expand Up @@ -607,7 +607,7 @@ def _iter_tis_to_flush():
tis_to_flush.extend(_iter_tis_to_flush())
return tis_to_flush

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> Generator[TaskInstance, None, None]:
"""
Handle remnants of tasks that were failed because they were stuck in queued.

Expand All @@ -621,9 +621,6 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
if TYPE_CHECKING:
assert self.kube_client
assert self.kube_scheduler
readable_tis: list[str] = []
if not tis:
return readable_tis
pod_combined_search_str_to_pod_map = self.get_pod_combined_search_str_to_pod_map()
for ti in tis:
# Build the pod selector
Expand All @@ -637,13 +634,17 @@ def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
if not pod:
self.log.warning("Cannot find pod for ti %s", ti)
continue
<<<<<<< HEAD
readable_tis.append(repr(ti))
if Version(airflow_version) >= Version("2.10.4"):
=======
if Version(airflow_version) < Version("2.10.4"):
>>>>>>> d6d1caa641 (Simplify the handle stuck in queued interface)
self.kube_scheduler.patch_pod_delete_stuck(
pod_name=pod.metadata.name, namespace=pod.metadata.namespace
)
yield ti
self.kube_scheduler.delete_pod(pod_name=pod.metadata.name, namespace=pod.metadata.namespace)
return readable_tis

def adopt_launched_task(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Generator, Sequence

from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
Expand Down Expand Up @@ -230,11 +230,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
*self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis),
]

def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> list[str]:
def cleanup_stuck_queued_tasks(self, tis: list[TaskInstance]) -> Generator[TaskInstance, None, None]:
# LocalExecutor doesn't have a cleanup_stuck_queued_tasks method, so we
# will only run KubernetesExecutor's
kubernetes_tis = [ti for ti in tis if ti.queue == self.KUBERNETES_QUEUE]
return self.kubernetes_executor.cleanup_stuck_queued_tasks(kubernetes_tis)
yield from self.kubernetes_executor.cleanup_stuck_queued_tasks(kubernetes_tis)

def end(self) -> None:
"""End local and kubernetes executor."""
Expand Down