diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 2b574efec246d..f0b232183f7b3 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -23,6 +23,7 @@ import warnings from collections import defaultdict from dataclasses import dataclass, field +from functools import cached_property from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple import pendulum @@ -32,6 +33,7 @@ from airflow.exceptions import RemovedInAirflow3Warning from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.log.task_context_logger import TaskContextLogger from airflow.utils.state import TaskInstanceState PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -284,8 +286,12 @@ def trigger_tasks(self, open_slots: int) -> None: self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key) continue # Otherwise, we give up and remove the task from the queue. - self.log.error( - "could not queue task %s (still running after %d attempts)", key, attempt.total_tries + self.send_message_to_task_logs( + logging.ERROR, + "could not queue task %s (still running after %d attempts).", + key, + attempt.total_tries, + ti=ti, ) del self.attempts[key] del self.queued_tasks[key] @@ -512,6 +518,16 @@ def send_callback(self, request: CallbackRequest) -> None: raise ValueError("Callback sink is not ready.") self.callback_sink.send(request) + @cached_property + def _task_context_logger(self) -> TaskContextLogger: + return TaskContextLogger( + component_name="Executor", + call_site_logger=self.log, + ) + + def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): + self._task_context_logger._log(level, msg, *args, ti=ti) + @staticmethod def get_cli_commands() -> list[GroupCommand]: """ diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 5b768252296ec..48c286c6fbe90 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -23,6 +23,7 @@ from __future__ import annotations +import logging import time from collections import defaultdict, deque from copy import deepcopy @@ -347,7 +348,7 @@ def attempt_task_runs(self): queue = ecs_task.queue exec_config = ecs_task.executor_config attempt_number = ecs_task.attempt_number - _failure_reasons = [] + failure_reasons = [] if timezone.utcnow() < ecs_task.next_attempt_time: self.pending_tasks.append(ecs_task) continue @@ -361,23 +362,21 @@ def attempt_task_runs(self): if error_code in INVALID_CREDENTIALS_EXCEPTIONS: self.pending_tasks.append(ecs_task) raise - _failure_reasons.append(str(e)) + failure_reasons.append(str(e)) except Exception as e: # Failed to even get a response back from the Boto3 API or something else went # wrong. For any possible failure we want to add the exception reasons to the # failure list so that it is logged to the user and most importantly the task is # added back to the pending list to be retried later. - _failure_reasons.append(str(e)) + failure_reasons.append(str(e)) else: # We got a response back, check if there were failures. If so, add them to the # failures list so that it is logged to the user and most importantly the task # is added back to the pending list to be retried later. if run_task_response["failures"]: - _failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) + failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) - if _failure_reasons: - for reason in _failure_reasons: - failure_reasons[reason] += 1 + if failure_reasons: # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): ecs_task.attempt_number += 1 @@ -386,14 +385,19 @@ def attempt_task_runs(self): ) self.pending_tasks.append(ecs_task) else: - self.log.error( - "ECS task %s has failed a maximum of %s times. Marking as failed", + self.send_message_to_task_logs( + logging.ERROR, + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", task_key, attempt_number, + ", ".join(failure_reasons), + ti=task_key, ) self.fail(task_key) elif not run_task_response["tasks"]: - self.log.error("ECS RunTask Response: %s", run_task_response) + self.send_message_to_task_logs( + logging.ERROR, "ECS RunTask Response: %s", run_task_response, ti=task_key + ) raise EcsExecutorException( "No failures and no ECS tasks provided in response. This should never happen." ) @@ -407,11 +411,6 @@ def attempt_task_runs(self): # executor feature). # TODO: remove when min airflow version >= 2.9.2 pass - if failure_reasons: - self.log.error( - "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.", - dict(failure_reasons), - ) def _run_task( self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType @@ -543,3 +542,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] return not_adopted_tis + + def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): + # TODO: remove this method when min_airflow_version is set to higher than 2.10.0 + try: + super().send_message_to_task_logs(level, msg, *args, ti=ti) + except AttributeError: + # ``send_message_to_task_logs`` is added in 2.10.0 + self.log.error(msg, *args) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index b233de6a14179..e177d232d5b3e 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -47,7 +47,8 @@ from pendulum import DateTime from airflow.models import DagRun - from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -265,7 +266,7 @@ def close(self): @internal_api_call @provide_session def _render_filename_db_access( - *, ti, try_number: int, session=None + *, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]: ti = _ensure_ti(ti, session) dag_run = ti.get_dagrun(session=session) @@ -281,9 +282,7 @@ def _render_filename_db_access( filename = render_template_to_string(jinja_tpl, context) return dag_run, ti, str_tpl, filename - def _render_filename( - self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int - ) -> str: + def _render_filename(self, ti: TaskInstance | TaskInstancePydantic, try_number: int) -> str: """Return the worker log filename.""" dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number) if filename: diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index 61cecfd92f5bb..1d2301b65be81 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -24,6 +24,9 @@ from typing import TYPE_CHECKING from airflow.configuration import conf +from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.utils.log.file_task_handler import _ensure_ti +from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -57,7 +60,7 @@ def __init__(self, component_name: str, call_site_logger: Logger | None = None): def _should_enable(self) -> bool: if not conf.getboolean("logging", "enable_task_context_logger"): return False - if not getattr(self.task_handler, "supports_task_context_logging", False): + if not self.task_handler: logger.warning("Task handler does not support task context logging") return False logger.info("Task context logging is enabled") @@ -78,13 +81,13 @@ def _get_task_handler() -> FileTaskHandler | None: assert isinstance(h, FileTaskHandler) return h - def _log(self, level: int, msg: str, *args, ti: TaskInstance): + def _log(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message to the task instance logs. :param level: the log level :param msg: the message to relay to task context log - :param ti: the task instance + :param ti: the task instance or the task instance key """ if self.call_site_logger and self.call_site_logger.isEnabledFor(level=level): with suppress(Exception): @@ -98,6 +101,9 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance): task_handler = copy(self.task_handler) try: + if isinstance(ti, TaskInstanceKey): + with create_session() as session: + ti = _ensure_ti(ti, session) task_handler.set_context(ti, identifier=self.component_name) if hasattr(task_handler, "mark_end_on_close"): task_handler.mark_end_on_close = False @@ -109,7 +115,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance): finally: task_handler.close() - def critical(self, msg: str, *args, ti: TaskInstance): + def critical(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level CRITICAL to the task instance logs. @@ -118,7 +124,7 @@ def critical(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.CRITICAL, msg, *args, ti=ti) - def fatal(self, msg: str, *args, ti: TaskInstance): + def fatal(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level FATAL to the task instance logs. @@ -127,7 +133,7 @@ def fatal(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.FATAL, msg, *args, ti=ti) - def error(self, msg: str, *args, ti: TaskInstance): + def error(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level ERROR to the task instance logs. @@ -136,7 +142,7 @@ def error(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.ERROR, msg, *args, ti=ti) - def warn(self, msg: str, *args, ti: TaskInstance): + def warn(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level WARN to the task instance logs. @@ -145,7 +151,7 @@ def warn(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.WARNING, msg, *args, ti=ti) - def warning(self, msg: str, *args, ti: TaskInstance): + def warning(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level WARNING to the task instance logs. @@ -154,7 +160,7 @@ def warning(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.WARNING, msg, *args, ti=ti) - def info(self, msg: str, *args, ti: TaskInstance): + def info(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level INFO to the task instance logs. @@ -163,7 +169,7 @@ def info(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.INFO, msg, *args, ti=ti) - def debug(self, msg: str, *args, ti: TaskInstance): + def debug(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level DEBUG to the task instance logs. @@ -172,7 +178,7 @@ def debug(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.DEBUG, msg, *args, ti=ti) - def notset(self, msg: str, *args, ti: TaskInstance): + def notset(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level NOTSET to the task instance logs. diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index b547f398337f9..6e7669288811f 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -25,7 +25,7 @@ from functools import partial from typing import Callable from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, call import pytest import yaml @@ -462,8 +462,11 @@ def test_failed_execute_api(self, mock_executor): # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 + @mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs") @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, caplog): + def test_attempt_task_runs_attempts_when_tasks_fail( + self, _, mock_send_message_to_task_logs, mock_executor + ): """ Test case when all tasks fail to run. @@ -474,7 +477,6 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] airflow_cmd1 = mock.Mock(spec=list) airflow_cmd2 = mock.Mock(spec=list) - caplog.set_level("ERROR") commands = [airflow_cmd1, airflow_cmd2] failures = [Exception("Failure 1"), Exception("Failure 2")] @@ -491,11 +493,9 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] assert len(mock_executor.pending_tasks) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures @@ -504,11 +504,9 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] assert len(mock_executor.pending_tasks) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures @@ -517,15 +515,25 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl assert len(mock_executor.active_workers.get_all_arns()) == 0 assert len(mock_executor.pending_tasks) == 0 - assert len(caplog.messages) == 3 + calls = [] for i in range(2): - assert ( - f"ECS task {airflow_keys[i]} has failed a maximum of 3 times. Marking as failed" - == caplog.messages[i] + calls.append( + call( + logging.ERROR, + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", + airflow_keys[i], + 3, + f"Failure {i + 1}", + ti=airflow_keys[i], + ) ) + mock_send_message_to_task_logs.assert_has_calls(calls) + @mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs") @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, caplog): + def test_attempt_task_runs_attempts_when_some_tasks_fal( + self, _, mock_send_message_to_task_logs, mock_executor + ): """ Test case when one task fail to run, and a new task gets queued. @@ -537,7 +545,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] airflow_cmd1 = mock.Mock(spec=list) airflow_cmd2 = mock.Mock(spec=list) - caplog.set_level("ERROR") airflow_commands = [airflow_cmd1, airflow_cmd2] task = { "taskArn": ARN1, @@ -564,7 +571,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, assert len(mock_executor.pending_tasks) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 1 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() # queue new task @@ -590,7 +596,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, assert len(mock_executor.pending_tasks) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 2 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() responses = [Exception("Failure 1")] @@ -600,11 +605,13 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS - assert len(caplog.messages) == 2 - - assert ( - f"ECS task {airflow_keys[0]} has failed a maximum of 3 times. Marking as failed" - == caplog.messages[0] + mock_send_message_to_task_logs.assert_called_once_with( + logging.ERROR, + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", + airflow_keys[0], + 3, + "Failure 1", + ti=airflow_keys[0], ) @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index eafb0640565a1..a02af122bfa01 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections import defaultdict -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstancekey import TaskInstanceKey @@ -50,6 +50,8 @@ def __init__(self, do_update=True, *args, **kwargs): super().__init__(*args, **kwargs) + self.task_context_logger = Mock() + def success(self): return State.SUCCESS diff --git a/tests/utils/log/test_task_context_logger.py b/tests/utils/log/test_task_context_logger.py index 1171a40864ff9..d1e5339b0435e 100644 --- a/tests/utils/log/test_task_context_logger.py +++ b/tests/utils/log/test_task_context_logger.py @@ -17,10 +17,12 @@ from __future__ import annotations import logging +from unittest import mock from unittest.mock import Mock import pytest +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.log.task_context_logger import TaskContextLogger from tests.test_utils.config import conf_vars @@ -78,6 +80,23 @@ def test_task_context_log_with_correct_arguments(ti, mock_handler, supported): mock_handler.emit.assert_not_called() +@pytest.mark.db_test +@mock.patch("airflow.utils.log.task_context_logger._ensure_ti") +@pytest.mark.parametrize("supported", [True, False]) +def test_task_context_log_with_task_instance_key(mock_ensure_ti, ti, mock_handler, supported): + mock_handler.supports_task_context_logging = supported + mock_ensure_ti.return_value = ti + task_instance_key = TaskInstanceKey(ti.dag_id, ti.task_id, ti.run_id, ti.try_number, ti.map_index) + t = TaskContextLogger(component_name="test_component") + t.info("test message with args %s, %s", "a", "b", ti=task_instance_key) + if supported: + mock_handler.set_context.assert_called_once_with(ti, identifier="test_component") + mock_handler.emit.assert_called_once() + else: + mock_handler.set_context.assert_not_called() + mock_handler.emit.assert_not_called() + + @pytest.mark.db_test def test_task_context_log_closes_task_handler(ti, mock_handler): t = TaskContextLogger("blah")