Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple

import pendulum
Expand All @@ -32,6 +33,7 @@
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.task_context_logger import TaskContextLogger
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")
Expand Down Expand Up @@ -284,8 +286,12 @@ def trigger_tasks(self, open_slots: int) -> None:
self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key)
continue
# Otherwise, we give up and remove the task from the queue.
self.log.error(
"could not queue task %s (still running after %d attempts)", key, attempt.total_tries
self.send_message_to_task_logs(
logging.ERROR,
"could not queue task %s (still running after %d attempts).",
key,
attempt.total_tries,
ti=ti,
)
del self.attempts[key]
del self.queued_tasks[key]
Expand Down Expand Up @@ -512,6 +518,16 @@ def send_callback(self, request: CallbackRequest) -> None:
raise ValueError("Callback sink is not ready.")
self.callback_sink.send(request)

@cached_property
def _task_context_logger(self) -> TaskContextLogger:
return TaskContextLogger(
component_name="Executor",
call_site_logger=self.log,
)

def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
self._task_context_logger._log(level, msg, *args, ti=ti)

@staticmethod
def get_cli_commands() -> list[GroupCommand]:
"""
Expand Down
37 changes: 22 additions & 15 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from __future__ import annotations

import logging
import time
from collections import defaultdict, deque
from copy import deepcopy
Expand Down Expand Up @@ -347,7 +348,7 @@ def attempt_task_runs(self):
queue = ecs_task.queue
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
_failure_reasons = []
failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
self.pending_tasks.append(ecs_task)
continue
Expand All @@ -361,23 +362,21 @@ def attempt_task_runs(self):
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.append(ecs_task)
raise
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
# failure list so that it is logged to the user and most importantly the task is
# added back to the pending list to be retried later.
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
else:
# We got a response back, check if there were failures. If so, add them to the
# failures list so that it is logged to the user and most importantly the task
# is added back to the pending list to be retried later.
if run_task_response["failures"]:
_failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])
failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])

if _failure_reasons:
for reason in _failure_reasons:
failure_reasons[reason] += 1
if failure_reasons:
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
ecs_task.attempt_number += 1
Expand All @@ -386,14 +385,19 @@ def attempt_task_runs(self):
)
self.pending_tasks.append(ecs_task)
else:
self.log.error(
"ECS task %s has failed a maximum of %s times. Marking as failed",
self.send_message_to_task_logs(
logging.ERROR,
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
", ".join(failure_reasons),
ti=task_key,
)
self.fail(task_key)
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
self.send_message_to_task_logs(
logging.ERROR, "ECS RunTask Response: %s", run_task_response, ti=task_key
)
raise EcsExecutorException(
"No failures and no ECS tasks provided in response. This should never happen."
)
Expand All @@ -407,11 +411,6 @@ def attempt_task_runs(self):
# executor feature).
# TODO: remove when min airflow version >= 2.9.2
pass
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
dict(failure_reasons),
)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
Expand Down Expand Up @@ -543,3 +542,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis

def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
# TODO: remove this method when min_airflow_version is set to higher than 2.10.0
try:
super().send_message_to_task_logs(level, msg, *args, ti=ti)
except AttributeError:
# ``send_message_to_task_logs`` is added in 2.10.0
self.log.error(msg, *args)
9 changes: 4 additions & 5 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
from pendulum import DateTime

from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

Expand Down Expand Up @@ -265,7 +266,7 @@ def close(self):
@internal_api_call
@provide_session
def _render_filename_db_access(
*, ti, try_number: int, session=None
*, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None
) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]:
ti = _ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)
Expand All @@ -281,9 +282,7 @@ def _render_filename_db_access(
filename = render_template_to_string(jinja_tpl, context)
return dag_run, ti, str_tpl, filename

def _render_filename(
self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int
) -> str:
def _render_filename(self, ti: TaskInstance | TaskInstancePydantic, try_number: int) -> str:
"""Return the worker log filename."""
dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number)
if filename:
Expand Down
28 changes: 17 additions & 11 deletions airflow/utils/log/task_context_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from typing import TYPE_CHECKING

from airflow.configuration import conf
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.log.file_task_handler import _ensure_ti
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
Expand Down Expand Up @@ -57,7 +60,7 @@ def __init__(self, component_name: str, call_site_logger: Logger | None = None):
def _should_enable(self) -> bool:
if not conf.getboolean("logging", "enable_task_context_logger"):
return False
if not getattr(self.task_handler, "supports_task_context_logging", False):
if not self.task_handler:
logger.warning("Task handler does not support task context logging")
return False
logger.info("Task context logging is enabled")
Expand All @@ -78,13 +81,13 @@ def _get_task_handler() -> FileTaskHandler | None:
assert isinstance(h, FileTaskHandler)
return h

def _log(self, level: int, msg: str, *args, ti: TaskInstance):
def _log(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message to the task instance logs.

:param level: the log level
:param msg: the message to relay to task context log
:param ti: the task instance
:param ti: the task instance or the task instance key
"""
if self.call_site_logger and self.call_site_logger.isEnabledFor(level=level):
with suppress(Exception):
Expand All @@ -98,6 +101,9 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance):

task_handler = copy(self.task_handler)
try:
if isinstance(ti, TaskInstanceKey):
with create_session() as session:
ti = _ensure_ti(ti, session)
task_handler.set_context(ti, identifier=self.component_name)
if hasattr(task_handler, "mark_end_on_close"):
task_handler.mark_end_on_close = False
Expand All @@ -109,7 +115,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance):
finally:
task_handler.close()

def critical(self, msg: str, *args, ti: TaskInstance):
def critical(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level CRITICAL to the task instance logs.

Expand All @@ -118,7 +124,7 @@ def critical(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.CRITICAL, msg, *args, ti=ti)

def fatal(self, msg: str, *args, ti: TaskInstance):
def fatal(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level FATAL to the task instance logs.

Expand All @@ -127,7 +133,7 @@ def fatal(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.FATAL, msg, *args, ti=ti)

def error(self, msg: str, *args, ti: TaskInstance):
def error(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level ERROR to the task instance logs.

Expand All @@ -136,7 +142,7 @@ def error(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.ERROR, msg, *args, ti=ti)

def warn(self, msg: str, *args, ti: TaskInstance):
def warn(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level WARN to the task instance logs.

Expand All @@ -145,7 +151,7 @@ def warn(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.WARNING, msg, *args, ti=ti)

def warning(self, msg: str, *args, ti: TaskInstance):
def warning(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level WARNING to the task instance logs.

Expand All @@ -154,7 +160,7 @@ def warning(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.WARNING, msg, *args, ti=ti)

def info(self, msg: str, *args, ti: TaskInstance):
def info(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level INFO to the task instance logs.

Expand All @@ -163,7 +169,7 @@ def info(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.INFO, msg, *args, ti=ti)

def debug(self, msg: str, *args, ti: TaskInstance):
def debug(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level DEBUG to the task instance logs.

Expand All @@ -172,7 +178,7 @@ def debug(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.DEBUG, msg, *args, ti=ti)

def notset(self, msg: str, *args, ti: TaskInstance):
def notset(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level NOTSET to the task instance logs.

Expand Down
Loading