Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,7 @@ def __init__(
self.scheduling_priority_override = scheduling_priority_override
self.array_properties = array_properties
self.parameters = parameters or {}
self.retry_strategy = retry_strategy or {}
if not self.retry_strategy.get("attempts", None):
self.retry_strategy["attempts"] = 1
self.retry_strategy = retry_strategy
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
Expand Down
32 changes: 30 additions & 2 deletions tests/providers/amazon/aws/operators/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest import mock
from unittest.mock import patch

import botocore.client
import pytest

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
Expand Down Expand Up @@ -64,7 +65,7 @@ def setup_method(self, _, get_client_type_mock):
max_retries=self.MAX_RETRIES,
status_retries=self.STATUS_RETRIES,
parameters=None,
retry_strategy=None,
retry_strategy={"attempts": 1},
container_overrides={},
array_properties=None,
aws_conn_id="airflow_test",
Expand Down Expand Up @@ -112,6 +113,34 @@ def test_init(self):

self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1")

def test_init_defaults(self):
"""Test constructor default values"""
batch_job = BatchOperator(
task_id="task",
job_name=JOB_NAME,
job_queue="queue",
job_definition="hello-world",
)
assert batch_job.job_id is None
assert batch_job.job_name == JOB_NAME
assert batch_job.job_queue == "queue"
assert batch_job.job_definition == "hello-world"
assert batch_job.waiters is None
assert batch_job.hook.max_retries == 4200
assert batch_job.hook.status_retries == 10
assert batch_job.parameters == {}
assert batch_job.retry_strategy is None
assert batch_job.container_overrides is None
assert batch_job.array_properties is None
assert batch_job.node_overrides is None
assert batch_job.share_identifier is None
assert batch_job.scheduling_priority_override is None
assert batch_job.hook.region_name is None
assert batch_job.hook.aws_conn_id is None
assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient)
assert batch_job.tags == {}
assert batch_job.wait_for_completion is True

def test_template_fields_overrides(self):
assert self.batch.template_fields == (
"job_id",
Expand Down Expand Up @@ -238,7 +267,6 @@ def test_override_not_sent_if_not_set(self, client_mock, override):
"jobName": JOB_NAME,
"jobDefinition": "hello-world",
"parameters": {},
"retryStrategy": {"attempts": 1},
"tags": {},
}
if override == "overrides":
Expand Down