From 60e8aa20e5ae2d09b96b0b0785ff1d67136760be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 16 Jun 2023 11:51:45 -0700 Subject: [PATCH 1/7] add waiter for compute env --- .../providers/amazon/aws/waiters/batch.json | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/airflow/providers/amazon/aws/waiters/batch.json b/airflow/providers/amazon/aws/waiters/batch.json index fa9752ea14c41..3fbdd433771c8 100644 --- a/airflow/providers/amazon/aws/waiters/batch.json +++ b/airflow/providers/amazon/aws/waiters/batch.json @@ -20,6 +20,32 @@ "state": "failed" } ] + }, + + "compute_env_ready": { + "delay": 30, + "operation": "DescribeComputeEnvironments", + "maxAttempts": 100, + "acceptors": [ + { + "argument": "computeEnvironments[].status", + "expected": "VALID", + "matcher": "pathAll", + "state": "success" + }, + { + "argument": "computeEnvironments[].status", + "expected": "INVALID", + "matcher": "pathAny", + "state": "failed" + }, + { + "argument": "computeEnvironments[].status", + "expected": "DELETED", + "matcher": "pathAny", + "state": "failed" + } + ] } } } From c593ac1b3f310af3580eab7daf3115c9ccf1e1ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Fri, 16 Jun 2023 12:47:30 -0700 Subject: [PATCH 2/7] deprecated unused param --- airflow/providers/amazon/aws/operators/batch.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 2825ed5a01353..daf22cf716ef6 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -402,6 +402,14 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + if status_retries is not None: + warnings.warn( + "The `status_retries` parameter is unused and should be removed. " + "It'll be deleted in a future version.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.compute_environment_name = compute_environment_name self.environment_type = environment_type self.state = state @@ -410,7 +418,6 @@ def __init__( self.service_role = service_role self.tags = tags or {} self.max_retries = max_retries - self.status_retries = status_retries self.aws_conn_id = aws_conn_id self.region_name = region_name @@ -418,8 +425,6 @@ def __init__( def hook(self): """Create and return a BatchClientHook.""" return BatchClientHook( - max_retries=self.max_retries, - status_retries=self.status_retries, aws_conn_id=self.aws_conn_id, region_name=self.region_name, ) From a5ff5ec326a1907c9b211df32ad9e904067253a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Mon, 19 Jun 2023 16:23:59 -0700 Subject: [PATCH 3/7] add deferrable mode --- .../providers/amazon/aws/operators/batch.py | 28 ++++++- .../providers/amazon/aws/triggers/batch.py | 77 +++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index daf22cf716ef6..3cb488a7b9662 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -37,7 +37,10 @@ BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, +) from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: @@ -395,10 +398,12 @@ def __init__( unmanaged_v_cpus: int | None = None, service_role: str | None = None, tags: dict | None = None, - max_retries: int | None = None, + poll_interval: int = 30, + max_retries: int = 120, status_retries: int | None = None, aws_conn_id: str | None = None, region_name: str | None = None, + deferrable: bool = True, **kwargs, ): super().__init__(**kwargs) @@ -417,9 +422,11 @@ def __init__( self.compute_resources = compute_resources self.service_role = service_role self.tags = tags or {} + self.poll_interval = poll_interval self.max_retries = max_retries self.aws_conn_id = aws_conn_id self.region_name = region_name + self.deferrable = deferrable @cached_property def hook(self): @@ -440,6 +447,21 @@ def execute(self, context: Context): "serviceRole": self.service_role, "tags": self.tags, } - self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + response = self.hook.client.create_compute_environment(**trim_none_values(kwargs)) + arn = response["computeEnvironmentArn"] + + if self.deferrable: + self.defer( + trigger=BatchCreateComputeEnvironmentTrigger( + arn, self.poll_interval, self.max_retries, self.aws_conn_id, self.region_name + ), + method_name="execute_complete", + ) self.log.info("AWS Batch compute environment created successfully") + return + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while waiting for the compute environment to be ready: {event}") + return event["value"] diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index f4a5de15254fa..37c86c6d6b3f4 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -22,6 +22,7 @@ from botocore.exceptions import WaiterError +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -188,3 +189,79 @@ async def run(self): "message": f"Job {self.job_id} Succeeded", } ) + + +class BatchCreateComputeEnvironmentTrigger(BaseTrigger): + """ + Trigger for BatchCreateComputeEnvironmentOperator. + The trigger will asynchronously poll the boto3 API and wait for the compute environment to be ready. + + :param job_id: A unique identifier for the cluster. + :param max_retries: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region_name: region name to use in AWS Hook + :param poll_interval: The amount of time in seconds to wait between attempts. + """ + + def __init__( + self, + compute_env_arn: str | None = None, + poll_interval: int = 30, + max_retries: int = 10, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + ): + super().__init__() + self.compute_env_arn = compute_env_arn + self.max_retries = max_retries + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.poll_interval = poll_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes BatchOperatorTrigger arguments and classpath.""" + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "compute_env_arn": self.compute_env_arn, + "max_retries": self.max_retries, + "aws_conn_id": self.aws_conn_id, + "region_name": self.region_name, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) + async with hook.async_conn as client: + waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client) + attempt = 0 + while attempt < self.max_retries: + attempt = attempt + 1 + try: + await waiter.wait( + computeEnvironments=[self.compute_env_arn], + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + raise + self.log.info( + "Compute Environment status is %s. Retrying attempt %s/%s", + error.last_response["computeEnvironments"][0]["status"], + attempt, + self.max_retries, + ) + await asyncio.sleep(int(self.poll_interval)) + + if attempt >= self.max_retries: + raise AirflowException( + f"Compute Environment {self.compute_env_arn} is still not ready " + f"after checking its state {self.max_retries} times." + ) + else: + yield TriggerEvent({"status": "success", "value": self.compute_env_arn}) From 5a2bd77fa3834810a09a45025deb191b784812e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 20 Jun 2023 11:03:30 -0700 Subject: [PATCH 4/7] add utests --- .../amazon/aws/operators/test_batch.py | 28 ++++++++++++- .../amazon/aws/triggers/test_batch.py | 42 ++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index f559424dff8cc..ac9780efc1e4e 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -27,7 +27,10 @@ from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator # Use dummy AWS credentials -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, +) AWS_REGION = "eu-west-1" AWS_ACCESS_KEY_ID = "airflow_dummy_key" @@ -298,3 +301,26 @@ def test_execute(self, mock_conn): computeResources=compute_resources, tags=tags, ) + + @mock.patch.object(BatchClientHook, "client") + def test_defer(self, client_mock): + client_mock.create_compute_environment.return_value = {"computeEnvironmentArn": "my_arn"} + + operator = BatchCreateComputeEnvironmentOperator( + task_id="task", + compute_environment_name="my_env_name", + environment_type="my_env_type", + state="my_state", + compute_resources={}, + max_retries=123456, + poll_interval=456789, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as deferred: + operator.execute(None) + + assert isinstance(deferred.value.trigger, BatchCreateComputeEnvironmentTrigger) + assert deferred.value.trigger.compute_env_arn == "my_arn" + assert deferred.value.trigger.poll_interval == 456789 + assert deferred.value.trigger.max_retries == 123456 diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py index 5cf125f8280a5..5e405839b3b4a 100644 --- a/tests/providers/amazon/aws/triggers/test_batch.py +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -22,7 +22,12 @@ import pytest from botocore.exceptions import WaiterError -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger, BatchSensorTrigger +from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.triggers.batch import ( + BatchCreateComputeEnvironmentTrigger, + BatchOperatorTrigger, + BatchSensorTrigger, +) from airflow.triggers.base import TriggerEvent BATCH_JOB_ID = "job_id" @@ -181,3 +186,38 @@ async def test_batch_sensor_trigger_failure( assert actual_response == TriggerEvent( {"status": "failure", "message": f"Job Failed: Waiter {name} failed: {reason}"} ) + + +class TestBatchCreateComputeEnvironmentTrigger: + @pytest.mark.asyncio + @mock.patch.object(BatchClientHook, "async_conn") + @mock.patch.object(BatchClientHook, "get_waiter") + async def test_success(self, get_waiter_mock, conn_mock): + get_waiter_mock().wait = AsyncMock( + side_effect=[ + WaiterError( + "situation normal", "first try", {"computeEnvironments": [{"status": "my_status"}]} + ), + {}, + ] + ) + trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["value"] == "my_arn" + + @pytest.mark.asyncio + @mock.patch.object(BatchClientHook, "async_conn") + @mock.patch.object(BatchClientHook, "get_waiter") + async def test_failure(self, get_waiter_mock, conn_mock): + get_waiter_mock().wait = AsyncMock( + side_effect=[WaiterError("terminal failure", "terminal failure reason", {})] + ) + trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) From f606d67152b929f4acf53d56c4c5ef93641f893c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 20 Jun 2023 15:09:45 -0700 Subject: [PATCH 5/7] fixes: docstring, default val, return arn --- airflow/providers/amazon/aws/operators/batch.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 069cf2486e46e..d5eb8dd78d9a3 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -405,14 +405,16 @@ class BatchCreateComputeEnvironmentOperator(BaseOperator): services on your behalf (templated). :param tags: Tags that you apply to the compute-environment to help you categorize and organize your resources. - :param max_retries: Exponential back-off retries, 4200 = 48 hours; polling - is only used when waiters is None. - :param status_retries: Number of HTTP retries to get job status, 10; polling - is only used when waiters is None. + :param poll_interval: How long to wait in seconds between 2 polls at the environment status. + Only useful when deferrable is True. + :param max_retries: How many times to poll for the environment status. + Only useful when deferrable is True. :param aws_conn_id: Connection ID of AWS credentials / region name. If None, credential boto3 strategy will be used. :param region_name: Region name to use in AWS Hook. Overrides the ``region_name`` in connection if provided. + :param deferrable: If True, the operator will wait asynchronously for the environment to be created. + This mode requires aiobotocore module to be installed. (default: False) """ template_fields: Sequence[str] = ( @@ -436,7 +438,7 @@ def __init__( status_retries: int | None = None, aws_conn_id: str | None = None, region_name: str | None = None, - deferrable: bool = True, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -492,7 +494,7 @@ def execute(self, context: Context): ) self.log.info("AWS Batch compute environment created successfully") - return + return arn def execute_complete(self, context, event=None): if event["status"] != "success": From 49dc7749e224a3184ce081ce766c5f299302fd2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Wed, 21 Jun 2023 09:56:58 -0700 Subject: [PATCH 6/7] apply suggestions around obsolescence --- airflow/providers/amazon/aws/operators/batch.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index d5eb8dd78d9a3..b9b3322c49102 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -434,21 +434,22 @@ def __init__( service_role: str | None = None, tags: dict | None = None, poll_interval: int = 30, - max_retries: int = 120, - status_retries: int | None = None, + max_retries: int | None = None, aws_conn_id: str | None = None, region_name: str | None = None, deferrable: bool = False, **kwargs, ): - super().__init__(**kwargs) - if status_retries is not None: + if "status_retries" in kwargs: warnings.warn( "The `status_retries` parameter is unused and should be removed. " "It'll be deleted in a future version.", AirflowProviderDeprecationWarning, stacklevel=2, ) + kwargs.pop("status_retries") # remove before calling super() to prevent unexpected arg error + + super().__init__(**kwargs) self.compute_environment_name = compute_environment_name self.environment_type = environment_type @@ -458,7 +459,7 @@ def __init__( self.service_role = service_role self.tags = tags or {} self.poll_interval = poll_interval - self.max_retries = max_retries + self.max_retries = max_retries or 120 self.aws_conn_id = aws_conn_id self.region_name = region_name self.deferrable = deferrable From c453baf30591d6d6486ede7ef6dcc8a0301d0003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Vandon?= Date: Tue, 27 Jun 2023 14:25:40 -0700 Subject: [PATCH 7/7] switch to common method --- .../providers/amazon/aws/triggers/batch.py | 38 +++++-------------- .../amazon/aws/triggers/test_batch.py | 3 +- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py index 37c86c6d6b3f4..b0bdbc0d4578b 100644 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ b/airflow/providers/amazon/aws/triggers/batch.py @@ -22,8 +22,8 @@ from botocore.exceptions import WaiterError -from airflow import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -235,33 +235,13 @@ async def run(self): hook = BatchClientHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) async with hook.async_conn as client: waiter = hook.get_waiter("compute_env_ready", deferrable=True, client=client) - attempt = 0 - while attempt < self.max_retries: - attempt = attempt + 1 - try: - await waiter.wait( - computeEnvironments=[self.compute_env_arn], - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": 1, - }, - ) - break - except WaiterError as error: - if "terminal failure" in str(error): - raise - self.log.info( - "Compute Environment status is %s. Retrying attempt %s/%s", - error.last_response["computeEnvironments"][0]["status"], - attempt, - self.max_retries, - ) - await asyncio.sleep(int(self.poll_interval)) - - if attempt >= self.max_retries: - raise AirflowException( - f"Compute Environment {self.compute_env_arn} is still not ready " - f"after checking its state {self.max_retries} times." + await async_wait( + waiter=waiter, + waiter_delay=self.poll_interval, + waiter_max_attempts=self.max_retries, + args={"computeEnvironments": [self.compute_env_arn]}, + failure_message="Failure while creating Compute Environment", + status_message="Compute Environment not ready yet", + status_args=["computeEnvironments[].status", "computeEnvironments[].statusReason"], ) - else: yield TriggerEvent({"status": "success", "value": self.compute_env_arn}) diff --git a/tests/providers/amazon/aws/triggers/test_batch.py b/tests/providers/amazon/aws/triggers/test_batch.py index 5e405839b3b4a..e33736076237f 100644 --- a/tests/providers/amazon/aws/triggers/test_batch.py +++ b/tests/providers/amazon/aws/triggers/test_batch.py @@ -22,6 +22,7 @@ import pytest from botocore.exceptions import WaiterError +from airflow import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.triggers.batch import ( BatchCreateComputeEnvironmentTrigger, @@ -218,6 +219,6 @@ async def test_failure(self, get_waiter_mock, conn_mock): ) trigger = BatchCreateComputeEnvironmentTrigger("my_arn", poll_interval=0, max_retries=3) - with pytest.raises(WaiterError): + with pytest.raises(AirflowException): generator = trigger.run() await generator.asend(None)