From 881a0594887dd56919a27b8de44547085ef4db79 Mon Sep 17 00:00:00 2001 From: aandres Date: Tue, 14 May 2024 09:22:47 +0100 Subject: [PATCH 1/2] Fix default value for aws batch operator retry strategy --- .../providers/amazon/aws/operators/batch.py | 4 +-- .../amazon/aws/operators/test_batch.py | 31 ++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index afca0fc615ff0..7d97dfa104936 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -206,8 +206,8 @@ def __init__( self.scheduling_priority_override = scheduling_priority_override self.array_properties = array_properties self.parameters = parameters or {} - self.retry_strategy = retry_strategy or {} - if not self.retry_strategy.get("attempts", None): + self.retry_strategy = retry_strategy + if self.retry_strategy is not None and not self.retry_strategy.get("attempts", None): self.retry_strategy["attempts"] = 1 self.waiters = waiters self.tags = tags or {} diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 2ac95578136ba..2d6856c24df15 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -20,6 +20,7 @@ from unittest import mock from unittest.mock import patch +import botocore.client import pytest from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred @@ -64,7 +65,7 @@ def setup_method(self, _, get_client_type_mock): max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, - retry_strategy=None, + retry_strategy={}, container_overrides={}, array_properties=None, aws_conn_id="airflow_test", @@ -112,6 +113,34 @@ def test_init(self): self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") + def test_init_defaults(self): + """Test constructor default values""" + batch_job = BatchOperator( + task_id="task", + job_name=JOB_NAME, + job_queue="queue", + job_definition="hello-world", + ) + assert batch_job.job_id is None + assert batch_job.job_name == JOB_NAME + assert batch_job.job_queue == "queue" + assert batch_job.job_definition == "hello-world" + assert batch_job.waiters is None + assert batch_job.hook.max_retries == 4200 + assert batch_job.hook.status_retries == 10 + assert batch_job.parameters == {} + assert batch_job.retry_strategy is None + assert batch_job.container_overrides is None + assert batch_job.array_properties is None + assert batch_job.node_overrides is None + assert batch_job.share_identifier is None + assert batch_job.scheduling_priority_override is None + assert batch_job.hook.region_name is None + assert batch_job.hook.aws_conn_id is None + assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) + assert batch_job.tags == {} + assert batch_job.wait_for_completion is True + def test_template_fields_overrides(self): assert self.batch.template_fields == ( "job_id", From 8e96551bd22d4e7b9a1543a34570a7c48912c4ed Mon Sep 17 00:00:00 2001 From: aandres Date: Tue, 14 May 2024 15:37:26 +0100 Subject: [PATCH 2/2] Remove fix of attempt and fix tests --- airflow/providers/amazon/aws/operators/batch.py | 2 -- tests/providers/amazon/aws/operators/test_batch.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 7d97dfa104936..00b6287145a81 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -207,8 +207,6 @@ def __init__( self.array_properties = array_properties self.parameters = parameters or {} self.retry_strategy = retry_strategy - if self.retry_strategy is not None and not self.retry_strategy.get("attempts", None): - self.retry_strategy["attempts"] = 1 self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 2d6856c24df15..f769c1baa8181 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -65,7 +65,7 @@ def setup_method(self, _, get_client_type_mock): max_retries=self.MAX_RETRIES, status_retries=self.STATUS_RETRIES, parameters=None, - retry_strategy={}, + retry_strategy={"attempts": 1}, container_overrides={}, array_properties=None, aws_conn_id="airflow_test", @@ -267,7 +267,6 @@ def test_override_not_sent_if_not_set(self, client_mock, override): "jobName": JOB_NAME, "jobDefinition": "hello-world", "parameters": {}, - "retryStrategy": {"attempts": 1}, "tags": {}, } if override == "overrides":