diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index 38f2bb54f8b33..1dc8a4fe3335f 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.athena import AthenaHook @@ -88,11 +88,7 @@ def poke(self, context: Context) -> bool: state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time) if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = "Athena sensor failed" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException("Athena sensor failed") if state in self.INTERMEDIATE_STATES: return False diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py index c5dcb0e46de6a..9c1a29f8098f9 100644 --- a/airflow/providers/amazon/aws/sensors/batch.py +++ b/airflow/providers/amazon/aws/sensors/batch.py @@ -86,18 +86,7 @@ def poke(self, context: Context) -> bool: if state in BatchClientHook.INTERMEDIATE_STATES: return False - if state == BatchClientHook.FAILURE_STATE: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Batch sensor failed. AWS Batch job status: {state}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) - - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Batch sensor failed. Unknown AWS Batch job status: {state}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}") def execute(self, context: Context) -> None: if not self.deferrable: @@ -127,12 +116,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event["status"] != "success": - message = f"Error while running job: {event}" - # TODO: remove this if-else block when min_airflow_version is set to higher than the version that - # changed in https://github.com/apache/airflow/pull/33424 is released (2.7.1) - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Error while running job: {event}") job_id = event["job_id"] self.log.info("Batch Job %s complete", job_id) @@ -198,11 +182,7 @@ def poke(self, context: Context) -> bool: ) if not response["computeEnvironments"]: - message = f"AWS Batch compute environment {self.compute_environment} not found" - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found") status = response["computeEnvironments"][0]["status"] @@ -212,11 +192,9 @@ def poke(self, context: Context) -> bool: if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS: return False - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException( + f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}" + ) class BatchJobQueueSensor(BaseSensorOperator): @@ -276,11 +254,7 @@ def poke(self, context: Context) -> bool: if self.treat_non_existing_as_deleted: return True else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"AWS Batch job queue {self.job_queue} not found" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"AWS Batch job queue {self.job_queue} not found") status = response["jobQueues"][0]["status"] diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 8532886554868..e9157ab9c12a7 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence, TypeVar from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.bedrock import ( @@ -76,9 +76,6 @@ def __init__( def poke(self, context: Context, **kwargs) -> bool: state = self.get_state() if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) return state not in self.INTERMEDIATE_STATES diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py index f67278ecec384..ba07433bf0c80 100644 --- a/airflow/providers/amazon/aws/sensors/cloud_formation.py +++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py @@ -27,7 +27,6 @@ if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook @@ -67,11 +66,7 @@ def poke(self, context: Context): if stack_status in ("CREATE_IN_PROGRESS", None): return False - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Stack {self.stack_name} in bad state: {stack_status}" - if self.soft_fail: - raise AirflowSkipException(message) - raise ValueError(message) + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") class CloudFormationDeleteStackSensor(AwsBaseSensor[CloudFormationHook]): @@ -119,8 +114,4 @@ def poke(self, context: Context): if stack_status == "DELETE_IN_PROGRESS": return False - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Stack {self.stack_name} in bad state: {stack_status}" - if self.soft_fail: - raise AirflowSkipException(message) - raise ValueError(message) + raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}") diff --git a/airflow/providers/amazon/aws/sensors/comprehend.py b/airflow/providers/amazon/aws/sensors/comprehend.py index 42344f65e9839..545f7b02fc8dc 100644 --- a/airflow/providers/amazon/aws/sensors/comprehend.py +++ b/airflow/providers/amazon/aws/sensors/comprehend.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.comprehend import ( @@ -71,9 +71,6 @@ def __init__( def poke(self, context: Context, **kwargs) -> bool: state = self.get_state() if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) return state not in self.INTERMEDIATE_STATES @@ -241,9 +238,6 @@ def poke(self, context: Context, **kwargs) -> bool: ) if status in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) if status in self.SUCCESS_STATES: diff --git a/airflow/providers/amazon/aws/sensors/dms.py b/airflow/providers/amazon/aws/sensors/dms.py index 864a3b5276c32..2ea52ea0b5c35 100644 --- a/airflow/providers/amazon/aws/sensors/dms.py +++ b/airflow/providers/amazon/aws/sensors/dms.py @@ -21,7 +21,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -75,11 +75,9 @@ def get_hook(self) -> DmsHook: def poke(self, context: Context): if not (status := self.hook.get_task_status(self.replication_task_arn)): - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException( + f"Failed to read task status, task with ARN {self.replication_task_arn} not found" + ) self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status) @@ -87,11 +85,7 @@ def poke(self, context: Context): return True if status in self.termination_statuses: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Unexpected status: {status}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Unexpected status: {status}") return False diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py index 778bc49caa522..0736c63393ae7 100644 --- a/airflow/providers/amazon/aws/sensors/ec2.py +++ b/airflow/providers/amazon/aws/sensors/ec2.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event @@ -97,8 +97,4 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Error: {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Error: {event}") diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py index 02a212fbde0d7..aba3e55922680 100644 --- a/airflow/providers/amazon/aws/sensors/ecs.py +++ b/airflow/providers/amazon/aws/sensors/ecs.py @@ -19,7 +19,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.ecs import ( EcsClusterStates, EcsHook, @@ -37,11 +37,9 @@ def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None: if (current_state != target_state) and (current_state in failure_states): - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}" - if soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException( + f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}" + ) class EcsBaseSensor(AwsBaseSensor[EcsHook]): diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py index a5dcdeb0ef185..79e160b007402 100644 --- a/airflow/providers/amazon/aws/sensors/eks.py +++ b/airflow/providers/amazon/aws/sensors/eks.py @@ -22,7 +22,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.eks import ( ClusterStates, EksHook, @@ -106,12 +106,10 @@ def poke(self, context: Context) -> bool: state = self.get_state() self.log.info("Current state: %s", state) if state in (self.get_terminal_states() - {self.target_state}): - # If we reach a terminal state which is not the target state: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + # If we reach a terminal state which is not the target state + raise AirflowException( + f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}" + ) return state == self.target_state @abstractmethod diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py index 19e026e7a6c4e..e79642d35c693 100644 --- a/airflow/providers/amazon/aws/sensors/emr.py +++ b/airflow/providers/amazon/aws/sensors/emr.py @@ -27,7 +27,6 @@ from airflow.exceptions import ( AirflowException, AirflowProviderDeprecationWarning, - AirflowSkipException, ) from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri @@ -91,11 +90,7 @@ def poke(self, context: Context): return True if state in self.failed_states: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"EMR job failed: {self.failure_message_from_response(response)}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}") return False @@ -172,11 +167,9 @@ def poke(self, context: Context) -> bool: state = response["jobRun"]["state"] if state in EmrServerlessHook.JOB_FAILURE_STATES: - failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}" - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(failure_message) - raise AirflowException(failure_message) + raise AirflowException( + f"EMR Serverless job failed: {self.failure_message_from_response(response)}" + ) return state in self.target_states @@ -234,13 +227,9 @@ def poke(self, context: Context) -> bool: state = response["application"]["state"] if state in EmrServerlessHook.APPLICATION_FAILURE_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - failure_message = ( + raise AirflowException( f"EMR Serverless application failed: {self.failure_message_from_response(response)}" ) - if self.soft_fail: - raise AirflowSkipException(failure_message) - raise AirflowException(failure_message) return state in self.target_states @@ -328,11 +317,7 @@ def poke(self, context: Context) -> bool: ) if state in self.FAILURE_STATES: - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = "EMR Containers sensor failed" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException("EMR Containers sensor failed") if state in self.INTERMEDIATE_STATES: return False @@ -370,11 +355,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Error while running job: {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Error while running job: {event}") self.log.info("Job completed.") @@ -563,11 +544,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Error while running job: {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Error while running job: {event}") self.log.info("Job completed.") @@ -696,10 +673,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if check when min_airflow_version is set to higher than 2.7.1 - message = f"Error while running job: {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Error while running job: {event}") self.log.info("Job %s completed.", self.job_flow_id) diff --git a/airflow/providers/amazon/aws/sensors/glacier.py b/airflow/providers/amazon/aws/sensors/glacier.py index 7a65fc6fc31ea..eb87d8a51a21f 100644 --- a/airflow/providers/amazon/aws/sensors/glacier.py +++ b/airflow/providers/amazon/aws/sensors/glacier.py @@ -20,7 +20,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glacier import GlacierHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -93,10 +93,6 @@ def poke(self, context: Context) -> bool: self.log.warning("Code status: %s", response["StatusCode"]) return False else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = ( + raise AirflowException( f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}' ) - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py index 8493c2fd4ab6d..062e4ab3efd07 100644 --- a/airflow/providers/amazon/aws/sensors/glue.py +++ b/airflow/providers/amazon/aws/sensors/glue.py @@ -86,9 +86,6 @@ def poke(self, context: Context): elif job_state in self.errored_states: job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state self.log.info(job_error_message) - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(job_error_message) raise AirflowException(job_error_message) else: return False @@ -223,9 +220,6 @@ def poke(self, context: Context): f": {response.get('ErrorString')}" ) self.log.info(job_error_message) - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(job_error_message) raise AirflowException(job_error_message) else: return False @@ -343,9 +337,6 @@ def poke(self, context: Context) -> bool: f": {response.get('ErrorString')}" ) self.log.info(job_error_message) - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(job_error_message) raise AirflowException(job_error_message) else: return False diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py index f397a446becb4..af125e2dda6a8 100644 --- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py +++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py @@ -23,7 +23,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.glue import GlueCatalogPartitionTrigger @@ -127,11 +127,7 @@ def execute_complete(self, context: Context, event: dict | None = None) -> None: event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Trigger error: event is {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Trigger error: event is {event}") self.log.info("Partition exists in the Glue Catalog") @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning) diff --git a/airflow/providers/amazon/aws/sensors/glue_crawler.py b/airflow/providers/amazon/aws/sensors/glue_crawler.py index ce35aef2cfb2d..2d4396c010c39 100644 --- a/airflow/providers/amazon/aws/sensors/glue_crawler.py +++ b/airflow/providers/amazon/aws/sensors/glue_crawler.py @@ -21,7 +21,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -75,11 +75,7 @@ def poke(self, context: Context): self.log.info("Status: %s", crawler_status) return True else: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Status: {crawler_status}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Status: {crawler_status}") else: return False diff --git a/airflow/providers/amazon/aws/sensors/kinesis_analytics.py b/airflow/providers/amazon/aws/sensors/kinesis_analytics.py index 2c02e050c77f4..673445e67d5f9 100644 --- a/airflow/providers/amazon/aws/sensors/kinesis_analytics.py +++ b/airflow/providers/amazon/aws/sensors/kinesis_analytics.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.kinesis_analytics import ( @@ -80,9 +80,6 @@ def poke(self, context: Context, **kwargs) -> bool: ) if status in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) if status in self.SUCCESS_STATES: diff --git a/airflow/providers/amazon/aws/sensors/lambda_function.py b/airflow/providers/amazon/aws/sensors/lambda_function.py index c54dfbe8b71fe..8e01d40235ec8 100644 --- a/airflow/providers/amazon/aws/sensors/lambda_function.py +++ b/airflow/providers/amazon/aws/sensors/lambda_function.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Sequence -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.utils import trim_none_values @@ -78,10 +78,8 @@ def poke(self, context: Context) -> bool: state = self.hook.conn.get_function(**trim_none_values(get_function_args))["Configuration"]["State"] if state in self.FAILURE_STATES: - message = "Lambda function state sensor failed because the Lambda is in a failed state" - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException( + "Lambda function state sensor failed because the Lambda is in a failed state" + ) return state in self.target_states diff --git a/airflow/providers/amazon/aws/sensors/opensearch_serverless.py b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py index 7f5f650e0ee07..0f539f9bfb82d 100644 --- a/airflow/providers/amazon/aws/sensors/opensearch_serverless.py +++ b/airflow/providers/amazon/aws/sensors/opensearch_serverless.py @@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.opensearch_serverless import ( @@ -104,9 +104,6 @@ def poke(self, context: Context, **kwargs) -> bool: state = self.hook.conn.batch_get_collection(**call_args)["collectionDetails"][0]["status"] if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) if state in self.INTERMEDIATE_STATES: diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py b/airflow/providers/amazon/aws/sensors/quicksight.py index 321fa56dd235f..848c0dc7048fc 100644 --- a/airflow/providers/amazon/aws/sensors/quicksight.py +++ b/airflow/providers/amazon/aws/sensors/quicksight.py @@ -22,7 +22,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor @@ -74,10 +74,7 @@ def poke(self, context: Context) -> bool: self.log.info("QuickSight Status: %s", quicksight_ingestion_state) if quicksight_ingestion_state in self.errored_statuses: error = self.hook.get_error_info(None, self.data_set_id, self.ingestion_id) - message = f"The QuickSight Ingestion failed. Error info: {error}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"The QuickSight Ingestion failed. Error info: {error}") return quicksight_ingestion_state == self.success_status @cached_property diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/airflow/providers/amazon/aws/sensors/redshift_cluster.py index 7d2f4ba4724ba..243c71e61fe78 100644 --- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py +++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py @@ -23,7 +23,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event @@ -93,11 +93,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None status = event["status"] if status == "error": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"{event['status']}: {event['message']}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"{event['status']}: {event['message']}") elif status == "success": self.log.info("%s completed successfully.", self.task_id) self.log.info("Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status) diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py index 9c524494cdeb5..2f32fff3d30ac 100644 --- a/airflow/providers/amazon/aws/sensors/s3.py +++ b/airflow/providers/amazon/aws/sensors/s3.py @@ -33,7 +33,7 @@ if TYPE_CHECKING: from airflow.utils.context import Context -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.s3 import S3Hook from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger from airflow.sensors.base import BaseSensorOperator, poke_mode_only @@ -219,9 +219,6 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: if not found_keys: self._defer() elif event["status"] == "error": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning) @@ -342,14 +339,9 @@ def is_keys_unchanged(self, current_objects: set[str]) -> bool: ) return False - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = ( - f"Illegal behavior: objects were deleted in" - f" {os.path.join(self.bucket_name, self.prefix)} between pokes." + raise AirflowException( + f"Illegal behavior: objects were deleted in {os.path.join(self.bucket_name, self.prefix)} between pokes." ) - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) if self.last_activity_time: self.inactivity_seconds = int((datetime.now() - self.last_activity_time).total_seconds()) @@ -411,8 +403,5 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None event = validate_execute_complete_event(event) if event and event["status"] == "error": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - if self.soft_fail: - raise AirflowSkipException(event["message"]) raise AirflowException(event["message"]) return None diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py index 97ac8ad483c1b..b01e24cd5b815 100644 --- a/airflow/providers/amazon/aws/sensors/sagemaker.py +++ b/airflow/providers/amazon/aws/sensors/sagemaker.py @@ -22,7 +22,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook from airflow.sensors.base import BaseSensorOperator @@ -65,11 +65,9 @@ def poke(self, context: Context): return False if state in self.failed_states(): failed_reason = self.get_failed_reason_from_response(response) - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException( + f"Sagemaker {self.resource_type} failed for the following reason: {failed_reason}" + ) return True def non_terminal_states(self) -> set[str]: diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py index 657c6d9599c7d..d04a8cf820b01 100644 --- a/airflow/providers/amazon/aws/sensors/sqs.py +++ b/airflow/providers/amazon/aws/sensors/sqs.py @@ -25,7 +25,7 @@ from deprecated import deprecated from airflow.configuration import conf -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.sqs import SqsHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger @@ -160,11 +160,7 @@ def execute_complete(self, context: Context, event: dict | None = None) -> None: event = validate_execute_complete_event(event) if event["status"] != "success": - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Trigger error: event is {event}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Trigger error: event is {event}") context["ti"].xcom_push(key="messages", value=event["message_batch"]) def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection: @@ -221,11 +217,7 @@ def poke(self, context: Context): response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries) if "Successful" not in response: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - error_message = f"Delete SQS Messages failed {response} for messages {messages}" - if self.soft_fail: - raise AirflowSkipException(error_message) - raise AirflowException(error_message) + raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}") if message_batch: context["ti"].xcom_push(key="messages", value=message_batch) return True diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py index 5e0d3cfcf79cc..8af3bb6fe9c67 100644 --- a/airflow/providers/amazon/aws/sensors/step_function.py +++ b/airflow/providers/amazon/aws/sensors/step_function.py @@ -21,7 +21,7 @@ from deprecated import deprecated -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.utils.mixins import aws_template_fields @@ -76,11 +76,7 @@ def poke(self, context: Context): output = json.loads(execution_status["output"]) if "output" in execution_status else None if state in self.FAILURE_STATES: - # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = f"Step Function sensor failed. State Machine Output: {output}" - if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowException(f"Step Function sensor failed. State Machine Output: {output}") if state in self.INTERMEDIATE_STATES: return False diff --git a/tests/providers/amazon/aws/sensors/test_athena.py b/tests/providers/amazon/aws/sensors/test_athena.py index a973c76a38e4a..c770ab7dde38c 100644 --- a/tests/providers/amazon/aws/sensors/test_athena.py +++ b/tests/providers/amazon/aws/sensors/test_athena.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.athena import AthenaHook from airflow.providers.amazon.aws.sensors.athena import AthenaSensor @@ -73,17 +73,10 @@ def test_poke_intermediate_states(self, state, mock_poll_query_status): mock_poll_query_status.side_effect = [state] assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", ["FAILED", "CANCELLED"]) - def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_poll_query_status): + def test_poke_failure_states(self, state, mock_poll_query_status): mock_poll_query_status.side_effect = [state] - sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None) message = "Athena sensor failed" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_batch.py b/tests/providers/amazon/aws/sensors/test_batch.py index 267aeb998f871..a8ec1b926bb5c 100644 --- a/tests/providers/amazon/aws/sensors/test_batch.py +++ b/tests/providers/amazon/aws/sensors/test_batch.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook from airflow.providers.amazon.aws.sensors.batch import ( BatchComputeEnvironmentSensor, @@ -67,9 +67,7 @@ def test_poke_on_failure_state(self, mock_get_job_description, batch_sensor: Bat @mock.patch.object(BatchClientHook, "get_job_description") def test_poke_on_invalid_state(self, mock_get_job_description, batch_sensor: BatchSensor): mock_get_job_description.return_value = {"status": "INVALID"} - with pytest.raises( - AirflowException, match="Batch sensor failed. Unknown AWS Batch job status: INVALID" - ): + with pytest.raises(AirflowException, match="Batch sensor failed. AWS Batch job status: INVALID"): batch_sensor.poke({}) mock_get_job_description.assert_called_once_with(JOB_ID) @@ -100,23 +98,11 @@ def test_execute_failure_in_deferrable_mode(self, deferrable_batch_sensor: Batch with pytest.raises(AirflowException): deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"}) - def test_execute_failure_in_deferrable_mode_with_soft_fail(self, deferrable_batch_sensor: BatchSensor): - """Tests that an AirflowSkipException is raised in case of error event and soft_fail is set to True""" - deferrable_batch_sensor.soft_fail = True - with pytest.raises(AirflowSkipException): - deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"}) - - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( - "state, error_message", + "state", ( - ( - BatchClientHook.FAILURE_STATE, - f"Batch sensor failed. AWS Batch job status: {BatchClientHook.FAILURE_STATE}", - ), - ("unknown_state", "Batch sensor failed. Unknown AWS Batch job status: unknown_state"), + BatchClientHook.FAILURE_STATE, + "unknown_state", ), ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -125,13 +111,9 @@ def test_fail_poke( mock_get_job_description, batch_sensor: BatchSensor, state, - error_message, - soft_fail, - expected_exception, ): mock_get_job_description.return_value = {"status": state} - batch_sensor.soft_fail = soft_fail - with pytest.raises(expected_exception, match=error_message): + with pytest.raises(AirflowException, match=f"Batch sensor failed. AWS Batch job status: {state}"): batch_sensor.poke({}) @@ -202,9 +184,6 @@ def test_poke_invalid( ) assert "AWS Batch compute environment failed" in str(ctx.value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize( "compute_env, error_message", ( @@ -222,12 +201,9 @@ def test_fail_poke( batch_compute_environment_sensor: BatchComputeEnvironmentSensor, compute_env, error_message, - soft_fail, - expected_exception, ): mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": compute_env} - batch_compute_environment_sensor.soft_fail = soft_fail - with pytest.raises(expected_exception, match=error_message): + with pytest.raises(AirflowException, match=error_message): batch_compute_environment_sensor.poke({}) @@ -299,9 +275,6 @@ def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQ ) assert "AWS Batch job queue failed" in str(ctx.value) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @pytest.mark.parametrize("job_queue", ([], [{"status": "UNKNOWN_STATUS"}])) @mock.patch.object(BatchClientHook, "client") def test_fail_poke( @@ -309,12 +282,9 @@ def test_fail_poke( mock_batch_client, batch_job_queue_sensor: BatchJobQueueSensor, job_queue, - soft_fail, - expected_exception, ): mock_batch_client.describe_job_queues.return_value = {"jobQueues": job_queue} batch_job_queue_sensor.treat_non_existing_as_deleted = False - batch_job_queue_sensor.soft_fail = soft_fail message = "AWS Batch job queue" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): batch_job_queue_sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index 4ed531931608d..52151cd953942 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook from airflow.providers.amazon.aws.sensors.bedrock import ( BedrockCustomizeModelCompletedSensor, @@ -75,19 +75,12 @@ def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_model_customization_job.return_value = {"status": state} assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_model_customization_job.return_value = {"status": state} - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) @@ -135,20 +128,13 @@ def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_provisioned_model_throughput.return_value = {"status": state} assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_provisioned_model_throughput.return_value = {"status": state} - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) @@ -196,19 +182,12 @@ def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_knowledge_base.return_value = {"knowledgeBase": {"status": state}} assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockAgentHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_knowledge_base.return_value = {"knowledgeBase": {"status": state}} - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) @@ -258,17 +237,10 @@ def test_poke_intermediate_states(self, mock_conn, state): mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}} assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(BedrockAgentHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_ingestion_job.return_value = {"ingestionJob": {"status": state}} - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_cloud_formation.py b/tests/providers/amazon/aws/sensors/test_cloud_formation.py index 7edc252864a85..514934637179b 100644 --- a/tests/providers/amazon/aws/sensors/test_cloud_formation.py +++ b/tests/providers/amazon/aws/sensors/test_cloud_formation.py @@ -23,7 +23,6 @@ import pytest from moto import mock_aws -from airflow.exceptions import AirflowSkipException from airflow.providers.amazon.aws.sensors.cloud_formation import ( CloudFormationCreateStackSensor, CloudFormationDeleteStackSensor, @@ -76,17 +75,10 @@ def test_poke_false(self, mocked_hook_client): op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo") assert not op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, ValueError, id="non-soft-fail"), - ], - ) - def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, soft_fail, expected_exception): + def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client): mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]} - op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo", soft_fail=soft_fail) - with pytest.raises(expected_exception, match="Stack foo in bad state: bar"): + op = CloudFormationCreateStackSensor(task_id="task", stack_name="foo") + with pytest.raises(ValueError, match="Stack foo in bad state: bar"): op.poke({}) @@ -132,17 +124,10 @@ def test_poke_false(self, mocked_hook_client): op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") assert not op.poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, ValueError, id="non-soft-fail"), - ], - ) - def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client, soft_fail, expected_exception): + def test_poke_stack_in_unsuccessful_state(self, mocked_hook_client): mocked_hook_client.describe_stacks.return_value = {"Stacks": [{"StackStatus": "bar"}]} - op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo", soft_fail=soft_fail) - with pytest.raises(expected_exception, match="Stack foo in bad state: bar"): + op = CloudFormationDeleteStackSensor(task_id="task", stack_name="foo") + with pytest.raises(ValueError, match="Stack foo in bad state: bar"): op.poke({}) @mock_aws diff --git a/tests/providers/amazon/aws/sensors/test_comprehend.py b/tests/providers/amazon/aws/sensors/test_comprehend.py index 1c80ca5f79803..20e2dc25dd6af 100644 --- a/tests/providers/amazon/aws/sensors/test_comprehend.py +++ b/tests/providers/amazon/aws/sensors/test_comprehend.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook from airflow.providers.amazon.aws.sensors.comprehend import ( ComprehendCreateDocumentClassifierCompletedSensor, @@ -76,22 +76,15 @@ def test_intermediate_state(self, mock_conn, state): } assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(ComprehendHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.describe_pii_entities_detection_job.return_value = { "PiiEntitiesDetectionJobProperties": {"JobStatus": state} } - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) @@ -176,20 +169,13 @@ def test_intermediate_state(self, mock_conn, state): } assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(ComprehendHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.describe_document_classifier.return_value = { "DocumentClassifierProperties": {"Status": state} } - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_dms.py b/tests/providers/amazon/aws/sensors/test_dms.py index eb99949ea053b..3ec2ffccc70da 100644 --- a/tests/providers/amazon/aws/sensors/test_dms.py +++ b/tests/providers/amazon/aws/sensors/test_dms.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.dms import DmsHook from airflow.providers.amazon.aws.sensors.dms import DmsTaskCompletedSensor @@ -87,27 +87,13 @@ def test_poke_not_completed(self, mocked_get_task_status, status): "testing", ], ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) - def test_poke_terminated_status(self, mocked_get_task_status, status, soft_fail, expected_exception): + def test_poke_terminated_status(self, mocked_get_task_status, status): mocked_get_task_status.return_value = status error_message = f"Unexpected status: {status}" with pytest.raises(AirflowException, match=error_message): - DmsTaskCompletedSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({}) + DmsTaskCompletedSensor(**self.default_op_kwargs).poke({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) - def test_poke_none_status(self, mocked_get_task_status, soft_fail, expected_exception): + def test_poke_none_status(self, mocked_get_task_status): mocked_get_task_status.return_value = None with pytest.raises(AirflowException, match="task with ARN .* not found"): - DmsTaskCompletedSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({}) + DmsTaskCompletedSensor(**self.default_op_kwargs).poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_ecs.py b/tests/providers/amazon/aws/sensors/test_ecs.py index 43210ceeb2716..dbbcef8c3d168 100644 --- a/tests/providers/amazon/aws/sensors/test_ecs.py +++ b/tests/providers/amazon/aws/sensors/test_ecs.py @@ -24,7 +24,7 @@ import pytest from slugify import slugify -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.ecs import ( EcsBaseSensor, EcsClusterStates, @@ -34,7 +34,6 @@ EcsTaskDefinitionStateSensor, EcsTaskStates, EcsTaskStateSensor, - _check_failed, ) from airflow.utils import timezone from airflow.utils.types import NOTSET @@ -265,20 +264,3 @@ def test_custom_values_terminal_state(self, failure_states, return_state): with pytest.raises(AirflowException, match="Terminal state reached"): task.poke({}) m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN) - - -@pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) -) -def test_fail__check_failed(soft_fail, expected_exception): - current_state = "FAILED" - target_state = "SUCCESS" - failure_states = ["FAILED"] - message = f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}" - with pytest.raises(expected_exception, match=message): - _check_failed( - current_state=current_state, - target_state=target_state, - failure_states=failure_states, - soft_fail=soft_fail, - ) diff --git a/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py b/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py index c35d84e7fa5af..e52d2ab9e4209 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py +++ b/tests/providers/amazon/aws/sensors/test_emr_serverless_application.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.emr import EmrServerlessApplicationSensor @@ -75,16 +75,3 @@ def test_poke_raises_airflow_exception_with_failure_states(self, state): assert exception_msg == str(ctx.value) self.assert_get_application_was_called_once_with_app_id() - - -class TestPokeRaisesAirflowSkipException(TestEmrServerlessApplicationSensor): - def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self): - self.sensor.soft_fail = True - self.set_get_application_return_value( - {"application": {"state": "STOPPED", "stateDetails": "mock stopped"}} - ) - with pytest.raises(AirflowSkipException) as ctx: - self.sensor.poke(None) - assert "EMR Serverless application failed: mock stopped" == str(ctx.value) - self.assert_get_application_was_called_once_with_app_id() - self.sensor.soft_fail = False diff --git a/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py b/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py index 299efe3fd277e..942abab677e1e 100644 --- a/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py +++ b/tests/providers/amazon/aws/sensors/test_emr_serverless_job.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.emr import EmrServerlessJobSensor @@ -78,14 +78,3 @@ def test_poke_raises_airflow_exception_with_specified_states(self, state): assert exception_msg == str(ctx.value) self.assert_get_job_run_was_called_once_with_app_and_run_id() - - -class TestPokeRaisesAirflowSkipException(TestEmrServerlessJobSensor): - def test_when_state_is_failed_and_soft_fail_is_true_poke_should_raise_skip_exception(self): - self.sensor.soft_fail = True - self.set_get_job_run_return_value({"jobRun": {"state": "FAILED", "stateDetails": "mock failed"}}) - with pytest.raises(AirflowSkipException) as ctx: - self.sensor.poke(None) - assert "EMR Serverless job failed: mock failed" == str(ctx.value) - self.assert_get_job_run_was_called_once_with_app_and_run_id() - self.sensor.soft_fail = False diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py b/tests/providers/amazon/aws/sensors/test_glacier.py index f46a6c47298e8..2c93cfb58ebe4 100644 --- a/tests/providers/amazon/aws/sensors/test_glacier.py +++ b/tests/providers/amazon/aws/sensors/test_glacier.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor, JobStatus SUCCEEDED = "Succeeded" @@ -77,19 +77,11 @@ def test_poke_fail(self, mocked_describe_job): with pytest.raises(AirflowException, match="Sensor failed"): self.op.poke(None) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) - def test_fail_poke(self, soft_fail, expected_exception, mocked_describe_job): - self.op.soft_fail = soft_fail + def test_fail_poke(self, mocked_describe_job): response = {"Action": "some action", "StatusCode": "Failed"} message = f'Sensor failed. Job status: {response["Action"]}, code status: {response["StatusCode"]}' mocked_describe_job.return_value = response - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): self.op.poke(context={}) diff --git a/tests/providers/amazon/aws/sensors/test_glue.py b/tests/providers/amazon/aws/sensors/test_glue.py index d179feb98ee59..2d4925a016c18 100644 --- a/tests/providers/amazon/aws/sensors/test_glue.py +++ b/tests/providers/amazon/aws/sensors/test_glue.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor @@ -134,11 +134,8 @@ def test_poke_failed_job_with_verbose_logging(self, mock_get_job_state, mock_con continuation_tokens=ANY, ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.amazon.aws.hooks.glue.GlueJobHook.get_job_state") - def test_fail_poke(self, get_job_state, soft_fail, expected_exception): + def test_fail_poke(self, get_job_state): job_name = "job_name" job_run_id = "job_run_id" op = GlueJobSensor( @@ -150,9 +147,8 @@ def test_fail_poke(self, get_job_state, soft_fail, expected_exception): verbose=True, ) op.verbose = False - op.soft_fail = soft_fail job_state = "FAILED" get_job_state.return_value = job_state job_error_message = "Exiting Job" - with pytest.raises(expected_exception, match=job_error_message): + with pytest.raises(AirflowException, match=job_error_message): op.poke(context={}) diff --git a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py index 767a6ca0863a5..f8556092d188e 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py +++ b/tests/providers/amazon/aws/sensors/test_glue_catalog_partition.py @@ -22,7 +22,7 @@ import pytest from moto import mock_aws -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook from airflow.providers.amazon.aws.sensors.glue_catalog_partition import GlueCatalogPartitionSensor @@ -112,15 +112,11 @@ def test_execute_complete_succeeds_if_status_is_success(self, caplog): op.execute_complete(context={}, event=event) assert "Partition exists in the Glue Catalog" in caplog.messages - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_execute_complete(self, soft_fail, expected_exception): + def test_fail_execute_complete(self): op = GlueCatalogPartitionSensor(task_id=self.task_id, table_name="tbl", deferrable=True) - op.soft_fail = soft_fail event = {"status": "Failed"} message = f"Trigger error: event is {event}" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): op.execute_complete(context={}, event=event) def test_init(self): diff --git a/tests/providers/amazon/aws/sensors/test_glue_crawler.py b/tests/providers/amazon/aws/sensors/test_glue_crawler.py index 83b3795a1ae13..d5762c8ee61a2 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_crawler.py +++ b/tests/providers/amazon/aws/sensors/test_glue_crawler.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor @@ -53,16 +53,12 @@ def test_poke_cancelled(self, mock_get_crawler): assert self.sensor.poke({}) is False mock_get_crawler.assert_called_once_with("aws_test_glue_crawler") - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) @mock.patch("airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook.get_crawler") - def test_fail_poke(self, get_crawler, soft_fail, expected_exception): - self.sensor.soft_fail = soft_fail + def test_fail_poke(self, get_crawler): crawler_status = "FAILED" get_crawler.return_value = {"State": "READY", "LastCrawl": {"Status": crawler_status}} message = f"Status: {crawler_status}" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): self.sensor.poke(context={}) def test_base_aws_op_attributes(self): diff --git a/tests/providers/amazon/aws/sensors/test_glue_data_quality.py b/tests/providers/amazon/aws/sensors/test_glue_data_quality.py index a37bc0b700ff9..585487f22b4b6 100644 --- a/tests/providers/amazon/aws/sensors/test_glue_data_quality.py +++ b/tests/providers/amazon/aws/sensors/test_glue_data_quality.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook from airflow.providers.amazon.aws.sensors.glue import ( GlueDataQualityRuleRecommendationRunSensor, @@ -133,16 +133,9 @@ def test_poke_intermediate_state(self, mock_conn): assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(GlueDataQualityHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_data_quality_ruleset_evaluation_run.return_value = { "RunId": "12345", "Status": state, @@ -150,11 +143,11 @@ def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_excepti "ErrorString": "unknown error", } - sensor = self.SENSOR(**self.default_args, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_args, aws_conn_id=None) message = f"Error: AWS Glue data quality ruleset evaluation run RunId: 12345 Run Status: {state}: unknown error" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke({}) mock_conn.get_data_quality_ruleset_evaluation_run.assert_called_once_with(RunId="12345") @@ -247,29 +240,22 @@ def test_poke_intermediate_state(self, mock_conn): ) assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(GlueDataQualityHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.get_data_quality_rule_recommendation_run.return_value = { "RunId": "12345", "Status": state, "ErrorString": "unknown error", } - sensor = self.SENSOR(**self.default_args, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_args, aws_conn_id=None) message = ( f"Error: AWS Glue data quality recommendation run RunId: 12345 Run Status: {state}: unknown error" ) - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke({}) mock_conn.get_data_quality_rule_recommendation_run.assert_called_once_with(RunId="12345") diff --git a/tests/providers/amazon/aws/sensors/test_kinesis_analytics.py b/tests/providers/amazon/aws/sensors/test_kinesis_analytics.py index 73a2cfdb7753e..b335931de2682 100644 --- a/tests/providers/amazon/aws/sensors/test_kinesis_analytics.py +++ b/tests/providers/amazon/aws/sensors/test_kinesis_analytics.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.kinesis_analytics import KinesisAnalyticsV2Hook from airflow.providers.amazon.aws.sensors.kinesis_analytics import ( KinesisAnalyticsV2StartApplicationCompletedSensor, @@ -78,24 +78,17 @@ def test_intermediate_state(self, mock_conn, state): } assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(KinesisAnalyticsV2Hook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.describe_application.return_value = { "ApplicationDetail": {"ApplicationARN": self.APPLICATION_ARN, "ApplicationStatus": state} } - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) with pytest.raises( - expected_exception, match="AWS Managed Service for Apache Flink application start failed" + AirflowException, match="AWS Managed Service for Apache Flink application start failed" ): sensor.poke({}) @@ -150,23 +143,16 @@ def test_intermediate_state(self, mock_conn, state): } assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) @mock.patch.object(KinesisAnalyticsV2Hook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.describe_application.return_value = { "ApplicationDetail": {"ApplicationARN": self.APPLICATION_ARN, "ApplicationStatus": state} } - sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) with pytest.raises( - expected_exception, match="AWS Managed Service for Apache Flink application stop failed" + AirflowException, match="AWS Managed Service for Apache Flink application stop failed" ): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_lambda_function.py b/tests/providers/amazon/aws/sensors/test_lambda_function.py index 8f537ead185b7..2e1fbe8981e78 100644 --- a/tests/providers/amazon/aws/sensors/test_lambda_function.py +++ b/tests/providers/amazon/aws/sensors/test_lambda_function.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook from airflow.providers.amazon.aws.sensors.lambda_function import LambdaFunctionStateSensor @@ -87,17 +87,13 @@ def test_poke(self, get_function_output, expect_failure, expected): FunctionName=FUNCTION_NAME, ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_poke(self, soft_fail, expected_exception): + def test_fail_poke(self): sensor = LambdaFunctionStateSensor( task_id="test_sensor", function_name=FUNCTION_NAME, ) - sensor.soft_fail = soft_fail message = "Lambda function state sensor failed because the Lambda is in a failed state" with mock.patch("airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook.conn") as conn: conn.get_function.return_value = {"Configuration": {"State": "Failed"}} - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke(context={}) diff --git a/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py b/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py index 3b7474aeac19d..5543934468145 100644 --- a/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py +++ b/tests/providers/amazon/aws/sensors/test_opensearch_serverless.py @@ -20,7 +20,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.opensearch_serverless import OpenSearchServerlessHook from airflow.providers.amazon.aws.sensors.opensearch_serverless import ( OpenSearchServerlessCollectionActiveSensor, @@ -95,19 +95,10 @@ def test_poke_intermediate_states(self, mock_conn, state): mock_conn.batch_get_collection.return_value = {"collectionDetails": [{"status": state}]} assert self.sensor.poke({}) is False - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(False, AirflowException, id="not-soft-fail"), - pytest.param(True, AirflowSkipException, id="soft-fail"), - ], - ) @pytest.mark.parametrize("state", list(OpenSearchServerlessCollectionActiveSensor.FAILURE_STATES)) @mock.patch.object(OpenSearchServerlessHook, "conn") - def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + def test_poke_failure_states(self, mock_conn, state): mock_conn.batch_get_collection.return_value = {"collectionDetails": [{"status": state}]} - sensor = OpenSearchServerlessCollectionActiveSensor( - **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail - ) - with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor = OpenSearchServerlessCollectionActiveSensor(**self.default_op_kwargs, aws_conn_id=None) + with pytest.raises(AirflowException, match=sensor.FAILURE_MESSAGE): sensor.poke({}) diff --git a/tests/providers/amazon/aws/sensors/test_quicksight.py b/tests/providers/amazon/aws/sensors/test_quicksight.py index 9eb9a49587808..46890a69cbfff 100644 --- a/tests/providers/amazon/aws/sensors/test_quicksight.py +++ b/tests/providers/amazon/aws/sensors/test_quicksight.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor @@ -88,20 +88,11 @@ def test_poke_not_completed(self, status, mocked_get_status): mocked_get_status.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) @pytest.mark.parametrize("status", ["FAILED", "CANCELLED"]) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) - def test_poke_terminated_status( - self, status, soft_fail, expected_exception, mocked_get_status, mocked_get_error_info - ): + def test_poke_terminated_status(self, status, mocked_get_status, mocked_get_error_info): mocked_get_status.return_value = status mocked_get_error_info.return_value = "something bad happen" - with pytest.raises(expected_exception, match="Error info: something bad happen"): - QuickSightSensor(**self.default_op_kwargs, soft_fail=soft_fail).poke({}) + with pytest.raises(AirflowException, match="Error info: something bad happen"): + QuickSightSensor(**self.default_op_kwargs).poke({}) mocked_get_status.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) mocked_get_error_info.assert_called_once_with(None, DATA_SET_ID, INGESTION_ID) diff --git a/tests/providers/amazon/aws/sensors/test_s3.py b/tests/providers/amazon/aws/sensors/test_s3.py index 2d9ee9d52eae7..3c5606dcbe56f 100644 --- a/tests/providers/amazon/aws/sensors/test_s3.py +++ b/tests/providers/amazon/aws/sensors/test_s3.py @@ -24,7 +24,7 @@ import time_machine from moto import mock_aws -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.models import DAG, DagRun, TaskInstance from airflow.models.variable import Variable from airflow.providers.amazon.aws.hooks.s3 import S3Hook @@ -282,18 +282,14 @@ def check_fn(files: list) -> bool: sensor.execute_complete(context={}, event={"status": "running", "files": [{"Size": 10}]}) is None ) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_execute_complete(self, soft_fail, expected_exception): + def test_fail_execute_complete(self): op = S3KeySensor( task_id="s3_key_sensor", bucket_key=["s3://test_bucket/file*", "s3://test_bucket/*.zip"], wildcard_match=True, ) - op.soft_fail = soft_fail message = "error" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): op.execute_complete(context={}, event={"status": "error", "message": message}) @mock_aws @@ -524,25 +520,17 @@ def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine): time_machine.coordinates.shift(10) assert self.sensor.poke(dict()) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_is_keys_unchanged(self, soft_fail, expected_exception): + def test_fail_is_keys_unchanged(self): op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") - op.soft_fail = soft_fail op.previous_objects = {"1", "2", "3"} current_objects = {"1", "2"} op.allow_delete = False message = "Illegal behavior: objects were deleted in" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): op.is_keys_unchanged(current_objects=current_objects) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_execute_complete(self, soft_fail, expected_exception): + def test_fail_execute_complete(self): op = S3KeysUnchangedSensor(task_id="sensor", bucket_name="test-bucket", prefix="test-prefix/path") - op.soft_fail = soft_fail message = "test message" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): op.execute_complete(context={}, event={"status": "error", "message": message}) diff --git a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py index bc324473dc791..c8a3ade2a18be 100644 --- a/tests/providers/amazon/aws/sensors/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/sensors/test_sagemaker_base.py @@ -19,7 +19,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor @@ -110,10 +110,7 @@ def state_from_response(self, response): with pytest.raises(AirflowException): sensor.poke(None) - @pytest.mark.parametrize( - "soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException)) - ) - def test_fail_poke(self, soft_fail, expected_exception): + def test_fail_poke(self): resource_type = "job" class SageMakerBaseSensorSubclass(SageMakerBaseSensor): @@ -132,10 +129,9 @@ def state_from_response(self, response): sensor = SageMakerBaseSensorSubclass( task_id="test_task", poke_interval=2, aws_conn_id="aws_test", resource_type=resource_type ) - sensor.soft_fail = soft_fail message = ( f"Sagemaker {resource_type} failed for the following reason:" f" {sensor.get_failed_reason_from_response({})}" ) - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke(context={}) diff --git a/tests/providers/amazon/aws/sensors/test_sqs.py b/tests/providers/amazon/aws/sensors/test_sqs.py index 0cddcfb36d9aa..457195313c7be 100644 --- a/tests/providers/amazon/aws/sensors/test_sqs.py +++ b/tests/providers/amazon/aws/sensors/test_sqs.py @@ -23,7 +23,7 @@ import pytest from moto import mock_aws -from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sqs import SqsHook from airflow.providers.amazon.aws.sensors.sqs import SqsSensor @@ -418,31 +418,17 @@ def test_sqs_deferrable(self): with pytest.raises(TaskDeferred): sensor.execute({}) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) - def test_fail_execute_complete(self, soft_fail, expected_exception): - sensor = SqsSensor(**self.default_op_kwargs, deferrable=True, soft_fail=soft_fail) + def test_fail_execute_complete(self): + sensor = SqsSensor(**self.default_op_kwargs, deferrable=True) event = {"status": "failed"} message = f"Trigger error: event is {event}" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.execute_complete(context={}, event=event) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) @mock.patch("airflow.providers.amazon.aws.sensors.sqs.SqsSensor.poll_sqs") @mock.patch("airflow.providers.amazon.aws.sensors.sqs.process_response") @mock.patch("airflow.providers.amazon.aws.hooks.sqs.SqsHook.conn") - def test_fail_poke(self, mocked_client, process_response, poll_sqs, soft_fail, expected_exception): + def test_fail_poke(self, mocked_client, process_response, poll_sqs): response = "error message" messages = [{"MessageId": "1", "ReceiptHandle": "test"}] poll_sqs.return_value = response @@ -450,6 +436,6 @@ def test_fail_poke(self, mocked_client, process_response, poll_sqs, soft_fail, e mocked_client.delete_message_batch.return_value = response error_message = f"Delete SQS Messages failed {response} for messages" - sensor = SqsSensor(**self.default_op_kwargs, soft_fail=soft_fail) - with pytest.raises(expected_exception, match=error_message): + sensor = SqsSensor(**self.default_op_kwargs) + with pytest.raises(AirflowException, match=error_message): sensor.poke(context={}) diff --git a/tests/providers/amazon/aws/sensors/test_step_function.py b/tests/providers/amazon/aws/sensors/test_step_function.py index 878691dc1cddc..5289fb053e144 100644 --- a/tests/providers/amazon/aws/sensors/test_step_function.py +++ b/tests/providers/amazon/aws/sensors/test_step_function.py @@ -22,7 +22,7 @@ import pytest -from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor TASK_ID = "step_function_execution_sensor" @@ -78,19 +78,10 @@ def test_succeeded(self, mocked_hook, status, mocked_context): @mock.patch.object(StepFunctionExecutionSensor, "hook") @pytest.mark.parametrize("status", StepFunctionExecutionSensor.FAILURE_STATES) - @pytest.mark.parametrize( - "soft_fail, expected_exception", - [ - pytest.param(True, AirflowSkipException, id="soft-fail"), - pytest.param(False, AirflowException, id="non-soft-fail"), - ], - ) - def test_failure(self, mocked_hook, status, soft_fail, expected_exception, mocked_context): + def test_failure(self, mocked_hook, status, mocked_context): output = {"test": "test"} mocked_hook.describe_execution.return_value = {"status": status, "output": json.dumps(output)} - sensor = StepFunctionExecutionSensor( - task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None, soft_fail=soft_fail - ) + sensor = StepFunctionExecutionSensor(task_id=TASK_ID, execution_arn=EXECUTION_ARN, aws_conn_id=None) message = f"Step Function sensor failed. State Machine Output: {output}" - with pytest.raises(expected_exception, match=message): + with pytest.raises(AirflowException, match=message): sensor.poke(mocked_context)