Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions airflow/providers/amazon/aws/executors/batch/batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
from __future__ import annotations

import contextlib
import logging
import time
from collections import defaultdict, deque
from collections import deque
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Sequence

Expand Down Expand Up @@ -264,15 +265,14 @@ def attempt_submit_jobs(self):
in the next iteration of the sync() method, unless it has exceeded the maximum number of
attempts. If a job exceeds the maximum number of attempts, it is removed from the queue.
"""
failure_reasons = defaultdict(int)
for _ in range(len(self.pending_jobs)):
batch_job = self.pending_jobs.popleft()
key = batch_job.key
cmd = batch_job.command
queue = batch_job.queue
exec_config = batch_job.executor_config
attempt_number = batch_job.attempt_number
_failure_reason = []
failure_reason: str | None = None
if timezone.utcnow() < batch_job.next_attempt_time:
self.pending_jobs.append(batch_job)
continue
Expand All @@ -286,18 +286,18 @@ def attempt_submit_jobs(self):
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_jobs.append(batch_job)
raise
_failure_reason.append(str(e))
failure_reason = str(e)
except Exception as e:
_failure_reason.append(str(e))

if _failure_reason:
for reason in _failure_reason:
failure_reasons[reason] += 1
failure_reason = str(e)

if failure_reason:
if attempt_number >= int(self.__class__.MAX_SUBMIT_JOB_ATTEMPTS):
self.log.error(
"This job has been unsuccessfully attempted too many times (%s). Dropping the task.",
self.send_message_to_task_logs(
logging.ERROR,
"This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
attempt_number,
failure_reason,
ti=key,
)
self.fail(key=key)
else:
Expand All @@ -322,11 +322,6 @@ def attempt_submit_jobs(self):
# running_state is added in Airflow 2.10 and only needed to support task adoption
# (an optional executor feature).
self.running_state(key, job_id)
if failure_reasons:
self.log.error(
"Pending Batch jobs failed to launch for the following reasons: %s. Retrying later.",
dict(failure_reasons),
)

def _describe_jobs(self, job_ids) -> list[BatchJob]:
all_jobs = []
Expand Down Expand Up @@ -462,3 +457,11 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis

def send_message_to_task_logs(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
# TODO: remove this method when min_airflow_version is set to higher than 2.10.0
try:
super().send_message_to_task_logs(level, msg, *args, ti=ti)
except AttributeError:
# ``send_message_to_task_logs`` is added in 2.10.0
self.log.error(msg, *args)
47 changes: 23 additions & 24 deletions tests/providers/amazon/aws/executors/batch/test_batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
from unittest import mock
from unittest.mock import call

import pytest
import yaml
Expand Down Expand Up @@ -194,8 +195,9 @@ def test_execute(self, mock_executor):
mock_executor.batch.submit_job.assert_called_once()
assert len(mock_executor.active_workers) == 1

@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor, caplog):
def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_send_message_to_task_logs, mock_executor):
"""
Test how jobs are tried when one job fails, but others pass.

Expand All @@ -206,7 +208,6 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor, caplog):
airflow_key = mock.Mock(spec=tuple)
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
caplog.set_level("ERROR")
airflow_commands = [airflow_cmd1, airflow_cmd2]
responses = [Exception("Failure 1"), {"jobId": "job-2"}]

Expand All @@ -229,13 +230,10 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor, caplog):
for i in range(2):
submit_job_args["containerOverrides"]["command"] = airflow_commands[i]
assert mock_executor.batch.submit_job.call_args_list[i].kwargs == submit_job_args
assert "Pending Batch jobs failed to launch for the following reasons" in caplog.messages[0]
assert len(mock_executor.pending_jobs) == 1
mock_executor.pending_jobs[0].command == airflow_cmd1
assert len(mock_executor.active_workers.get_all_jobs()) == 1

caplog.clear()

# Add more tasks to pending_jobs. This simulates tasks being scheduled by Airflow
airflow_cmd3 = mock.Mock(spec=list)
airflow_cmd4 = mock.Mock(spec=list)
Expand All @@ -252,26 +250,27 @@ def test_attempt_all_jobs_when_some_jobs_fail(self, _, mock_executor, caplog):
for i in range(2, 5):
submit_job_args["containerOverrides"]["command"] = airflow_commands[i]
assert mock_executor.batch.submit_job.call_args_list[i].kwargs == submit_job_args
assert "Pending Batch jobs failed to launch for the following reasons" in caplog.messages[0]
assert len(mock_executor.pending_jobs) == 1
mock_executor.pending_jobs[0].command == airflow_cmd1
assert len(mock_executor.active_workers.get_all_jobs()) == 3

caplog.clear()

airflow_commands.append(airflow_cmd1)
responses.append(Exception("Failure 1"))

mock_executor.attempt_submit_jobs()
submit_job_args["containerOverrides"]["command"] = airflow_commands[0]
assert mock_executor.batch.submit_job.call_args_list[5].kwargs == submit_job_args
assert (
"This job has been unsuccessfully attempted too many times (3). Dropping the task."
== caplog.messages[0]
mock_send_message_to_task_logs.assert_called_once_with(
logging.ERROR,
"This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
3,
"Failure 1",
ti=airflow_key,
)

@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor, caplog):
def test_attempt_all_jobs_when_jobs_fail(self, _, mock_send_message_to_task_logs, mock_executor):
"""
Test job retry behaviour when jobs fail validation.

Expand All @@ -282,7 +281,6 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor, caplog):
airflow_key = mock.Mock(spec=tuple)
airflow_cmd1 = mock.Mock(spec=list)
airflow_cmd2 = mock.Mock(spec=list)
caplog.set_level("ERROR")
commands = [airflow_cmd1, airflow_cmd2]
failures = [Exception("Failure 1"), Exception("Failure 2")]
submit_job_args = {
Expand All @@ -304,29 +302,29 @@ def test_attempt_all_jobs_when_jobs_fail(self, _, mock_executor, caplog):
for i in range(2):
submit_job_args["containerOverrides"]["command"] = commands[i]
assert mock_executor.batch.submit_job.call_args_list[i].kwargs == submit_job_args
assert "Pending Batch jobs failed to launch for the following reasons" in caplog.messages[0]
assert len(mock_executor.pending_jobs) == 2

caplog.clear()

mock_executor.batch.submit_job.side_effect = failures
mock_executor.attempt_submit_jobs()
for i in range(2):
submit_job_args["containerOverrides"]["command"] = commands[i]
assert mock_executor.batch.submit_job.call_args_list[i].kwargs == submit_job_args
assert "Pending Batch jobs failed to launch for the following reasons" in caplog.messages[0]
assert len(mock_executor.pending_jobs) == 2

caplog.clear()

mock_executor.batch.submit_job.side_effect = failures
mock_executor.attempt_submit_jobs()
assert len(caplog.messages) == 3
calls = []
for i in range(2):
assert (
"This job has been unsuccessfully attempted too many times (3). Dropping the task."
== caplog.messages[i]
calls.append(
call(
logging.ERROR,
"This job has been unsuccessfully attempted too many times (%s). Dropping the task. Reason: %s",
3,
f"Failure {i + 1}",
ti=airflow_key,
)
)
mock_send_message_to_task_logs.assert_has_calls(calls)

def test_attempt_submit_jobs_failure(self, mock_executor):
mock_executor.batch.submit_job.side_effect = NoCredentialsError()
Expand Down Expand Up @@ -467,8 +465,9 @@ def test_sync(self, success_mock, fail_mock, mock_airflow_key, mock_executor):

@mock.patch.object(BaseExecutor, "fail")
@mock.patch.object(BaseExecutor, "success")
@mock.patch.object(AwsBatchExecutor, "send_message_to_task_logs")
@mock.patch.object(batch_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0))
def test_failed_sync(self, _, success_mock, fail_mock, mock_airflow_key, mock_executor):
def test_failed_sync(self, _, _2, success_mock, fail_mock, mock_airflow_key, mock_executor):
"""Test failure states"""
self._mock_sync(
executor=mock_executor, airflow_key=mock_airflow_key(), status="FAILED", attempt_number=2
Expand Down