diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py index 7093120ceafd0..ceb1c99928282 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py @@ -46,7 +46,6 @@ class GlueJobHook(AwsBaseHook): :param script_location: path to etl script on s3 :param retry_limit: Maximum number of times to retry this job if it fails :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job - :param region_name: aws region name (example: us-east-1) :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. :param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 4b80d47b046c8..5cf8f2ecb5ff1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -60,7 +60,6 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): :param script_args: etl script arguments and AWS Glue arguments (templated) :param retry_limit: The maximum number of times to retry this job if it fails :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job. - :param region_name: aws region name (example: us-east-1) :param s3_bucket: S3 bucket where logs and local etl script will be uploaded :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. :param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set `iam_role_name` must equal None. @@ -79,6 +78,17 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): Thus if status is returned immediately it might end up in case of more than 1 concurrent run. It is recommended to set this parameter to 10 when you are using concurrency=1. For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ aws_hook_class = GlueJobHook @@ -122,9 +132,9 @@ def __init__( verbose: bool = False, replace_script_file: bool = False, update_config: bool = False, - job_poll_interval: int | float = 6, stop_job_run_on_kill: bool = False, sleep_before_return: int = 0, + job_poll_interval: int | float = 6, **kwargs, ): super().__init__(**kwargs) @@ -231,7 +241,8 @@ def execute(self, context: Context): run_id=self._job_run_id, verbose=self.verbose, aws_conn_id=self.aws_conn_id, - job_poll_interval=self.job_poll_interval, + waiter_delay=int(self.job_poll_interval), + waiter_max_attempts=self.retry_limit, ), method_name="execute_complete", ) @@ -254,7 +265,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if validated_event["status"] != "success": raise AirflowException(f"Error in glue job: {validated_event}") - return validated_event["value"] + return validated_event["run_id"] def on_kill(self): """Cancel the running AWS Glue Job.""" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index fb360f9102171..6437d5513cd90 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -18,7 +18,6 @@ from __future__ import annotations from collections.abc import Sequence -from functools import cached_property from typing import TYPE_CHECKING, Any from airflow.configuration import conf @@ -28,16 +27,16 @@ from airflow.providers.amazon.aws.triggers.glue import ( GlueDataQualityRuleRecommendationRunCompleteTrigger, GlueDataQualityRuleSetEvaluationRunCompleteTrigger, + GlueJobCompleteTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context -class GlueJobSensor(BaseSensorOperator): +class GlueJobSensor(AwsBaseSensor[GlueJobHook]): """ Waits for an AWS Glue Job to reach any of the status below. @@ -50,9 +49,29 @@ class GlueJobSensor(BaseSensorOperator): :param job_name: The AWS Glue Job unique name :param run_id: The AWS Glue current running job identifier :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 60) + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("job_name", "run_id") + SUCCESS_STATES = ("SUCCEEDED",) + FAILURE_STATES = ("FAILED", "STOPPED", "TIMEOUT") + + aws_hook_class = GlueJobHook + template_fields: Sequence[str] = aws_template_fields("job_name", "run_id") def __init__( self, @@ -60,6 +79,9 @@ def __init__( job_name: str, run_id: str, verbose: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poke_interval: int = 120, + max_retries: int = 60, aws_conn_id: str | None = "aws_default", **kwargs, ): @@ -67,24 +89,46 @@ def __init__( self.job_name = job_name self.run_id = run_id self.verbose = verbose + self.deferrable = deferrable + self.poke_interval = poke_interval + self.max_retries = max_retries self.aws_conn_id = aws_conn_id - self.success_states: list[str] = ["SUCCEEDED"] - self.errored_states: list[str] = ["FAILED", "STOPPED", "TIMEOUT"] self.next_log_tokens = GlueJobHook.LogContinuationTokens() - @cached_property - def hook(self): - return GlueJobHook(aws_conn_id=self.aws_conn_id) + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=GlueJobCompleteTrigger( + job_name=self.job_name, + run_id=self.run_id, + verbose=self.verbose, + aws_conn_id=self.aws_conn_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + ), + method_name="execute_complete", + ) + else: + super().execute(context=context) - def poke(self, context: Context): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validated_event = validate_execute_complete_event(event) + + if validated_event["status"] != "success": + message = f"Error: AWS Glue Job: {validated_event}" + raise AirflowException(message) + + self.log.info("AWS Glue Job completed.") + + def poke(self, context: Context) -> bool: self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id) job_state = self.hook.get_job_state(job_name=self.job_name, run_id=self.run_id) try: - if job_state in self.success_states: + if job_state in self.SUCCESS_STATES: self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) return True - if job_state in self.errored_states: + if job_state in self.FAILURE_STATES: job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state self.log.info(job_error_message) raise AirflowException(job_error_message) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py index 4a56f47689a1b..6314bce52886f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py @@ -31,49 +31,62 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent -class GlueJobCompleteTrigger(BaseTrigger): +class GlueJobCompleteTrigger(AwsBaseWaiterTrigger): """ Watches for a glue job, triggers when it finishes. :param job_name: glue job name :param run_id: the ID of the specific run to watch for that job :param verbose: whether to print the job's logs in airflow logs or not - :param aws_conn_id: The Airflow connection used for AWS credentials. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials + :param region_name: Optional aws region name (example: us-east-1). Uses region from connection + if not specified. + :param verify: Whether or not to verify SSL certificates. + :param botocore_config: Configuration dictionary (key-values) for botocore client. """ def __init__( self, job_name: str, run_id: str, - verbose: bool, - aws_conn_id: str | None, - job_poll_interval: int | float, + verbose: bool = False, + waiter_delay: int = 60, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): - super().__init__() + super().__init__( + serialized_fields={"job_name": job_name, "run_id": run_id, "verbose": verbose}, + waiter_name="job_complete", + waiter_args={"JobName": job_name, "RunId": run_id}, + failure_message="AWS Glue job failed.", + status_message="Status of AWS Glue job is", + status_queries=["JobRun.JobRunState"], + return_key="run_id", + return_value=run_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + verify=verify, + botocore_config=botocore_config, + ) self.job_name = job_name self.run_id = run_id self.verbose = verbose - self.aws_conn_id = aws_conn_id - self.job_poll_interval = job_poll_interval - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - # dynamically generate the fully qualified name of the class - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "job_name": self.job_name, - "run_id": self.run_id, - "verbose": self.verbose, - "aws_conn_id": self.aws_conn_id, - "job_poll_interval": self.job_poll_interval, - }, + def hook(self) -> AwsGenericHook: + return GlueJobHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval) - await hook.async_job_completion(self.job_name, self.run_id, self.verbose) - yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id}) - class GlueCatalogPartitionTrigger(BaseTrigger): """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json index 2fb355809dab4..68a4842ce0bae 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json @@ -1,6 +1,61 @@ { "version": 2, "waiters": { + "job_complete": { + "operation": "GetJobRun", + "delay": 60, + "maxAttempts": 75, + "acceptors": [ + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STARTING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "RUNNING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STOPPING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STOPPED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "ERROR", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "TIMEOUT", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "SUCCEEDED", + "state": "success" + } + ] + }, "crawler_ready": { "operation": "GetCrawler", "delay": 5, diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py index 2d4925a016c18..ef57439523bc6 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor @@ -152,3 +152,66 @@ def test_fail_poke(self, get_job_state): job_error_message = "Exiting Job" with pytest.raises(AirflowException, match=job_error_message): op.poke(context={}) + + def test_deferrable_execute_raises_task_deferred(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + deferrable=True, + poke_interval=1, + timeout=5, + ) + with pytest.raises(TaskDeferred): + sensor.execute({}) + + @mock.patch.object(GlueJobSensor, "defer") + def test_default_timeout(self, mock_defer): + mock_defer.side_effect = TaskDeferred(trigger=mock.Mock(), method_name="execute_complete") + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name="job_name", + run_id="job_run_id", + deferrable=True, + poke_interval=5, + max_retries=30, + ) + with pytest.raises(TaskDeferred): + sensor.execute({}) + call_kwargs = mock_defer.call_args.kwargs["trigger"] + assert call_kwargs.attempts == 30 + mock_defer.assert_called_once() + + def test_default_args(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + ) + assert sensor.poke_interval == 120 + assert sensor.verbose is False + assert sensor.deferrable is False or isinstance(sensor.deferrable, bool) + assert sensor.aws_conn_id == "aws_default" + + def test_custom_args(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + verbose=True, + deferrable=True, + poke_interval=10, + aws_conn_id="custom_conn", + max_retries=20, + ) + assert sensor.verbose is True + assert sensor.deferrable is True + assert sensor.poke_interval == 10 + assert sensor.aws_conn_id == "custom_conn" + assert sensor.max_retries == 20 diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py index 354d70cfab6ed..d9b8ccaa452c9 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py @@ -40,47 +40,55 @@ class TestGlueJobTrigger: @pytest.mark.asyncio - @mock.patch.object(GlueJobHook, "async_get_job_state") - async def test_wait_job(self, get_state_mock: mock.MagicMock): + @mock.patch.object(GlueJobHook, "get_waiter") + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_wait_job(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) - get_state_mock.side_effect = [ - "RUNNING", - "RUNNING", - "SUCCEEDED", - ] - generator = trigger.run() event = await generator.asend(None) # type:ignore[attr-defined] - assert get_state_mock.call_count == 3 + assert_expected_waiter_type(mock_get_waiter, "job_complete") + mock_get_waiter().wait.assert_called_once() assert event.payload["status"] == "success" + assert event.payload["run_id"] == "JobRunId" @pytest.mark.asyncio - @mock.patch.object(GlueJobHook, "async_get_job_state") - async def test_wait_job_failed(self, get_state_mock: mock.MagicMock): + @mock.patch.object(GlueJobHook, "get_waiter") + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_wait_job_failed(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + from botocore.exceptions import WaiterError + + mock_get_waiter().wait = AsyncMock( + side_effect=WaiterError( + name="job_complete", + reason="Waiter encountered a terminal failure state", + last_response={"JobRun": {"JobRunState": "FAILED"}}, + ) + ) + trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) - get_state_mock.side_effect = [ - "RUNNING", - "RUNNING", - "FAILED", - ] + generator = trigger.run() with pytest.raises(AirflowException): - await trigger.run().asend(None) # type:ignore[attr-defined] - - assert get_state_mock.call_count == 3 + await generator.asend(None) # type:ignore[attr-defined] + assert_expected_waiter_type(mock_get_waiter, "job_complete") def test_serialization(self): trigger = GlueJobCompleteTrigger( @@ -88,7 +96,8 @@ def test_serialization(self): run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.providers.amazon.aws.triggers.glue.GlueJobCompleteTrigger" @@ -97,7 +106,8 @@ def test_serialization(self): "run_id": "JobRunId", "verbose": False, "aws_conn_id": "aws_conn_id", - "job_poll_interval": 0.1, + "waiter_max_attempts": 3, + "waiter_delay": 10, } diff --git a/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py b/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py index 431082ee21f39..e4f8fb39e7fb2 100644 --- a/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py @@ -22,7 +22,7 @@ import botocore import pytest -from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook +from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook from airflow.providers.amazon.aws.sensors.glue import ( GlueDataQualityRuleRecommendationRunSensor, GlueDataQualityRuleSetEvaluationRunSensor, @@ -104,3 +104,46 @@ def test_data_quality_rule_recommendation_run_wait(self, mock_get_job): GlueDataQualityHook().get_waiter(self.WAITER_NAME).wait( RunIc="run_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} ) + + +class TestGlueJobCompleteCustomWaiterBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("glue") + monkeypatch.setattr(GlueJobHook, "conn", self.client) + + +class TestGlueJobCompleteWaiter(TestGlueJobCompleteCustomWaiterBase): + WAITER_NAME = "job_complete" + + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "get_job_run") as mock_getter: + yield mock_getter + + @pytest.mark.parametrize("state", ["SUCCEEDED"]) + def test_glue_job_run_success(self, state, mock_get_job): + mock_get_job.return_value = {"JobRun": {"JobRunState": state}} + + GlueJobHook().get_waiter(self.WAITER_NAME).wait(JobName="example", RunId="run_id") + + @pytest.mark.parametrize("state", ["STOPPED", "FAILED", "ERROR", "TIMEOUT"]) + def test_glue_job_run_failure(self, state, mock_get_job): + mock_get_job.return_value = {"JobRun": {"JobRunState": state}} + + with pytest.raises(botocore.exceptions.WaiterError): + GlueJobHook().get_waiter(self.WAITER_NAME).wait(JobName="example", RunId="run_id") + + @pytest.mark.parametrize("intermediate", ["STARTING", "RUNNING", "STOPPING"]) + def test_glue_job_run_retry_then_success(self, intermediate, mock_get_job): + mock_get_job.side_effect = [ + {"JobRun": {"JobRunState": intermediate}}, + {"JobRun": {"JobRunState": intermediate}}, + {"JobRun": {"JobRunState": "SUCCEEDED"}}, + ] + + GlueJobHook().get_waiter(self.WAITER_NAME).wait( + JobName="example", + RunId="run_id", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + )