diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index afca0fc615ff0..00b6287145a81 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -206,9 +206,7 @@ def __init__( self.scheduling_priority_override = scheduling_priority_override self.array_properties = array_properties self.parameters = parameters or {} - self.retry_strategy = retry_strategy or {} - if not self.retry_strategy.get("attempts", None): - self.retry_strategy["attempts"] = 1 + self.retry_strategy = retry_strategy self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 2ac95578136ba..f769c1baa8181 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -20,6 +20,7 @@ from unittest import mock from unittest.mock import patch +import botocore.client import pytest from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred @@ -64,7 +65,7 @@ def setup_method(self, _, get_client_type_mock): max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, - retry_strategy=None, + retry_strategy={"attempts": 1}, container_overrides={}, array_properties=None, aws_conn_id="airflow_test", @@ -112,6 +113,34 @@ def test_init(self): self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") + def test_init_defaults(self): + """Test constructor default values""" + batch_job = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + ) + assert batch_job.job_id is None + assert batch_job.job_name == JOB_NAME + assert batch_job.job_queue == "queue" + assert batch_job.job_definition == "hello-world" + assert batch_job.waiters is None + assert batch_job.hook.max_retries == 4200 + assert batch_job.hook.status_retries == 10 + assert batch_job.parameters == {} + assert batch_job.retry_strategy is None + assert batch_job.container_overrides is None + assert batch_job.array_properties is None + assert batch_job.node_overrides is None + assert batch_job.share_identifier is None + assert batch_job.scheduling_priority_override is None + assert batch_job.hook.region_name is None + assert batch_job.hook.aws_conn_id is None + assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) + assert batch_job.tags == {} + assert batch_job.wait_for_completion is True + def test_template_fields_overrides(self): assert self.batch.template_fields == ( "job_id", @@ -238,7 +267,6 @@ def test_override_not_sent_if_not_set(self, client_mock, override): "jobName": JOB_NAME, "jobDefinition": "hello-world", "parameters": {}, - "retryStrategy": {"attempts": 1}, "tags": {}, } if override == "overrides":