From 76ba7441501628c695899ac73c3f015d745e98e2 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Thu, 27 Jun 2024 15:11:23 -0400 Subject: [PATCH 01/19] Send important executor logs to task logs --- airflow/executors/base_executor.py | 14 ++++- .../amazon/aws/executors/ecs/ecs_executor.py | 8 ++- airflow/utils/log/file_task_handler.py | 33 ++--------- airflow/utils/log/task_context_logger.py | 56 +++++++++++++++---- 4 files changed, 66 insertions(+), 45 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 2b574efec246d..150e1b9fc9719 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -32,6 +32,7 @@ from airflow.exceptions import RemovedInAirflow3Warning from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.log.task_context_logger import TaskContextLogger from airflow.utils.state import TaskInstanceState PARALLELISM: int = conf.getint("core", "PARALLELISM") @@ -130,6 +131,10 @@ def __init__(self, parallelism: int = PARALLELISM): self.running: set[TaskInstanceKey] = set() self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) + self.task_context_logger: TaskContextLogger = TaskContextLogger( + component_name="Executor", + call_site_logger=self.log, + ) def __repr__(self): return f"{self.__class__.__name__}(parallelism={self.parallelism})" @@ -149,7 +154,7 @@ def queue_command( self.log.info("Adding to queue: %s", command) self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) else: - self.log.error("could not queue task %s", task_instance.key) + self.task_context_logger.error("could not queue task %s", task_instance.key, ti=task_instance) def queue_task_instance( self, @@ -284,8 +289,11 @@ def trigger_tasks(self, open_slots: int) -> None: self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key) continue # Otherwise, we give up and remove the task from the queue. - self.log.error( - "could not queue task %s (still running after %d attempts)", key, attempt.total_tries + self.task_context_logger.error( + "could not queue task %s (still running after %d attempts)", + key, + attempt.total_tries, + ti=ti, ) del self.attempts[key] del self.queued_tasks[key] diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 5b768252296ec..2f0badd54424b 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -320,10 +320,11 @@ def __handle_failed_task(self, task_arn: str, reason: str): ) ) else: - self.log.error( + self.task_context_logger.error( "Airflow task %s has failed a maximum of %s times. Marking as failed", task_key, failure_count, + ti=task_key, ) self.fail(task_key) self.active_workers.pop_by_key(task_key) @@ -386,14 +387,15 @@ def attempt_task_runs(self): ) self.pending_tasks.append(ecs_task) else: - self.log.error( + self.task_context_logger.error( "ECS task %s has failed a maximum of %s times. Marking as failed", task_key, attempt_number, + ti=task_key, ) self.fail(task_key) elif not run_task_response["tasks"]: - self.log.error("ECS RunTask Response: %s", run_task_response) + self.task_context_logger.error("ECS RunTask Response: %s", run_task_response, ti=task_key) raise EcsExecutorException( "No failures and no ECS tasks provided in response. This should never happen." ) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index b233de6a14179..a267d122ca1cc 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -34,12 +34,13 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.exceptions import RemovedInAirflow3Warning from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string from airflow.utils.log.logging_mixin import SetContextPropagate from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler +from airflow.utils.log.task_context_logger import ensure_ti from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState @@ -140,32 +141,6 @@ def _interleave_logs(*logs): last = v -def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance: - """ - Given TI | TIKey, return a TI object. - - Will raise exception if no TI is found in the database. - """ - from airflow.models.taskinstance import TaskInstance - - if isinstance(ti, TaskInstance): - return ti - val = ( - session.query(TaskInstance) - .filter( - TaskInstance.task_id == ti.task_id, - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - .one_or_none() - ) - if not val: - raise AirflowException(f"Could not find TaskInstance for {ti}") - val.try_number = ti.try_number - return val - - class FileTaskHandler(logging.Handler): """ FileTaskHandler is a python log handler that handles and reads task instance logs. @@ -265,9 +240,9 @@ def close(self): @internal_api_call @provide_session def _render_filename_db_access( - *, ti, try_number: int, session=None + *, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int, session=None ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]: - ti = _ensure_ti(ti, session) + ti = ensure_ti(ti, session) dag_run = ti.get_dagrun(session=session) template = dag_run.get_log_template(session=session).filename str_tpl, jinja_tpl = parse_template_string(template) diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index 61cecfd92f5bb..7297e3b9ffc18 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -24,14 +24,46 @@ from typing import TYPE_CHECKING from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.log.file_task_handler import FileTaskHandler logger = logging.getLogger(__name__) +def ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance: + """ + Given TI | TIKey, return a TI object. + + Will raise exception if no TI is found in the database. + """ + from airflow.models.taskinstance import TaskInstance + + if isinstance(ti, TaskInstance): + return ti + val = ( + session.query(TaskInstance) + .filter( + TaskInstance.task_id == ti.task_id, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + .one_or_none() + ) + if not val: + raise AirflowException(f"Could not find TaskInstance for {ti}") + val.try_number = ti.try_number + return val + + class TaskContextLogger: """ Class for sending messages to task instance logs from outside task execution context. @@ -57,7 +89,7 @@ def __init__(self, component_name: str, call_site_logger: Logger | None = None): def _should_enable(self) -> bool: if not conf.getboolean("logging", "enable_task_context_logger"): return False - if not getattr(self.task_handler, "supports_task_context_logging", False): + if not self.task_handler: logger.warning("Task handler does not support task context logging") return False logger.info("Task context logging is enabled") @@ -78,7 +110,10 @@ def _get_task_handler() -> FileTaskHandler | None: assert isinstance(h, FileTaskHandler) return h - def _log(self, level: int, msg: str, *args, ti: TaskInstance): + @provide_session + def _log( + self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey, session: Session = NEW_SESSION + ): """ Emit a log message to the task instance logs. @@ -98,6 +133,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance): task_handler = copy(self.task_handler) try: + ti = ensure_ti(ti, session) task_handler.set_context(ti, identifier=self.component_name) if hasattr(task_handler, "mark_end_on_close"): task_handler.mark_end_on_close = False @@ -109,7 +145,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance): finally: task_handler.close() - def critical(self, msg: str, *args, ti: TaskInstance): + def critical(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level CRITICAL to the task instance logs. @@ -118,7 +154,7 @@ def critical(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.CRITICAL, msg, *args, ti=ti) - def fatal(self, msg: str, *args, ti: TaskInstance): + def fatal(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level FATAL to the task instance logs. @@ -127,7 +163,7 @@ def fatal(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.FATAL, msg, *args, ti=ti) - def error(self, msg: str, *args, ti: TaskInstance): + def error(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level ERROR to the task instance logs. @@ -136,7 +172,7 @@ def error(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.ERROR, msg, *args, ti=ti) - def warn(self, msg: str, *args, ti: TaskInstance): + def warn(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level WARN to the task instance logs. @@ -145,7 +181,7 @@ def warn(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.WARNING, msg, *args, ti=ti) - def warning(self, msg: str, *args, ti: TaskInstance): + def warning(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level WARNING to the task instance logs. @@ -154,7 +190,7 @@ def warning(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.WARNING, msg, *args, ti=ti) - def info(self, msg: str, *args, ti: TaskInstance): + def info(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level INFO to the task instance logs. @@ -163,7 +199,7 @@ def info(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.INFO, msg, *args, ti=ti) - def debug(self, msg: str, *args, ti: TaskInstance): + def debug(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level DEBUG to the task instance logs. @@ -172,7 +208,7 @@ def debug(self, msg: str, *args, ti: TaskInstance): """ self._log(logging.DEBUG, msg, *args, ti=ti) - def notset(self, msg: str, *args, ti: TaskInstance): + def notset(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message with level NOTSET to the task instance logs. From de2af6cb401c41b891917d84ca50d0991299fe4f Mon Sep 17 00:00:00 2001 From: vincbeck Date: Thu, 27 Jun 2024 15:36:00 -0400 Subject: [PATCH 02/19] Add unit test --- tests/utils/log/test_task_context_logger.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/utils/log/test_task_context_logger.py b/tests/utils/log/test_task_context_logger.py index 1171a40864ff9..9759dda1fc772 100644 --- a/tests/utils/log/test_task_context_logger.py +++ b/tests/utils/log/test_task_context_logger.py @@ -17,10 +17,12 @@ from __future__ import annotations import logging +from unittest import mock from unittest.mock import Mock import pytest +from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.log.task_context_logger import TaskContextLogger from tests.test_utils.config import conf_vars @@ -78,6 +80,23 @@ def test_task_context_log_with_correct_arguments(ti, mock_handler, supported): mock_handler.emit.assert_not_called() +@pytest.mark.db_test +@mock.patch("airflow.utils.log.task_context_logger.ensure_ti") +@pytest.mark.parametrize("supported", [True, False]) +def test_task_context_log_with_task_instance_key(mock_ensure_ti, ti, mock_handler, supported): + mock_handler.supports_task_context_logging = supported + mock_ensure_ti.return_value = ti + task_instance_key = TaskInstanceKey(ti.dag_id, ti.task_id, ti.run_id, ti.try_number, ti.map_index) + t = TaskContextLogger(component_name="test_component") + t.info("test message with args %s, %s", "a", "b", ti=task_instance_key) + if supported: + mock_handler.set_context.assert_called_once_with(ti, identifier="test_component") + mock_handler.emit.assert_called_once() + else: + mock_handler.set_context.assert_not_called() + mock_handler.emit.assert_not_called() + + @pytest.mark.db_test def test_task_context_log_closes_task_handler(ti, mock_handler): t = TaskContextLogger("blah") From 907083b631468a8064505704c4776aee855a9d5f Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 28 Jun 2024 11:57:48 -0400 Subject: [PATCH 03/19] Create session only when ti is `TaskInstanceKey` --- airflow/utils/log/task_context_logger.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index 7297e3b9ffc18..ff2ae93ab4fa8 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -23,15 +23,14 @@ from logging import Logger from typing import TYPE_CHECKING +from sqlalchemy.orm import create_session + from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.models.taskinstancekey import TaskInstanceKey if TYPE_CHECKING: - from sqlalchemy.orm import Session - from airflow.models.taskinstance import TaskInstance - from airflow.models.taskinstancekey import TaskInstanceKey from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.log.file_task_handler import FileTaskHandler @@ -110,16 +109,13 @@ def _get_task_handler() -> FileTaskHandler | None: assert isinstance(h, FileTaskHandler) return h - @provide_session - def _log( - self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey, session: Session = NEW_SESSION - ): + def _log(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): """ Emit a log message to the task instance logs. :param level: the log level :param msg: the message to relay to task context log - :param ti: the task instance + :param ti: the task instance or the task instance key """ if self.call_site_logger and self.call_site_logger.isEnabledFor(level=level): with suppress(Exception): @@ -133,7 +129,9 @@ def _log( task_handler = copy(self.task_handler) try: - ti = ensure_ti(ti, session) + if isinstance(ti, TaskInstanceKey): + with create_session() as session: + ti = ensure_ti(ti, session) task_handler.set_context(ti, identifier=self.component_name) if hasattr(task_handler, "mark_end_on_close"): task_handler.mark_end_on_close = False From 6eb42867434a3064ddd91b01a4ce42a191481488 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 28 Jun 2024 12:27:17 -0400 Subject: [PATCH 04/19] Fix tests --- airflow/executors/base_executor.py | 1 + .../aws/executors/ecs/test_ecs_executor.py | 39 ++++++++++++------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 150e1b9fc9719..d6db7132837b9 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -123,6 +123,7 @@ class BaseExecutor(LoggingMixin): job_id: None | int | str = None name: None | ExecutorName = None callback_sink: BaseCallbackSink | None = None + task_context_logger: TaskContextLogger def __init__(self, parallelism: int = PARALLELISM): super().__init__() diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index b547f398337f9..5a8a0773efa73 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -25,7 +25,7 @@ from functools import partial from typing import Callable from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock, call import pytest import yaml @@ -160,6 +160,7 @@ def mock_executor(set_env_vars) -> AwsEcsExecutor: run_task_ret_val = {"tasks": [{"taskArn": ARN1}], "failures": []} ecs_mock.run_task.return_value = run_task_ret_val executor.ecs = ecs_mock + executor.task_context_logger = Mock() return executor @@ -517,12 +518,17 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl assert len(mock_executor.active_workers.get_all_arns()) == 0 assert len(mock_executor.pending_tasks) == 0 - assert len(caplog.messages) == 3 + calls = [] for i in range(2): - assert ( - f"ECS task {airflow_keys[i]} has failed a maximum of 3 times. Marking as failed" - == caplog.messages[i] + calls.append( + call( + "ECS task %s has failed a maximum of %s times. Marking as failed", + airflow_keys[i], + 3, + ti=airflow_keys[i], + ) ) + mock_executor.task_context_logger.error.assert_has_calls(calls) @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, caplog): @@ -600,11 +606,12 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS - assert len(caplog.messages) == 2 - - assert ( - f"ECS task {airflow_keys[0]} has failed a maximum of 3 times. Marking as failed" - == caplog.messages[0] + assert len(caplog.messages) == 1 + mock_executor.task_context_logger.error.assert_called_once_with( + "ECS task %s has failed a maximum of %s times. Marking as failed", + airflow_keys[0], + 3, + ti=airflow_keys[0], ) @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) @@ -704,11 +711,17 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS mock_executor.sync_running_tasks() + calls = [] for i in range(2): - assert ( - f"Airflow task {airflow_keys[i]} has failed a maximum of 2 times. Marking as failed" - in caplog.messages[i] + calls.append( + call( + "Airflow task %s has failed a maximum of %s times. Marking as failed", + airflow_keys[i], + 2, + ti=airflow_keys[i], + ) ) + mock_executor.task_context_logger.error.assert_has_calls(calls) @mock.patch.object(BaseExecutor, "fail") @mock.patch.object(BaseExecutor, "success") From 34e00ce8a50d7a39c3fbaf6192b1362168fc86a8 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Fri, 28 Jun 2024 15:18:59 -0400 Subject: [PATCH 05/19] Fix tests + adjustments --- .../amazon/aws/executors/ecs/ecs_executor.py | 20 +++++++------------ .../aws/executors/ecs/test_ecs_executor.py | 19 ++++++------------ tests/test_utils/mock_executor.py | 4 +++- 3 files changed, 16 insertions(+), 27 deletions(-) diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 2f0badd54424b..6b469604a7cdc 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -348,7 +348,7 @@ def attempt_task_runs(self): queue = ecs_task.queue exec_config = ecs_task.executor_config attempt_number = ecs_task.attempt_number - _failure_reasons = [] + failure_reasons = [] if timezone.utcnow() < ecs_task.next_attempt_time: self.pending_tasks.append(ecs_task) continue @@ -362,23 +362,21 @@ def attempt_task_runs(self): if error_code in INVALID_CREDENTIALS_EXCEPTIONS: self.pending_tasks.append(ecs_task) raise - _failure_reasons.append(str(e)) + failure_reasons.append(str(e)) except Exception as e: # Failed to even get a response back from the Boto3 API or something else went # wrong. For any possible failure we want to add the exception reasons to the # failure list so that it is logged to the user and most importantly the task is # added back to the pending list to be retried later. - _failure_reasons.append(str(e)) + failure_reasons.append(str(e)) else: # We got a response back, check if there were failures. If so, add them to the # failures list so that it is logged to the user and most importantly the task # is added back to the pending list to be retried later. if run_task_response["failures"]: - _failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) + failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) - if _failure_reasons: - for reason in _failure_reasons: - failure_reasons[reason] += 1 + if failure_reasons: # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): ecs_task.attempt_number += 1 @@ -388,9 +386,10 @@ def attempt_task_runs(self): self.pending_tasks.append(ecs_task) else: self.task_context_logger.error( - "ECS task %s has failed a maximum of %s times. Marking as failed", + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", task_key, attempt_number, + ", ".join(failure_reasons), ti=task_key, ) self.fail(task_key) @@ -409,11 +408,6 @@ def attempt_task_runs(self): # executor feature). # TODO: remove when min airflow version >= 2.9.2 pass - if failure_reasons: - self.log.error( - "Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.", - dict(failure_reasons), - ) def _run_task( self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 5a8a0773efa73..257cbfba82c8f 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -464,7 +464,7 @@ def test_failed_execute_api(self, mock_executor): assert len(mock_executor.active_workers) == 0 @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, caplog): + def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): """ Test case when all tasks fail to run. @@ -475,7 +475,6 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] airflow_cmd1 = mock.Mock(spec=list) airflow_cmd2 = mock.Mock(spec=list) - caplog.set_level("ERROR") commands = [airflow_cmd1, airflow_cmd2] failures = [Exception("Failure 1"), Exception("Failure 2")] @@ -492,11 +491,9 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] assert len(mock_executor.pending_tasks) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures @@ -505,11 +502,9 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl for i in range(2): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS - assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] assert len(mock_executor.pending_tasks) == 2 assert len(mock_executor.active_workers.get_all_arns()) == 0 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() mock_executor.ecs.run_task.side_effect = failures @@ -522,16 +517,17 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, capl for i in range(2): calls.append( call( - "ECS task %s has failed a maximum of %s times. Marking as failed", + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", airflow_keys[i], 3, + f"Failure {i + 1}", ti=airflow_keys[i], ) ) mock_executor.task_context_logger.error.assert_has_calls(calls) @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, caplog): + def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): """ Test case when one task fail to run, and a new task gets queued. @@ -543,7 +539,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] airflow_cmd1 = mock.Mock(spec=list) airflow_cmd2 = mock.Mock(spec=list) - caplog.set_level("ERROR") airflow_commands = [airflow_cmd1, airflow_cmd2] task = { "taskArn": ARN1, @@ -570,7 +565,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, assert len(mock_executor.pending_tasks) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 1 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() # queue new task @@ -596,7 +590,6 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, assert len(mock_executor.pending_tasks) == 1 assert len(mock_executor.active_workers.get_all_arns()) == 2 - caplog.clear() mock_executor.ecs.run_task.call_args_list.clear() responses = [Exception("Failure 1")] @@ -606,11 +599,11 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS - assert len(caplog.messages) == 1 mock_executor.task_context_logger.error.assert_called_once_with( - "ECS task %s has failed a maximum of %s times. Marking as failed", + "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", airflow_keys[0], 3, + "Failure 1", ti=airflow_keys[0], ) diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index eafb0640565a1..a02af122bfa01 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections import defaultdict -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock from airflow.executors.base_executor import BaseExecutor from airflow.models.taskinstancekey import TaskInstanceKey @@ -50,6 +50,8 @@ def __init__(self, do_update=True, *args, **kwargs): super().__init__(*args, **kwargs) + self.task_context_logger = Mock() + def success(self): return State.SUCCESS From 4fcdefc7569837b761c479ff529854beb96f379e Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 11:41:43 -0400 Subject: [PATCH 06/19] Update error message --- airflow/executors/base_executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index d6db7132837b9..e6cc19c09e3da 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -291,7 +291,8 @@ def trigger_tasks(self, open_slots: int) -> None: continue # Otherwise, we give up and remove the task from the queue. self.task_context_logger.error( - "could not queue task %s (still running after %d attempts)", + "Could not queue task %s (still running after %d attempts). It generally means that the " + "task has been killed externally and not yet been marked as failed.", key, attempt.total_tries, ti=ti, From ae80de4508ca314d008efe34231a3ac706ff6a40 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 13:22:20 -0400 Subject: [PATCH 07/19] Fix test --- tests/task/task_runner/test_task_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index 6214930e36477..d5d4cc0bc3a4e 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -35,7 +35,7 @@ class TestGetTaskRunner: def test_should_have_valid_imports(self, import_path): assert import_string(import_path) is not None - @mock.patch("airflow.utils.log.file_task_handler._ensure_ti") + @mock.patch("airflow.utils.log.task_context_logger.ensure_ti") @mock.patch("airflow.task.task_runner.base_task_runner.subprocess") @mock.patch("airflow.task.task_runner._TASK_RUNNER_NAME", "StandardTaskRunner") def test_should_support_core_task_runner(self, mock_subprocess, mock_ensure_ti): From 2abc248c1603b1c91d570ecccccfc00887b4ab4c Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 14:00:11 -0400 Subject: [PATCH 08/19] Fix test --- tests/task/task_runner/test_task_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index d5d4cc0bc3a4e..af6636ef636b4 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -35,7 +35,7 @@ class TestGetTaskRunner: def test_should_have_valid_imports(self, import_path): assert import_string(import_path) is not None - @mock.patch("airflow.utils.log.task_context_logger.ensure_ti") + @mock.patch("airflow.utils.log.file_task_handler.ensure_ti") @mock.patch("airflow.task.task_runner.base_task_runner.subprocess") @mock.patch("airflow.task.task_runner._TASK_RUNNER_NAME", "StandardTaskRunner") def test_should_support_core_task_runner(self, mock_subprocess, mock_ensure_ti): From 20e89e4d4e389f0e7431658a738eb4bd9729dce3 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 14:27:05 -0400 Subject: [PATCH 09/19] Remove `TaskInstanceKey` type from `_render_filename_db_access` method --- airflow/utils/log/file_task_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index a267d122ca1cc..d78eb0a33a72c 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -240,7 +240,7 @@ def close(self): @internal_api_call @provide_session def _render_filename_db_access( - *, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int, session=None + *, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]: ti = ensure_ti(ti, session) dag_run = ti.get_dagrun(session=session) @@ -257,7 +257,7 @@ def _render_filename_db_access( return dag_run, ti, str_tpl, filename def _render_filename( - self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int + self, ti: TaskInstance | TaskInstancePydantic, try_number: int ) -> str: """Return the worker log filename.""" dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number) From 595cbd635bc7f2f58f8dcaee51d2920c5eddc991 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 14:41:31 -0400 Subject: [PATCH 10/19] Fix ruff --- airflow/utils/log/file_task_handler.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index d78eb0a33a72c..eeccf4e5eb949 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -48,7 +48,7 @@ from pendulum import DateTime from airflow.models import DagRun - from airflow.models.taskinstance import TaskInstance, TaskInstanceKey + from airflow.models.taskinstance import TaskInstance from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -256,9 +256,7 @@ def _render_filename_db_access( filename = render_template_to_string(jinja_tpl, context) return dag_run, ti, str_tpl, filename - def _render_filename( - self, ti: TaskInstance | TaskInstancePydantic, try_number: int - ) -> str: + def _render_filename(self, ti: TaskInstance | TaskInstancePydantic, try_number: int) -> str: """Return the worker log filename.""" dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number) if filename: From bb7f1b62493a0dc4d50a11db46dc8eadd25e4f30 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 17:10:05 -0400 Subject: [PATCH 11/19] Use correct `create_session` --- airflow/utils/log/task_context_logger.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index ff2ae93ab4fa8..956afa89a0a94 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -23,11 +23,10 @@ from logging import Logger from typing import TYPE_CHECKING -from sqlalchemy.orm import create_session - from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance From 9aad244ba88aa3de358c09fef87735948ea24ce7 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 2 Jul 2024 17:15:17 -0400 Subject: [PATCH 12/19] Update error message --- airflow/executors/base_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index e6cc19c09e3da..1df5b4bfd17ce 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -291,8 +291,8 @@ def trigger_tasks(self, open_slots: int) -> None: continue # Otherwise, we give up and remove the task from the queue. self.task_context_logger.error( - "Could not queue task %s (still running after %d attempts). It generally means that the " - "task has been killed externally and not yet been marked as failed.", + "Failed to queue task %s after %d attempts; executor reports task is currently " + "running.", key, attempt.total_tries, ti=ti, From 083d15b98924750a1ff05a09d904412f18458a79 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Wed, 3 Jul 2024 10:46:50 -0400 Subject: [PATCH 13/19] Update error message --- airflow/executors/base_executor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 1df5b4bfd17ce..efa2582d791b6 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -291,10 +291,10 @@ def trigger_tasks(self, open_slots: int) -> None: continue # Otherwise, we give up and remove the task from the queue. self.task_context_logger.error( - "Failed to queue task %s after %d attempts; executor reports task is currently " - "running.", + "Could not queue task %s as it is seen as still running after %d attempts (tried for %d seconds). It looks like it was killed externally. Look for external reasons why it has been killed (likely a bug or deployment issue).", key, attempt.total_tries, + RunningRetryAttemptType.MIN_SECONDS, ti=ti, ) del self.attempts[key] From b3e6963328f3dba7b0f3ab3427552f3f1ffae8cd Mon Sep 17 00:00:00 2001 From: vincbeck Date: Wed, 3 Jul 2024 11:39:21 -0400 Subject: [PATCH 14/19] Move `ensure_ti` and makes it back private --- airflow/utils/log/file_task_handler.py | 32 ++++++++++++++++++++-- airflow/utils/log/task_context_logger.py | 31 ++------------------- tests/task/task_runner/test_task_runner.py | 2 +- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index eeccf4e5eb949..e177d232d5b3e 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -34,13 +34,12 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf -from airflow.exceptions import RemovedInAirflow3Warning +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.executors.executor_loader import ExecutorLoader from airflow.utils.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string from airflow.utils.log.logging_mixin import SetContextPropagate from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler -from airflow.utils.log.task_context_logger import ensure_ti from airflow.utils.session import provide_session from airflow.utils.state import State, TaskInstanceState @@ -49,6 +48,7 @@ from airflow.models import DagRun from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic @@ -141,6 +141,32 @@ def _interleave_logs(*logs): last = v +def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance: + """ + Given TI | TIKey, return a TI object. + + Will raise exception if no TI is found in the database. + """ + from airflow.models.taskinstance import TaskInstance + + if isinstance(ti, TaskInstance): + return ti + val = ( + session.query(TaskInstance) + .filter( + TaskInstance.task_id == ti.task_id, + TaskInstance.dag_id == ti.dag_id, + TaskInstance.run_id == ti.run_id, + TaskInstance.map_index == ti.map_index, + ) + .one_or_none() + ) + if not val: + raise AirflowException(f"Could not find TaskInstance for {ti}") + val.try_number = ti.try_number + return val + + class FileTaskHandler(logging.Handler): """ FileTaskHandler is a python log handler that handles and reads task instance logs. @@ -242,7 +268,7 @@ def close(self): def _render_filename_db_access( *, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None ) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]: - ti = ensure_ti(ti, session) + ti = _ensure_ti(ti, session) dag_run = ti.get_dagrun(session=session) template = dag_run.get_log_template(session=session).filename str_tpl, jinja_tpl = parse_template_string(template) diff --git a/airflow/utils/log/task_context_logger.py b/airflow/utils/log/task_context_logger.py index 956afa89a0a94..1d2301b65be81 100644 --- a/airflow/utils/log/task_context_logger.py +++ b/airflow/utils/log/task_context_logger.py @@ -24,44 +24,17 @@ from typing import TYPE_CHECKING from airflow.configuration import conf -from airflow.exceptions import AirflowException from airflow.models.taskinstancekey import TaskInstanceKey +from airflow.utils.log.file_task_handler import _ensure_ti from airflow.utils.session import create_session if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance - from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from airflow.utils.log.file_task_handler import FileTaskHandler logger = logging.getLogger(__name__) -def ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance: - """ - Given TI | TIKey, return a TI object. - - Will raise exception if no TI is found in the database. - """ - from airflow.models.taskinstance import TaskInstance - - if isinstance(ti, TaskInstance): - return ti - val = ( - session.query(TaskInstance) - .filter( - TaskInstance.task_id == ti.task_id, - TaskInstance.dag_id == ti.dag_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - .one_or_none() - ) - if not val: - raise AirflowException(f"Could not find TaskInstance for {ti}") - val.try_number = ti.try_number - return val - - class TaskContextLogger: """ Class for sending messages to task instance logs from outside task execution context. @@ -130,7 +103,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): try: if isinstance(ti, TaskInstanceKey): with create_session() as session: - ti = ensure_ti(ti, session) + ti = _ensure_ti(ti, session) task_handler.set_context(ti, identifier=self.component_name) if hasattr(task_handler, "mark_end_on_close"): task_handler.mark_end_on_close = False diff --git a/tests/task/task_runner/test_task_runner.py b/tests/task/task_runner/test_task_runner.py index af6636ef636b4..6214930e36477 100644 --- a/tests/task/task_runner/test_task_runner.py +++ b/tests/task/task_runner/test_task_runner.py @@ -35,7 +35,7 @@ class TestGetTaskRunner: def test_should_have_valid_imports(self, import_path): assert import_string(import_path) is not None - @mock.patch("airflow.utils.log.file_task_handler.ensure_ti") + @mock.patch("airflow.utils.log.file_task_handler._ensure_ti") @mock.patch("airflow.task.task_runner.base_task_runner.subprocess") @mock.patch("airflow.task.task_runner._TASK_RUNNER_NAME", "StandardTaskRunner") def test_should_support_core_task_runner(self, mock_subprocess, mock_ensure_ti): From 260a1d547e7bdf71c366ebff0c644a40738712e2 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Wed, 3 Jul 2024 12:12:31 -0400 Subject: [PATCH 15/19] Fix test --- tests/utils/log/test_task_context_logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/log/test_task_context_logger.py b/tests/utils/log/test_task_context_logger.py index 9759dda1fc772..d1e5339b0435e 100644 --- a/tests/utils/log/test_task_context_logger.py +++ b/tests/utils/log/test_task_context_logger.py @@ -81,7 +81,7 @@ def test_task_context_log_with_correct_arguments(ti, mock_handler, supported): @pytest.mark.db_test -@mock.patch("airflow.utils.log.task_context_logger.ensure_ti") +@mock.patch("airflow.utils.log.task_context_logger._ensure_ti") @pytest.mark.parametrize("supported", [True, False]) def test_task_context_log_with_task_instance_key(mock_ensure_ti, ti, mock_handler, supported): mock_handler.supports_task_context_logging = supported From 6b5061d97b824d47d53f092f089c6d58af7ecaf5 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Wed, 3 Jul 2024 14:36:54 -0400 Subject: [PATCH 16/19] Reduce number of calls to `task_context_logger` --- airflow/executors/base_executor.py | 2 +- .../amazon/aws/executors/ecs/ecs_executor.py | 3 +-- .../amazon/aws/executors/ecs/test_ecs_executor.py | 12 +++--------- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index efa2582d791b6..4659cde84eb74 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -155,7 +155,7 @@ def queue_command( self.log.info("Adding to queue: %s", command) self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) else: - self.task_context_logger.error("could not queue task %s", task_instance.key, ti=task_instance) + self.log.error("could not queue task %s", task_instance.key) def queue_task_instance( self, diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 6b469604a7cdc..7b54a8a90caf6 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -320,11 +320,10 @@ def __handle_failed_task(self, task_arn: str, reason: str): ) ) else: - self.task_context_logger.error( + self.log.error( "Airflow task %s has failed a maximum of %s times. Marking as failed", task_key, failure_count, - ti=task_key, ) self.fail(task_key) self.active_workers.pop_by_key(task_key) diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 257cbfba82c8f..023c03af01c94 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -704,17 +704,11 @@ def test_task_retry_on_api_failure_all_tasks_fail(self, _, mock_executor, caplog assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS mock_executor.sync_running_tasks() - calls = [] for i in range(2): - calls.append( - call( - "Airflow task %s has failed a maximum of %s times. Marking as failed", - airflow_keys[i], - 2, - ti=airflow_keys[i], - ) + assert ( + f"Airflow task {airflow_keys[i]} has failed a maximum of 2 times. Marking as failed" + in caplog.messages[i] ) - mock_executor.task_context_logger.error.assert_has_calls(calls) @mock.patch.object(BaseExecutor, "fail") @mock.patch.object(BaseExecutor, "success") From b553197bee756f78b1ce664a53c0f750190f902e Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 9 Jul 2024 18:26:23 -0400 Subject: [PATCH 17/19] Revert error message --- airflow/executors/base_executor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 4659cde84eb74..3cf66c3a6e160 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -291,10 +291,9 @@ def trigger_tasks(self, open_slots: int) -> None: continue # Otherwise, we give up and remove the task from the queue. self.task_context_logger.error( - "Could not queue task %s as it is seen as still running after %d attempts (tried for %d seconds). It looks like it was killed externally. Look for external reasons why it has been killed (likely a bug or deployment issue).", + "could not queue task %s (still running after %d attempts).", key, attempt.total_tries, - RunningRetryAttemptType.MIN_SECONDS, ti=ti, ) del self.attempts[key] From af1ca5b4eacac0c1a1accd32cf3b2d0779a1186f Mon Sep 17 00:00:00 2001 From: vincbeck Date: Tue, 9 Jul 2024 18:51:17 -0400 Subject: [PATCH 18/19] Create method `send_message_to_task_logs` in `base_executor` --- airflow/executors/base_executor.py | 19 +++++++++++++------ .../amazon/aws/executors/ecs/ecs_executor.py | 8 ++++++-- .../aws/executors/ecs/test_ecs_executor.py | 19 +++++++++++++------ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 3cf66c3a6e160..f0b232183f7b3 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -23,6 +23,7 @@ import warnings from collections import defaultdict from dataclasses import dataclass, field +from functools import cached_property from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple import pendulum @@ -123,7 +124,6 @@ class BaseExecutor(LoggingMixin): job_id: None | int | str = None name: None | ExecutorName = None callback_sink: BaseCallbackSink | None = None - task_context_logger: TaskContextLogger def __init__(self, parallelism: int = PARALLELISM): super().__init__() @@ -132,10 +132,6 @@ def __init__(self, parallelism: int = PARALLELISM): self.running: set[TaskInstanceKey] = set() self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType) - self.task_context_logger: TaskContextLogger = TaskContextLogger( - component_name="Executor", - call_site_logger=self.log, - ) def __repr__(self): return f"{self.__class__.__name__}(parallelism={self.parallelism})" @@ -290,7 +286,8 @@ def trigger_tasks(self, open_slots: int) -> None: self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key) continue # Otherwise, we give up and remove the task from the queue. - self.task_context_logger.error( + self.send_message_to_task_logs( + logging.ERROR, "could not queue task %s (still running after %d attempts).", key, attempt.total_tries, @@ -521,6 +518,16 @@ def send_callback(self, request: CallbackRequest) -> None: raise ValueError("Callback sink is not ready.") self.callback_sink.send(request) + @cached_property + def _task_context_logger(self) -> TaskContextLogger: + return TaskContextLogger( + component_name="Executor", + call_site_logger=self.log, + ) + + def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): + self._task_context_logger._log(level, msg, *args, ti=ti) + @staticmethod def get_cli_commands() -> list[GroupCommand]: """ diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 7b54a8a90caf6..983ad8b0007ab 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -23,6 +23,7 @@ from __future__ import annotations +import logging import time from collections import defaultdict, deque from copy import deepcopy @@ -384,7 +385,8 @@ def attempt_task_runs(self): ) self.pending_tasks.append(ecs_task) else: - self.task_context_logger.error( + self.send_message_to_task_logs( + logging.ERROR, "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", task_key, attempt_number, @@ -393,7 +395,9 @@ def attempt_task_runs(self): ) self.fail(task_key) elif not run_task_response["tasks"]: - self.task_context_logger.error("ECS RunTask Response: %s", run_task_response, ti=task_key) + self.send_message_to_task_logs( + logging.ERROR, "ECS RunTask Response: %s", run_task_response, ti=task_key + ) raise EcsExecutorException( "No failures and no ECS tasks provided in response. This should never happen." ) diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 023c03af01c94..6e7669288811f 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -25,7 +25,7 @@ from functools import partial from typing import Callable from unittest import mock -from unittest.mock import MagicMock, Mock, call +from unittest.mock import MagicMock, call import pytest import yaml @@ -160,7 +160,6 @@ def mock_executor(set_env_vars) -> AwsEcsExecutor: run_task_ret_val = {"tasks": [{"taskArn": ARN1}], "failures": []} ecs_mock.run_task.return_value = run_task_ret_val executor.ecs = ecs_mock - executor.task_context_logger = Mock() return executor @@ -463,8 +462,11 @@ def test_failed_execute_api(self, mock_executor): # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 + @mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs") @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): + def test_attempt_task_runs_attempts_when_tasks_fail( + self, _, mock_send_message_to_task_logs, mock_executor + ): """ Test case when all tasks fail to run. @@ -517,6 +519,7 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): for i in range(2): calls.append( call( + logging.ERROR, "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", airflow_keys[i], 3, @@ -524,10 +527,13 @@ def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor): ti=airflow_keys[i], ) ) - mock_executor.task_context_logger.error.assert_has_calls(calls) + mock_send_message_to_task_logs.assert_has_calls(calls) + @mock.patch.object(AwsEcsExecutor, "send_message_to_task_logs") @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) - def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): + def test_attempt_task_runs_attempts_when_some_tasks_fal( + self, _, mock_send_message_to_task_logs, mock_executor + ): """ Test case when one task fail to run, and a new task gets queued. @@ -599,7 +605,8 @@ def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor): RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS - mock_executor.task_context_logger.error.assert_called_once_with( + mock_send_message_to_task_logs.assert_called_once_with( + logging.ERROR, "ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s", airflow_keys[0], 3, From 540e9413dc1cfdd9e0d67ba1c1ec0b943ebac399 Mon Sep 17 00:00:00 2001 From: vincbeck Date: Wed, 10 Jul 2024 11:04:08 -0400 Subject: [PATCH 19/19] Handle case when `send_message_to_task_logs` does not exist in `base_executor` --- .../providers/amazon/aws/executors/ecs/ecs_executor.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 983ad8b0007ab..48c286c6fbe90 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -542,3 +542,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task not_adopted_tis = [ti for ti in tis if ti not in adopted_tis] return not_adopted_tis + + def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey): + # TODO: remove this method when min_airflow_version is set to higher than 2.10.0 + try: + super().send_message_to_task_logs(level, msg, *args, ti=ti) + except AttributeError: + # ``send_message_to_task_logs`` is added in 2.10.0 + self.log.error(msg, *args)