From bcb5a4c9a1534e60634b25adbfced9b791d82d73 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Wed, 5 Apr 2023 20:54:26 +0200 Subject: [PATCH] Revert "Add AWS deferrable BatchOperator (#29300)" This reverts commit 77c272e6e8ecda0ce48917064e58ba14f6a15844. --- .../amazon/aws/hooks/batch_client.py | 239 +----------------- .../providers/amazon/aws/operators/batch.py | 37 --- .../providers/amazon/aws/triggers/batch.py | 123 --------- .../operators/batch.rst | 5 +- .../aws/deferrable/hooks/test_batch_client.py | 213 ---------------- .../aws/deferrable/triggers/test_batch.py | 131 ---------- .../amazon/aws/hooks/test_batch_client.py | 1 - .../amazon/aws/operators/test_batch.py | 120 +-------- 8 files changed, 3 insertions(+), 866 deletions(-) delete mode 100644 airflow/providers/amazon/aws/triggers/batch.py delete mode 100644 tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py delete mode 100644 tests/providers/amazon/aws/deferrable/triggers/test_batch.py diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py index 10b93afbbc2b3..526ab9a8a4f01 100644 --- a/airflow/providers/amazon/aws/hooks/batch_client.py +++ b/airflow/providers/amazon/aws/hooks/batch_client.py @@ -26,17 +26,15 @@ """ from __future__ import annotations -import asyncio from random import uniform from time import sleep -from typing import Any import botocore.client import botocore.exceptions import botocore.waiter from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.typing_compat import Protocol, runtime_checkable @@ -546,238 +544,3 @@ def exp(tries): delay = 1 + pow(tries * 0.6, 2) delay = min(max_interval, delay) return uniform(delay / 3, delay) - - -class BatchClientAsyncHook(BatchClientHook, AwsBaseAsyncHook): - """ - Async client for AWS Batch services. - - :param job_id: the job ID, usually unknown (None) until the - submit_job operation gets the jobId defined by AWS Batch - - :param waiters: an :py:class:`.BatchWaiters` object (see note below); - if None, polling is used with max_retries and status_retries. - - .. note:: - Several methods use a default random delay to check or poll for job status, i.e. - ``random.sample()`` - Using a random interval helps to avoid AWS API throttle limits - when many concurrent tasks request job-descriptions. - - To modify the global defaults for the range of jitter allowed when a - random delay is used to check Batch job status, modify these defaults, e.g.: - - BatchClient.DEFAULT_DELAY_MIN = 0 - BatchClient.DEFAULT_DELAY_MAX = 5 - - When explicit delay values are used, a 1 second random jitter is applied to the - delay . It is generally recommended that random jitter is added to API requests. - A convenience method is provided for this, e.g. to get a random delay of - 10 sec +/- 5 sec: ``delay = BatchClient.add_jitter(10, width=5, minima=0)`` - """ - - def __init__(self, job_id: str | None, waiters: Any = None, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.job_id = job_id - self.waiters = waiters - - async def monitor_job(self) -> dict[str, str] | None: - """ - Monitor an AWS Batch job - monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout - is given while creating the task. These exceptions should be handled in taskinstance.py - instead of here like it was previously done - - :raises: AirflowException - """ - if not self.job_id: - raise AirflowException("AWS Batch job - job_id was not found") - - if self.waiters: - self.waiters.wait_for_job(self.job_id) - return None - else: - await self.wait_for_job(self.job_id) - await self.check_job_success(self.job_id) - success_msg = f"AWS Batch job ({self.job_id}) succeeded" - self.log.info(success_msg) - return {"status": "success", "message": success_msg} - - async def check_job_success(self, job_id: str) -> bool: # type: ignore[override] - """ - Check the final status of the Batch job; return True if the job - 'SUCCEEDED', else raise an AirflowException - - :param job_id: a Batch job ID - - :raises: AirflowException - """ - job = await self.get_job_description(job_id) - job_status = job.get("status") - if job_status == self.SUCCESS_STATE: - self.log.info("AWS Batch job (%s) succeeded: %s", job_id, job) - return True - - if job_status == self.FAILURE_STATE: - raise AirflowException(f"AWS Batch job ({job_id}) failed: {job}") - - if job_status in self.INTERMEDIATE_STATES: - raise AirflowException(f"AWS Batch job ({job_id}) is not complete: {job}") - - raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}") - - @staticmethod - async def delay(delay: int | float | None = None) -> None: # type: ignore[override] - """ - Pause execution for ``delay`` seconds. - - :param delay: a delay to pause execution using ``time.sleep(delay)``; - a small 1 second jitter is applied to the delay. - - .. note:: - This method uses a default random delay, i.e. - ``random.sample()``; - using a random interval helps to avoid AWS API throttle limits - when many concurrent tasks request job-descriptions. - """ - if delay is None: - delay = uniform(BatchClientHook.DEFAULT_DELAY_MIN, BatchClientHook.DEFAULT_DELAY_MAX) - else: - delay = BatchClientAsyncHook.add_jitter(delay) - await asyncio.sleep(delay) - - async def wait_for_job( # type: ignore[override] - self, job_id: str, delay: int | float | None = None - ) -> None: - """ - Wait for Batch job to complete. - - :param job_id: a Batch job ID - - :param delay: a delay before polling for job status - - :raises: AirflowException - """ - await self.delay(delay) - await self.poll_for_job_running(job_id, delay) - await self.poll_for_job_complete(job_id, delay) - self.log.info("AWS Batch job (%s) has completed", job_id) - - async def poll_for_job_complete( # type: ignore[override] - self, job_id: str, delay: int | float | None = None - ) -> None: - """ - Poll for job completion. The status that indicates job completion - are: 'SUCCEEDED'|'FAILED'. - - So the status options that this will wait for are the transitions from: - 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' - - :param job_id: a Batch job ID - - :param delay: a delay before polling for job status - - :raises: AirflowException - """ - await self.delay(delay) - complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE] - await self.poll_job_status(job_id, complete_status) - - async def poll_for_job_running( # type: ignore[override] - self, job_id: str, delay: int | float | None = None - ) -> None: - """ - Poll for job running. The status that indicates a job is running or - already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'. - - So the status options that this will wait for are the transitions from: - 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'|'SUCCEEDED'|'FAILED' - - The completed status options are included for cases where the status - changes too quickly for polling to detect a RUNNING status that moves - quickly from STARTING to RUNNING to completed (often a failure). - - :param job_id: a Batch job ID - - :param delay: a delay before polling for job status - - :raises: AirflowException - """ - await self.delay(delay) - running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE] - await self.poll_job_status(job_id, running_status) - - async def get_job_description(self, job_id: str) -> dict[str, str]: # type: ignore[override] - """ - Get job description (using status_retries). - - :param job_id: a Batch job ID - :raises: AirflowException - """ - retries = 0 - async with await self.get_client_async() as client: - while True: - try: - response = client.describe_jobs(jobs=[job_id]) - return self.parse_job_description(job_id, response) - - except botocore.exceptions.ClientError as err: - error = err.response.get("Error", {}) - if error.get("Code") == "TooManyRequestsException": - pass # allow it to retry, if possible - else: - raise AirflowException(f"AWS Batch job ({job_id}) description error: {err}") - - retries += 1 - if retries >= self.status_retries: - raise AirflowException( - f"AWS Batch job ({job_id}) description error: exceeded status_retries " - f"({self.status_retries})" - ) - - pause = self.exponential_delay(retries) - self.log.info( - "AWS Batch job (%s) description retry (%d of %d) in the next %.2f seconds", - job_id, - retries, - self.status_retries, - pause, - ) - await self.delay(pause) - - async def poll_job_status(self, job_id: str, match_status: list[str]) -> bool: # type: ignore[override] - """ - Poll for job status using an exponential back-off strategy (with max_retries). - The Batch job status polled are: - 'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED' - - :param job_id: a Batch job ID - :param match_status: a list of job status to match - :raises: AirflowException - """ - retries = 0 - while True: - job = await self.get_job_description(job_id) - job_status = job.get("status") - self.log.info( - "AWS Batch job (%s) check status (%s) in %s", - job_id, - job_status, - match_status, - ) - if job_status in match_status: - return True - - if retries >= self.max_retries: - raise AirflowException(f"AWS Batch job ({job_id}) status checks exceed max_retries") - - retries += 1 - pause = self.exponential_delay(retries) - self.log.info( - "AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds", - job_id, - retries, - self.max_retries, - pause, - ) - await self.delay(pause) diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py index 79a10a7b17c22..6565bcecfbaf1 100644 --- a/airflow/providers/amazon/aws/operators/batch.py +++ b/airflow/providers/amazon/aws/operators/batch.py @@ -37,7 +37,6 @@ BatchJobQueueLink, ) from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: @@ -72,7 +71,6 @@ class BatchOperator(BaseOperator): Override the region_name in connection (if provided) :param tags: collection of tags to apply to the AWS Batch job submission if None, no tags are submitted - :param deferrable: Run operator in the deferrable mode. .. note:: Any custom waiters must return a waiter for these calls: @@ -127,7 +125,6 @@ def __init__( region_name: str | None = None, tags: dict | None = None, wait_for_completion: bool = True, - deferrable: bool = False, **kwargs, ): @@ -142,8 +139,6 @@ def __init__( self.waiters = waiters self.tags = tags or {} self.wait_for_completion = wait_for_completion - self.deferrable = deferrable - self.hook = BatchClientHook( max_retries=max_retries, status_retries=status_retries, @@ -159,43 +154,11 @@ def execute(self, context: Context): """ self.submit_job(context) - if self.deferrable: - self.defer( - timeout=self.execution_timeout, - trigger=BatchOperatorTrigger( - job_id=self.job_id, - job_name=self.job_name, - job_definition=self.job_definition, - job_queue=self.job_queue, - overrides=self.overrides, - array_properties=self.array_properties, - parameters=self.parameters, - waiters=self.waiters, - tags=self.tags, - max_retries=self.hook.max_retries, - status_retries=self.hook.status_retries, - aws_conn_id=self.hook.aws_conn_id, - region_name=self.hook.region_name, - ), - method_name="execute_complete", - ) - if self.wait_for_completion: self.monitor_job(context) return self.job_id - def execute_complete(self, context: Context, event: dict[str, Any]): - """ - Callback for when the trigger fires - returns immediately. - Relies on trigger to throw an exception, otherwise it assumes execution was - successful. - """ - if "status" in event and event["status"] == "error": - raise AirflowException(event["message"]) - self.log.info(event["message"]) - return self.job_id - def on_kill(self): response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user") self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response) diff --git a/airflow/providers/amazon/aws/triggers/batch.py b/airflow/providers/amazon/aws/triggers/batch.py deleted file mode 100644 index eb5a80a3c956b..0000000000000 --- a/airflow/providers/amazon/aws/triggers/batch.py +++ /dev/null @@ -1,123 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import Any, AsyncIterator - -from airflow.providers.amazon.aws.hooks.batch_client import BatchClientAsyncHook -from airflow.triggers.base import BaseTrigger, TriggerEvent - - -class BatchOperatorTrigger(BaseTrigger): - """ - Checks for the state of a previously submitted job to AWS Batch. - BatchOperatorTrigger is fired as deferred class with params to poll the job state in Triggerer - - :param job_id: the job ID, usually unknown (None) until the - submit_job operation gets the jobId defined by AWS Batch - :param job_name: the name for the job that will run on AWS Batch (templated) - :param job_definition: the job definition name on AWS Batch - :param job_queue: the queue name on AWS Batch - :param overrides: the `containerOverrides` parameter for boto3 (templated) - :param array_properties: the `arrayProperties` parameter for boto3 - :param parameters: the `parameters` for boto3 (templated) - :param waiters: a :class:`.BatchWaiters` object (see note below); - if None, polling is used with max_retries and status_retries. - :param tags: collection of tags to apply to the AWS Batch job submission - if None, no tags are submitted - :param max_retries: exponential back-off retries, 4200 = 48 hours; - polling is only used when waiters is None - :param status_retries: number of HTTP retries to get job status, 10; - polling is only used when waiters is None - :param aws_conn_id: connection id of AWS credentials / region name. If None, - credential boto3 strategy will be used. - :param region_name: AWS region name to use . - Override the region_name in connection (if provided) - """ - - def __init__( - self, - job_id: str | None, - job_name: str, - job_definition: str, - job_queue: str, - overrides: dict[str, str], - array_properties: dict[str, str], - parameters: dict[str, str], - waiters: Any, - tags: dict[str, str], - max_retries: int, - status_retries: int, - region_name: str | None, - aws_conn_id: str | None = "aws_default", - ): - super().__init__() - self.job_id = job_id - self.job_name = job_name - self.job_definition = job_definition - self.job_queue = job_queue - self.overrides = overrides or {} - self.array_properties = array_properties or {} - self.parameters = parameters or {} - self.waiters = waiters - self.tags = tags or {} - self.max_retries = max_retries - self.status_retries = status_retries - self.aws_conn_id = aws_conn_id - self.region_name = region_name - - def serialize(self) -> tuple[str, dict[str, Any]]: - """Serializes BatchOperatorTrigger arguments and classpath.""" - return ( - "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger", - { - "job_id": self.job_id, - "job_name": self.job_name, - "job_definition": self.job_definition, - "job_queue": self.job_queue, - "overrides": self.overrides, - "array_properties": self.array_properties, - "parameters": self.parameters, - "waiters": self.waiters, - "tags": self.tags, - "max_retries": self.max_retries, - "status_retries": self.status_retries, - "aws_conn_id": self.aws_conn_id, - "region_name": self.region_name, - }, - ) - - async def run(self) -> AsyncIterator["TriggerEvent"]: - """ - Make async connection using aiobotocore library to AWS Batch, - periodically poll for the job status on the Triggerer - - The status that indicates job completion are: 'SUCCEEDED'|'FAILED'. - - So the status options that this will poll for are the transitions from: - 'SUBMITTED'>'PENDING'>'RUNNABLE'>'STARTING'>'RUNNING'>'SUCCEEDED'|'FAILED' - """ - hook = BatchClientAsyncHook(job_id=self.job_id, waiters=self.waiters, aws_conn_id=self.aws_conn_id) - try: - response = await hook.monitor_job() - if response: - yield TriggerEvent(response) - else: - error_message = f"{self.job_id} failed" - yield TriggerEvent({"status": "error", "message": error_message}) - except Exception as e: - yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/docs/apache-airflow-providers-amazon/operators/batch.rst b/docs/apache-airflow-providers-amazon/operators/batch.rst index 0c686184b9224..ba280cb38d37e 100644 --- a/docs/apache-airflow-providers-amazon/operators/batch.rst +++ b/docs/apache-airflow-providers-amazon/operators/batch.rst @@ -37,10 +37,7 @@ Operators Submit a new AWS Batch job ========================== -To submit a new AWS Batch job and monitor it until it reaches a terminal state. -You can also run this operator in deferrable mode by setting the parameter ``deferrable`` to True. -This will lead to efficient utilization of Airflow workers as polling for job status happens on -the triggerer asynchronously. Note that this will need triggerer to be available on your Airflow deployment. +To submit a new AWS Batch job and monitor it until it reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.operators.batch.BatchOperator`. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_batch.py diff --git a/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py b/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py deleted file mode 100644 index 10be746ef9c45..0000000000000 --- a/tests/providers/amazon/aws/deferrable/hooks/test_batch_client.py +++ /dev/null @@ -1,213 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import sys - -import botocore -import pytest - -from airflow.exceptions import AirflowException -from airflow.providers.amazon.aws.hooks.batch_client import BatchClientAsyncHook - -if sys.version_info < (3, 8): - # For compatibility with Python 3.7 - from asynctest import mock as async_mock -else: - from unittest import mock as async_mock - -pytest.importorskip("aiobotocore") - - -class TestBatchClientAsyncHook: - JOB_ID = "e2a459c5-381b-494d-b6e8-d6ee334db4e2" - BATCH_API_SUCCESS_RESPONSE = {"jobs": [{"jobId": JOB_ID, "status": "SUCCEEDED"}]} - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") - async def test_monitor_job_with_success(self, mock_poll_job_status, mock_client): - """Tests that the monitor_job method returns expected event once successful""" - mock_poll_job_status.return_value = True - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( - self.BATCH_API_SUCCESS_RESPONSE - ) - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - result = await hook.monitor_job() - assert result == {"status": "success", "message": f"AWS Batch job ({self.JOB_ID}) succeeded"} - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") - async def test_monitor_job_with_no_job_id(self, mock_poll_job_status, mock_client): - """Tests that the monitor_job method raises expected exception when incorrect job id is passed""" - mock_poll_job_status.return_value = True - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( - self.BATCH_API_SUCCESS_RESPONSE - ) - - with pytest.raises(AirflowException) as exc_info: - hook = BatchClientAsyncHook(job_id=False, waiters=None) - await hook.monitor_job() - assert str(exc_info.value) == "AWS Batch job - job_id was not found" - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") - async def test_hit_api_throttle(self, mock_poll_job_status, mock_client): - """ - Tests that the get_job_description method raises correct exception when retries - exceed the threshold - """ - mock_poll_job_status.return_value = True - mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = ( - botocore.exceptions.ClientError( - error_response={ - "Error": { - "Code": "TooManyRequestsException", - } - }, - operation_name="get job description", - ) - ) - """status_retries = 2 ensures that exponential_delay block is covered in batch_client.py - otherwise the code coverage will drop""" - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, status_retries=2) - with pytest.raises(AirflowException) as exc_info: - await hook.get_job_description(job_id=self.JOB_ID) - assert ( - str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description error: exceeded " - "status_retries (2)" - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.poll_job_status") - async def test_client_error(self, mock_poll_job_status, mock_client): - """Test that the get_job_description method raises correct exception when the error code - from boto3 api is not TooManyRequestsException""" - mock_poll_job_status.return_value = True - mock_client.return_value.__aenter__.return_value.describe_jobs.side_effect = ( - botocore.exceptions.ClientError( - error_response={"Error": {"Code": "InvalidClientTokenId", "Message": "Malformed Token"}}, - operation_name="get job description", - ) - ) - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, status_retries=1) - with pytest.raises(AirflowException) as exc_info: - await hook.get_job_description(job_id=self.JOB_ID) - assert ( - str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) description error: An error " - "occurred (InvalidClientTokenId) when calling the get job description operation: " - "Malformed Token" - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_success(self, mock_client): - """Tests that the check_job_success method returns True when job succeeds""" - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = ( - self.BATCH_API_SUCCESS_RESPONSE - ) - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - result = await hook.check_job_success(job_id=self.JOB_ID) - assert result is True - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_raises_exception_failed(self, mock_client): - """Tests that the check_job_success method raises exception correctly as per job state""" - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "FAILED"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - with pytest.raises(AirflowException) as exc_info: - await hook.check_job_success(job_id=self.JOB_ID) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) failed" + ": " + str( - mock_job["jobs"][0] - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_raises_exception_pending(self, mock_client): - """Tests that the check_job_success method raises exception correctly as per job state""" - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "PENDING"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - with pytest.raises(AirflowException) as exc_info: - await hook.check_job_success(job_id=self.JOB_ID) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( - mock_job["jobs"][0] - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_raises_exception_strange(self, mock_client): - """Tests that the check_job_success method raises exception correctly as per job state""" - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "STRANGE"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - with pytest.raises(AirflowException) as exc_info: - await hook.check_job_success(job_id=self.JOB_ID) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) has unknown status" + ": " + str( - mock_job["jobs"][0] - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_raises_exception_runnable(self, mock_client): - """Tests that the check_job_success method raises exception correctly as per job state""" - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - with pytest.raises(AirflowException) as exc_info: - await hook.check_job_success(job_id=self.JOB_ID) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( - mock_job["jobs"][0] - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_check_job_raises_exception_submitted(self, mock_client): - """Tests that the check_job_success method raises exception correctly as per job state""" - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "SUBMITTED"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None) - with pytest.raises(AirflowException) as exc_info: - await hook.check_job_success(job_id=self.JOB_ID) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) is not complete" + ": " + str( - mock_job["jobs"][0] - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_poll_job_status_raises_for_max_retries(self, mock_client): - mock_job = {"jobs": [{"jobId": self.JOB_ID, "status": "RUNNABLE"}]} - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, max_retries=1) - with pytest.raises(AirflowException) as exc_info: - await hook.poll_job_status(job_id=self.JOB_ID, match_status=["SUCCEEDED"]) - assert str(exc_info.value) == f"AWS Batch job ({self.JOB_ID}) status checks exceed " "max_retries" - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.get_client_async") - async def test_poll_job_status_in_match_status(self, mock_client): - mock_job = self.BATCH_API_SUCCESS_RESPONSE - mock_client.return_value.__aenter__.return_value.describe_jobs.return_value = mock_job - hook = BatchClientAsyncHook(job_id=self.JOB_ID, waiters=None, max_retries=1) - result = await hook.poll_job_status(job_id=self.JOB_ID, match_status=["SUCCEEDED"]) - assert result is True diff --git a/tests/providers/amazon/aws/deferrable/triggers/test_batch.py b/tests/providers/amazon/aws/deferrable/triggers/test_batch.py deleted file mode 100644 index ad534619f0c02..0000000000000 --- a/tests/providers/amazon/aws/deferrable/triggers/test_batch.py +++ /dev/null @@ -1,131 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import asyncio -import importlib.util - -import pytest - -from airflow.providers.amazon.aws.triggers.batch import ( - BatchOperatorTrigger, -) -from airflow.triggers.base import TriggerEvent -from tests.providers.amazon.aws.utils.compat import async_mock - -JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" -JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" -MAX_RETRIES = 2 -STATUS_RETRIES = 3 -POKE_INTERVAL = 5 -AWS_CONN_ID = "airflow_test" -REGION_NAME = "eu-west-1" - - -@pytest.mark.skipif(not bool(importlib.util.find_spec("aiobotocore")), reason="aiobotocore require") -class TestBatchOperatorTrigger: - TRIGGER = BatchOperatorTrigger( - job_id=JOB_ID, - job_name=JOB_NAME, - job_definition="hello-world", - job_queue="queue", - waiters=None, - tags={}, - max_retries=MAX_RETRIES, - status_retries=STATUS_RETRIES, - parameters={}, - overrides={}, - array_properties={}, - region_name="eu-west-1", - aws_conn_id="airflow_test", - ) - - def test_batch_trigger_serialization(self): - """ - Asserts that the BatchOperatorTrigger correctly serializes its arguments - and classpath. - """ - - classpath, kwargs = self.TRIGGER.serialize() - assert classpath == "airflow.providers.amazon.aws.triggers.batch.BatchOperatorTrigger" - assert kwargs == { - "job_id": JOB_ID, - "job_name": JOB_NAME, - "job_definition": "hello-world", - "job_queue": "queue", - "waiters": None, - "tags": {}, - "max_retries": MAX_RETRIES, - "status_retries": STATUS_RETRIES, - "parameters": {}, - "overrides": {}, - "array_properties": {}, - "region_name": "eu-west-1", - "aws_conn_id": "airflow_test", - } - - @pytest.mark.asyncio - async def test_batch_trigger_run(self): - """Test that the task is not done when event is not returned from trigger.""" - - task = asyncio.create_task(self.TRIGGER.run().__anext__()) - await asyncio.sleep(0.5) - # TriggerEvent was not returned - assert task.done() is False - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") - async def test_batch_trigger_completed(self, mock_response): - """Test if the success event is returned from trigger.""" - mock_response.return_value = {"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"} - - generator = self.TRIGGER.run() - actual_response = await generator.asend(None) - assert ( - TriggerEvent({"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}) - == actual_response - ) - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") - async def test_batch_trigger_failure(self, mock_response): - """Test if the failure event is returned from trigger.""" - mock_response.return_value = {"status": "error", "message": f"{JOB_ID} failed"} - - generator = self.TRIGGER.run() - actual_response = await generator.asend(None) - assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") - async def test_batch_trigger_none(self, mock_response): - """Test if the failure event is returned when there is no response from hook.""" - mock_response.return_value = None - - generator = self.TRIGGER.run() - actual_response = await generator.asend(None) - assert TriggerEvent({"status": "error", "message": f"{JOB_ID} failed"}) == actual_response - - @pytest.mark.asyncio - @async_mock.patch("airflow.providers.amazon.aws.hooks.batch_client.BatchClientAsyncHook.monitor_job") - async def test_batch_trigger_exception(self, mock_response): - """Test if the exception is raised from trigger.""" - mock_response.side_effect = Exception("Test exception") - - task = [i async for i in self.TRIGGER.run()] - assert len(task) == 1 - assert TriggerEvent({"status": "error", "message": "Test exception"}) in task diff --git a/tests/providers/amazon/aws/hooks/test_batch_client.py b/tests/providers/amazon/aws/hooks/test_batch_client.py index d7e06d9eb23ae..13726e5518ff4 100644 --- a/tests/providers/amazon/aws/hooks/test_batch_client.py +++ b/tests/providers/amazon/aws/hooks/test_batch_client.py @@ -20,7 +20,6 @@ import logging from unittest import mock -import botocore import botocore.exceptions import pytest diff --git a/tests/providers/amazon/aws/operators/test_batch.py b/tests/providers/amazon/aws/operators/test_batch.py index 2192b7c4e20aa..0ddfcea591713 100644 --- a/tests/providers/amazon/aws/operators/test_batch.py +++ b/tests/providers/amazon/aws/operators/test_batch.py @@ -19,18 +19,11 @@ from unittest import mock -import pendulum import pytest -from airflow.exceptions import AirflowException, TaskDeferred -from airflow.models import DAG -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.operators.batch import BatchCreateComputeEnvironmentOperator, BatchOperator -from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger -from airflow.utils import timezone -from airflow.utils.types import DagRunType # Use dummy AWS credentials AWS_REGION = "eu-west-1" @@ -218,114 +211,3 @@ def test_execute(self, mock_conn): computeResources=compute_resources, tags=tags, ) - - -def create_context(task, dag=None): - if dag is None: - dag = DAG(dag_id="dag") - tzinfo = pendulum.timezone("UTC") - execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo) - dag_run = DagRun( - dag_id=dag.dag_id, - execution_date=execution_date, - run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date), - ) - - task_instance = TaskInstance(task=task) - task_instance.dag_run = dag_run - task_instance.xcom_push = mock.Mock() - return { - "dag": dag, - "ts": execution_date.isoformat(), - "task": task, - "ti": task_instance, - "task_instance": task_instance, - "run_id": dag_run.run_id, - "dag_run": dag_run, - "execution_date": execution_date, - "data_interval_end": execution_date, - "logical_date": execution_date, - } - - -class TestBatchOperatorAsync: - JOB_NAME = "51455483-c62c-48ac-9b88-53a6a725baa3" - JOB_ID = "8ba9d676-4108-4474-9dca-8bbac1da9b19" - MAX_RETRIES = 2 - STATUS_RETRIES = 3 - RESPONSE_WITHOUT_FAILURES = { - "jobName": JOB_NAME, - "jobId": JOB_ID, - } - - @mock.patch("airflow.providers.amazon.aws.hooks.batch_client.AwsBaseHook.get_client_type") - def test_batch_op_async(self, get_client_type_mock): - get_client_type_mock.return_value.submit_job.return_value = self.RESPONSE_WITHOUT_FAILURES - task = BatchOperator( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="airflow_test", - region_name="eu-west-1", - tags={}, - deferrable=True, - ) - context = create_context(task) - with pytest.raises(TaskDeferred) as exc: - task.execute(context) - assert isinstance(exc.value.trigger, BatchOperatorTrigger), "Trigger is not a BatchOperatorTrigger" - - def test_batch_op_async_execute_failure(self): - """Tests that an AirflowException is raised in case of error event""" - - task = BatchOperator( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="airflow_test", - region_name="eu-west-1", - tags={}, - deferrable=True, - ) - with pytest.raises(AirflowException) as exc_info: - task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) - - assert str(exc_info.value) == "test failure message" - - @pytest.mark.parametrize( - "event", - [{"status": "success", "message": f"AWS Batch job ({JOB_ID}) succeeded"}], - ) - def test_batch_op_async_execute_complete(self, caplog, event): - """Tests that execute_complete method returns None and that it prints expected log""" - task = BatchOperator( - task_id="task", - job_name=self.JOB_NAME, - job_queue="queue", - job_definition="hello-world", - max_retries=self.MAX_RETRIES, - status_retries=self.STATUS_RETRIES, - parameters=None, - overrides={}, - array_properties=None, - aws_conn_id="airflow_test", - region_name="eu-west-1", - tags={}, - deferrable=True, - ) - with mock.patch.object(task.log, "info") as mock_log_info: - assert task.execute_complete(context=None, event=event) is None - - mock_log_info.assert_called_with(f"AWS Batch job ({self.JOB_ID}) succeeded")