Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook


Expand Down Expand Up @@ -88,11 +88,7 @@ def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time)

if state in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = "Athena sensor failed"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException("Athena sensor failed")

if state in self.INTERMEDIATE_STATES:
return False
Expand Down
40 changes: 7 additions & 33 deletions airflow/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,7 @@ def poke(self, context: Context) -> bool:
if state in BatchClientHook.INTERMEDIATE_STATES:
return False

if state == BatchClientHook.FAILURE_STATE:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Batch sensor failed. AWS Batch job status: {state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Batch sensor failed. Unknown AWS Batch job status: {state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}")

def execute(self, context: Context) -> None:
if not self.deferrable:
Expand Down Expand Up @@ -127,12 +116,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event["status"] != "success":
message = f"Error while running job: {event}"
# TODO: remove this if-else block when min_airflow_version is set to higher than the version that
# changed in https://github.com/apache/airflow/pull/33424 is released (2.7.1)
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Error while running job: {event}")
job_id = event["job_id"]
self.log.info("Batch Job %s complete", job_id)

Expand Down Expand Up @@ -198,11 +182,7 @@ def poke(self, context: Context) -> bool:
)

if not response["computeEnvironments"]:
message = f"AWS Batch compute environment {self.compute_environment} not found"
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found")

status = response["computeEnvironments"][0]["status"]

Expand All @@ -212,11 +192,9 @@ def poke(self, context: Context) -> bool:
if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
return False

# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(
f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
)


class BatchJobQueueSensor(BaseSensorOperator):
Expand Down Expand Up @@ -276,11 +254,7 @@ def poke(self, context: Context) -> bool:
if self.treat_non_existing_as_deleted:
return True
else:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"AWS Batch job queue {self.job_queue} not found"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")

status = response["jobQueues"][0]["status"]

Expand Down
5 changes: 1 addition & 4 deletions airflow/providers/amazon/aws/sensors/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence, TypeVar

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.bedrock import BedrockAgentHook, BedrockHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.bedrock import (
Expand Down Expand Up @@ -76,9 +76,6 @@ def __init__(
def poke(self, context: Context, **kwargs) -> bool:
state = self.get_state()
if state in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(self.FAILURE_MESSAGE)
raise AirflowException(self.FAILURE_MESSAGE)

return state not in self.INTERMEDIATE_STATES
Expand Down
13 changes: 2 additions & 11 deletions airflow/providers/amazon/aws/sensors/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowSkipException
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook


Expand Down Expand Up @@ -67,11 +66,7 @@ def poke(self, context: Context):
if stack_status in ("CREATE_IN_PROGRESS", None):
return False

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Stack {self.stack_name} in bad state: {stack_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise ValueError(message)
raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")


class CloudFormationDeleteStackSensor(AwsBaseSensor[CloudFormationHook]):
Expand Down Expand Up @@ -119,8 +114,4 @@ def poke(self, context: Context):
if stack_status == "DELETE_IN_PROGRESS":
return False

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Stack {self.stack_name} in bad state: {stack_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise ValueError(message)
raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")
8 changes: 1 addition & 7 deletions airflow/providers/amazon/aws/sensors/comprehend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.comprehend import (
Expand Down Expand Up @@ -71,9 +71,6 @@ def __init__(
def poke(self, context: Context, **kwargs) -> bool:
state = self.get_state()
if state in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(self.FAILURE_MESSAGE)
raise AirflowException(self.FAILURE_MESSAGE)

return state not in self.INTERMEDIATE_STATES
Expand Down Expand Up @@ -241,9 +238,6 @@ def poke(self, context: Context, **kwargs) -> bool:
)

if status in self.FAILURE_STATES:
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(self.FAILURE_MESSAGE)
raise AirflowException(self.FAILURE_MESSAGE)

if status in self.SUCCESS_STATES:
Expand Down
16 changes: 5 additions & 11 deletions airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
Expand Down Expand Up @@ -75,23 +75,17 @@ def get_hook(self) -> DmsHook:

def poke(self, context: Context):
if not (status := self.hook.get_task_status(self.replication_task_arn)):
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(
f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
)

self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status)

if status in self.target_statuses:
return True

if status in self.termination_statuses:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Unexpected status: {status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Unexpected status: {status}")

return False

Expand Down
8 changes: 2 additions & 6 deletions airflow/providers/amazon/aws/sensors/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
Expand Down Expand Up @@ -97,8 +97,4 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
event = validate_execute_complete_event(event)

if event["status"] != "success":
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Error: {event}")
10 changes: 4 additions & 6 deletions airflow/providers/amazon/aws/sensors/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
Expand All @@ -37,11 +37,9 @@

def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None:
if (current_state != target_state) and (current_state in failure_states):
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
if soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(
f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
)


class EcsBaseSensor(AwsBaseSensor[EcsHook]):
Expand Down
12 changes: 5 additions & 7 deletions airflow/providers/amazon/aws/sensors/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import (
ClusterStates,
EksHook,
Expand Down Expand Up @@ -106,12 +106,10 @@ def poke(self, context: Context) -> bool:
state = self.get_state()
self.log.info("Current state: %s", state)
if state in (self.get_terminal_states() - {self.target_state}):
# If we reach a terminal state which is not the target state:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
# If we reach a terminal state which is not the target state
raise AirflowException(
f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}"
)
return state == self.target_state

@abstractmethod
Expand Down
45 changes: 9 additions & 36 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.exceptions import (
AirflowException,
AirflowProviderDeprecationWarning,
AirflowSkipException,
)
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
Expand Down Expand Up @@ -91,11 +90,7 @@ def poke(self, context: Context):
return True

if state in self.failed_states:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"EMR job failed: {self.failure_message_from_response(response)}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}")

return False

Expand Down Expand Up @@ -172,11 +167,9 @@ def poke(self, context: Context) -> bool:
state = response["jobRun"]["state"]

if state in EmrServerlessHook.JOB_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(failure_message)
raise AirflowException(failure_message)
raise AirflowException(
f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
)

return state in self.target_states

Expand Down Expand Up @@ -234,13 +227,9 @@ def poke(self, context: Context) -> bool:
state = response["application"]["state"]

if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
failure_message = (
raise AirflowException(
f"EMR Serverless application failed: {self.failure_message_from_response(response)}"
)
if self.soft_fail:
raise AirflowSkipException(failure_message)
raise AirflowException(failure_message)

return state in self.target_states

Expand Down Expand Up @@ -328,11 +317,7 @@ def poke(self, context: Context) -> bool:
)

if state in self.FAILURE_STATES:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "EMR Containers sensor failed"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException("EMR Containers sensor failed")

if state in self.INTERMEDIATE_STATES:
return False
Expand Down Expand Up @@ -370,11 +355,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
event = validate_execute_complete_event(event)

if event["status"] != "success":
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Error while running job: {event}")

self.log.info("Job completed.")

Expand Down Expand Up @@ -563,11 +544,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
event = validate_execute_complete_event(event)

if event["status"] != "success":
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Error while running job: {event}")
self.log.info("Job completed.")


Expand Down Expand Up @@ -696,10 +673,6 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None
event = validate_execute_complete_event(event)

if event["status"] != "success":
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
raise AirflowException(f"Error while running job: {event}")

self.log.info("Job %s completed.", self.job_flow_id)
Loading