From b6aca4968595d4d7b66b3c7dcb25d181a23c8542 Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Tue, 17 Jun 2025 19:24:41 +0200 Subject: [PATCH 1/6] Adjusted the GlueJobSensor to inherit from AwsBaseSensor --- .../providers/amazon/aws/sensors/glue.py | 70 +++++++++++++++---- .../providers/amazon/aws/triggers/glue.py | 3 +- .../unit/amazon/aws/sensors/test_glue.py | 46 +++++++++++- 3 files changed, 105 insertions(+), 14 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index fb360f9102171..0fa725c228122 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -18,7 +18,7 @@ from __future__ import annotations from collections.abc import Sequence -from functools import cached_property +from datetime import timedelta from typing import TYPE_CHECKING, Any from airflow.configuration import conf @@ -28,16 +28,16 @@ from airflow.providers.amazon.aws.triggers.glue import ( GlueDataQualityRuleRecommendationRunCompleteTrigger, GlueDataQualityRuleSetEvaluationRunCompleteTrigger, + GlueJobCompleteTrigger, ) from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: from airflow.utils.context import Context -class GlueJobSensor(BaseSensorOperator): +class GlueJobSensor(AwsBaseSensor[GlueJobHook]): """ Waits for an AWS Glue Job to reach any of the status below. @@ -50,9 +50,28 @@ class GlueJobSensor(BaseSensorOperator): :param job_name: The AWS Glue Job unique name :param run_id: The AWS Glue current running job identifier :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False) + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 60) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("job_name", "run_id") + SUCCESS_STATES = ("SUCCEEDED",) + FAILURE_STATES = ("FAILED", "STOPPED", "TIMEOUT") + + aws_hook_class = GlueJobHook + template_fields: Sequence[str] = aws_template_fields("job_name", "run_id") def __init__( self, @@ -60,6 +79,8 @@ def __init__( job_name: str, run_id: str, verbose: bool = False, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poke_interval: int = 120, aws_conn_id: str | None = "aws_default", **kwargs, ): @@ -67,24 +88,49 @@ def __init__( self.job_name = job_name self.run_id = run_id self.verbose = verbose + self.deferrable = deferrable + self.poke_interval = poke_interval self.aws_conn_id = aws_conn_id - self.success_states: list[str] = ["SUCCEEDED"] - self.errored_states: list[str] = ["FAILED", "STOPPED", "TIMEOUT"] self.next_log_tokens = GlueJobHook.LogContinuationTokens() - @cached_property - def hook(self): - return GlueJobHook(aws_conn_id=self.aws_conn_id) + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=GlueJobCompleteTrigger( + job_name=self.job_name, + run_id=self.run_id, + verbose=self.verbose, + aws_conn_id=self.aws_conn_id, + job_poll_interval=self.poke_interval, + ), + method_name="execute_complete", + timeout=timedelta(seconds=self.timeout), + ) + else: + super().execute(context=context) - def poke(self, context: Context): + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + validated_event = validate_execute_complete_event(event) + + if validated_event["status"] != "success": + message = f"Error: AWS Glue Job: {validated_event}" + raise AirflowException(message) + + self.log.info( + "AWS Glue Job completed. Job Name: %s, Run ID: %s", + self.job_name, + self.run_id, + ) + + def poke(self, context: Context) -> bool: self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id) job_state = self.hook.get_job_state(job_name=self.job_name, run_id=self.run_id) try: - if job_state in self.success_states: + if job_state in self.SUCCESS_STATES: self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state) return True - if job_state in self.errored_states: + if job_state in self.FAILURE_STATES: job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state self.log.info(job_error_message) raise AirflowException(job_error_message) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py index 4a56f47689a1b..03eed0caab0c2 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py @@ -38,7 +38,8 @@ class GlueJobCompleteTrigger(BaseTrigger): :param job_name: glue job name :param run_id: the ID of the specific run to watch for that job :param verbose: whether to print the job's logs in airflow logs or not - :param aws_conn_id: The Airflow connection used for AWS credentials. + :param aws_conn_id: The Airflow connection used for AWS credentials + :param job_poll_interval: The interval in which to poll the status of a job """ def __init__( diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py index 2d4925a016c18..f2d68d824cef9 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py @@ -21,7 +21,7 @@ import pytest -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.glue import GlueJobHook from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor @@ -152,3 +152,47 @@ def test_fail_poke(self, get_job_state): job_error_message = "Exiting Job" with pytest.raises(AirflowException, match=job_error_message): op.poke(context={}) + + def test_deferrable_execute_raises_task_deferred(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + deferrable=True, + poke_interval=1, + timeout=5, + ) + with pytest.raises(TaskDeferred): + sensor.execute({}) + + def test_default_args(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + ) + assert sensor.poke_interval == 120 + assert sensor.verbose is False + assert sensor.deferrable is False or isinstance(sensor.deferrable, bool) + assert sensor.aws_conn_id == "aws_default" + + def test_custom_args(self): + job_name = "job_name" + job_run_id = "job_run_id" + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name=job_name, + run_id=job_run_id, + verbose=True, + deferrable=True, + poke_interval=10, + aws_conn_id="custom_conn", + ) + assert sensor.verbose is True + assert sensor.deferrable is True + assert sensor.poke_interval == 10 + assert sensor.aws_conn_id == "custom_conn" From 5dee2312211dc7ce1c54b936e5a7be89ac8eff07 Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Wed, 18 Jun 2025 10:12:31 +0200 Subject: [PATCH 2/6] Changed timeout logic and added further tests --- .../providers/amazon/aws/sensors/glue.py | 2 +- .../tests/unit/amazon/aws/sensors/test_glue.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index 0fa725c228122..c6c01a8c76977 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -54,7 +54,7 @@ class GlueJobSensor(AwsBaseSensor[GlueJobHook]): module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) - :param max_retries: Number of times before returning the current state. (default: 60) + :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py index f2d68d824cef9..4766f7294c8a6 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +from datetime import timedelta from unittest import mock from unittest.mock import ANY @@ -167,6 +168,21 @@ def test_deferrable_execute_raises_task_deferred(self): with pytest.raises(TaskDeferred): sensor.execute({}) + @mock.patch.object(GlueJobSensor, "defer") + def test_max_retries_and_timeout(self, mock_defer): + mock_defer.side_effect = TaskDeferred(trigger=mock.Mock(), method_name="execute_complete") + sensor = GlueJobSensor( + task_id="test_glue_job_sensor", + job_name="job_name", + run_id="job_run_id", + deferrable=True, + poke_interval=5, + ) + with pytest.raises(TaskDeferred): + sensor.execute({}) + assert mock_defer.call_args[1]["timeout"] == timedelta(days=7) + mock_defer.assert_called_once() + def test_default_args(self): job_name = "job_name" job_run_id = "job_run_id" @@ -191,8 +207,10 @@ def test_custom_args(self): deferrable=True, poke_interval=10, aws_conn_id="custom_conn", + timeout=120, ) assert sensor.verbose is True assert sensor.deferrable is True assert sensor.poke_interval == 10 assert sensor.aws_conn_id == "custom_conn" + assert sensor.timeout == 120 From 7d69f8f695bc26db72b3eeec4b2efb49f608a86e Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Wed, 18 Jun 2025 10:23:54 +0200 Subject: [PATCH 3/6] Renamed test case due to removal of max_retries param --- providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py index 4766f7294c8a6..c3e24617693fe 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py @@ -169,7 +169,7 @@ def test_deferrable_execute_raises_task_deferred(self): sensor.execute({}) @mock.patch.object(GlueJobSensor, "defer") - def test_max_retries_and_timeout(self, mock_defer): + def test_default_timeout(self, mock_defer): mock_defer.side_effect = TaskDeferred(trigger=mock.Mock(), method_name="execute_complete") sensor = GlueJobSensor( task_id="test_glue_job_sensor", From 81f2501b699cf196feeeb36530ee51d4e68cdcf8 Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Wed, 18 Jun 2025 13:01:16 +0200 Subject: [PATCH 4/6] Added custom GlueJob waiter --- .../providers/amazon/aws/operators/glue.py | 9 ++- .../providers/amazon/aws/sensors/glue.py | 8 ++- .../providers/amazon/aws/triggers/glue.py | 62 ++++++++++++------- .../providers/amazon/aws/waiters/glue.json | 55 ++++++++++++++++ .../unit/amazon/aws/sensors/test_glue.py | 9 +-- .../unit/amazon/aws/waiters/test_glue.py | 45 +++++++++++++- 6 files changed, 154 insertions(+), 34 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 4b80d47b046c8..61b1cc98e338d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -122,9 +122,11 @@ def __init__( verbose: bool = False, replace_script_file: bool = False, update_config: bool = False, - job_poll_interval: int | float = 6, + waiter_delay: int = 60, + waiter_max_attempts: int = 20, stop_job_run_on_kill: bool = False, sleep_before_return: int = 0, + job_poll_interval: int | float = 6, **kwargs, ): super().__init__(**kwargs) @@ -147,6 +149,8 @@ def __init__( self.update_config = update_config self.replace_script_file = replace_script_file self.deferrable = deferrable + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts self.job_poll_interval = job_poll_interval self.stop_job_run_on_kill = stop_job_run_on_kill self._job_run_id: str | None = None @@ -231,7 +235,8 @@ def execute(self, context: Context): run_id=self._job_run_id, verbose=self.verbose, aws_conn_id=self.aws_conn_id, - job_poll_interval=self.job_poll_interval, + waiter_delay=int(self.waiter_delay), + waiter_max_attempts=self.waiter_max_attempts, ), method_name="execute_complete", ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index c6c01a8c76977..7d7b5e5e2afc6 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -18,7 +18,6 @@ from __future__ import annotations from collections.abc import Sequence -from datetime import timedelta from typing import TYPE_CHECKING, Any from airflow.configuration import conf @@ -54,6 +53,7 @@ class GlueJobSensor(AwsBaseSensor[GlueJobHook]): module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 60) :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If @@ -81,6 +81,7 @@ def __init__( verbose: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poke_interval: int = 120, + max_retries: int = 60, aws_conn_id: str | None = "aws_default", **kwargs, ): @@ -90,6 +91,7 @@ def __init__( self.verbose = verbose self.deferrable = deferrable self.poke_interval = poke_interval + self.max_retries = max_retries self.aws_conn_id = aws_conn_id self.next_log_tokens = GlueJobHook.LogContinuationTokens() @@ -101,10 +103,10 @@ def execute(self, context: Context) -> Any: run_id=self.run_id, verbose=self.verbose, aws_conn_id=self.aws_conn_id, - job_poll_interval=self.poke_interval, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, ), method_name="execute_complete", - timeout=timedelta(seconds=self.timeout), ) else: super().execute(context=context) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py index 03eed0caab0c2..3f3b196668f97 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py @@ -31,50 +31,64 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent -class GlueJobCompleteTrigger(BaseTrigger): +class GlueJobCompleteTrigger(AwsBaseWaiterTrigger): """ Watches for a glue job, triggers when it finishes. :param job_name: glue job name :param run_id: the ID of the specific run to watch for that job :param verbose: whether to print the job's logs in airflow logs or not + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials - :param job_poll_interval: The interval in which to poll the status of a job + :param region_name: Optional aws region name (example: us-east-1). Uses region from connection + if not specified. + :param verify: Whether or not to verify SSL certificates. + :param botocore_config: Configuration dictionary (key-values) for botocore client. """ def __init__( self, job_name: str, run_id: str, - verbose: bool, - aws_conn_id: str | None, - job_poll_interval: int | float, + verbose: bool = False, + waiter_delay: int = 60, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, ): - super().__init__() + super().__init__( + serialized_fields={"job_name": job_name, "run_id": run_id, "verbose": verbose}, + waiter_name="job_complete", + waiter_args={"JobName": job_name, "RunId": run_id}, + failure_message="AWS Glue job failed.", + status_message="Status of AWS Glue job is", + status_queries=["JobRun.JobRunState"], + return_key="run_id", + return_value=run_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + region_name=region_name, + verify=verify, + botocore_config=botocore_config, + ) self.job_name = job_name self.run_id = run_id self.verbose = verbose - self.aws_conn_id = aws_conn_id - self.job_poll_interval = job_poll_interval + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts - def serialize(self) -> tuple[str, dict[str, Any]]: - return ( - # dynamically generate the fully qualified name of the class - self.__class__.__module__ + "." + self.__class__.__qualname__, - { - "job_name": self.job_name, - "run_id": self.run_id, - "verbose": self.verbose, - "aws_conn_id": self.aws_conn_id, - "job_poll_interval": self.job_poll_interval, - }, + def hook(self) -> AwsGenericHook: + return GlueJobHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, ) - async def run(self) -> AsyncIterator[TriggerEvent]: - hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval) - await hook.async_job_completion(self.job_name, self.run_id, self.verbose) - yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id}) - class GlueCatalogPartitionTrigger(BaseTrigger): """ diff --git a/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json b/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json index 2fb355809dab4..68a4842ce0bae 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json +++ b/providers/amazon/src/airflow/providers/amazon/aws/waiters/glue.json @@ -1,6 +1,61 @@ { "version": 2, "waiters": { + "job_complete": { + "operation": "GetJobRun", + "delay": 60, + "maxAttempts": 75, + "acceptors": [ + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STARTING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "RUNNING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STOPPING", + "state": "retry" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "STOPPED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "ERROR", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "TIMEOUT", + "state": "failure" + }, + { + "matcher": "path", + "argument": "JobRun.JobRunState", + "expected": "SUCCEEDED", + "state": "success" + } + ] + }, "crawler_ready": { "operation": "GetCrawler", "delay": 5, diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py index c3e24617693fe..ef57439523bc6 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_glue.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -from datetime import timedelta from unittest import mock from unittest.mock import ANY @@ -177,10 +176,12 @@ def test_default_timeout(self, mock_defer): run_id="job_run_id", deferrable=True, poke_interval=5, + max_retries=30, ) with pytest.raises(TaskDeferred): sensor.execute({}) - assert mock_defer.call_args[1]["timeout"] == timedelta(days=7) + call_kwargs = mock_defer.call_args.kwargs["trigger"] + assert call_kwargs.attempts == 30 mock_defer.assert_called_once() def test_default_args(self): @@ -207,10 +208,10 @@ def test_custom_args(self): deferrable=True, poke_interval=10, aws_conn_id="custom_conn", - timeout=120, + max_retries=20, ) assert sensor.verbose is True assert sensor.deferrable is True assert sensor.poke_interval == 10 assert sensor.aws_conn_id == "custom_conn" - assert sensor.timeout == 120 + assert sensor.max_retries == 20 diff --git a/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py b/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py index 431082ee21f39..e4f8fb39e7fb2 100644 --- a/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/waiters/test_glue.py @@ -22,7 +22,7 @@ import botocore import pytest -from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook +from airflow.providers.amazon.aws.hooks.glue import GlueDataQualityHook, GlueJobHook from airflow.providers.amazon.aws.sensors.glue import ( GlueDataQualityRuleRecommendationRunSensor, GlueDataQualityRuleSetEvaluationRunSensor, @@ -104,3 +104,46 @@ def test_data_quality_rule_recommendation_run_wait(self, mock_get_job): GlueDataQualityHook().get_waiter(self.WAITER_NAME).wait( RunIc="run_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} ) + + +class TestGlueJobCompleteCustomWaiterBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("glue") + monkeypatch.setattr(GlueJobHook, "conn", self.client) + + +class TestGlueJobCompleteWaiter(TestGlueJobCompleteCustomWaiterBase): + WAITER_NAME = "job_complete" + + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "get_job_run") as mock_getter: + yield mock_getter + + @pytest.mark.parametrize("state", ["SUCCEEDED"]) + def test_glue_job_run_success(self, state, mock_get_job): + mock_get_job.return_value = {"JobRun": {"JobRunState": state}} + + GlueJobHook().get_waiter(self.WAITER_NAME).wait(JobName="example", RunId="run_id") + + @pytest.mark.parametrize("state", ["STOPPED", "FAILED", "ERROR", "TIMEOUT"]) + def test_glue_job_run_failure(self, state, mock_get_job): + mock_get_job.return_value = {"JobRun": {"JobRunState": state}} + + with pytest.raises(botocore.exceptions.WaiterError): + GlueJobHook().get_waiter(self.WAITER_NAME).wait(JobName="example", RunId="run_id") + + @pytest.mark.parametrize("intermediate", ["STARTING", "RUNNING", "STOPPING"]) + def test_glue_job_run_retry_then_success(self, intermediate, mock_get_job): + mock_get_job.side_effect = [ + {"JobRun": {"JobRunState": intermediate}}, + {"JobRun": {"JobRunState": intermediate}}, + {"JobRun": {"JobRunState": "SUCCEEDED"}}, + ] + + GlueJobHook().get_waiter(self.WAITER_NAME).wait( + JobName="example", + RunId="run_id", + WaiterConfig={"Delay": 0.01, "MaxAttempts": 3}, + ) From 444f1c773a6d9628604990e859ef307c7002561c Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Fri, 20 Jun 2025 14:09:34 +0200 Subject: [PATCH 5/6] Added new params to GlueJobOperator and fixed GlueTrigger tests --- .../providers/amazon/aws/operators/glue.py | 4 +- .../unit/amazon/aws/triggers/test_glue.py | 56 +++++++++++-------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 61b1cc98e338d..61114b14f338c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -66,6 +66,8 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): :param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation :param run_job_kwargs: Extra arguments for Glue Job Run + :param waiter_delay: Time in seconds to wait between status checks. (default: 60) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20) :param wait_for_completion: Whether to wait for job run completion. (default: True) :param deferrable: If True, the operator will wait asynchronously for the job to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. @@ -259,7 +261,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None if validated_event["status"] != "success": raise AirflowException(f"Error in glue job: {validated_event}") - return validated_event["value"] + return validated_event["run_id"] def on_kill(self): """Cancel the running AWS Glue Job.""" diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py index 354d70cfab6ed..d9b8ccaa452c9 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py @@ -40,47 +40,55 @@ class TestGlueJobTrigger: @pytest.mark.asyncio - @mock.patch.object(GlueJobHook, "async_get_job_state") - async def test_wait_job(self, get_state_mock: mock.MagicMock): + @mock.patch.object(GlueJobHook, "get_waiter") + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_wait_job(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) - get_state_mock.side_effect = [ - "RUNNING", - "RUNNING", - "SUCCEEDED", - ] - generator = trigger.run() event = await generator.asend(None) # type:ignore[attr-defined] - assert get_state_mock.call_count == 3 + assert_expected_waiter_type(mock_get_waiter, "job_complete") + mock_get_waiter().wait.assert_called_once() assert event.payload["status"] == "success" + assert event.payload["run_id"] == "JobRunId" @pytest.mark.asyncio - @mock.patch.object(GlueJobHook, "async_get_job_state") - async def test_wait_job_failed(self, get_state_mock: mock.MagicMock): + @mock.patch.object(GlueJobHook, "get_waiter") + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_wait_job_failed(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + from botocore.exceptions import WaiterError + + mock_get_waiter().wait = AsyncMock( + side_effect=WaiterError( + name="job_complete", + reason="Waiter encountered a terminal failure state", + last_response={"JobRun": {"JobRunState": "FAILED"}}, + ) + ) + trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) - get_state_mock.side_effect = [ - "RUNNING", - "RUNNING", - "FAILED", - ] + generator = trigger.run() with pytest.raises(AirflowException): - await trigger.run().asend(None) # type:ignore[attr-defined] - - assert get_state_mock.call_count == 3 + await generator.asend(None) # type:ignore[attr-defined] + assert_expected_waiter_type(mock_get_waiter, "job_complete") def test_serialization(self): trigger = GlueJobCompleteTrigger( @@ -88,7 +96,8 @@ def test_serialization(self): run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", - job_poll_interval=0.1, + waiter_max_attempts=3, + waiter_delay=10, ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.providers.amazon.aws.triggers.glue.GlueJobCompleteTrigger" @@ -97,7 +106,8 @@ def test_serialization(self): "run_id": "JobRunId", "verbose": False, "aws_conn_id": "aws_conn_id", - "job_poll_interval": 0.1, + "waiter_max_attempts": 3, + "waiter_delay": 10, } From 97677a11cc06e0b74ce93593d8504ea3c92d311d Mon Sep 17 00:00:00 2001 From: Dominik Heilbock Date: Tue, 24 Jun 2025 22:36:24 +0200 Subject: [PATCH 6/6] Refined params of operator, trigger and hook --- .../providers/amazon/aws/hooks/glue.py | 1 - .../providers/amazon/aws/operators/glue.py | 22 +++++++++++-------- .../providers/amazon/aws/sensors/glue.py | 6 +---- .../providers/amazon/aws/triggers/glue.py | 2 -- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py index 7093120ceafd0..ceb1c99928282 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py @@ -46,7 +46,6 @@ class GlueJobHook(AwsBaseHook): :param script_location: path to etl script on s3 :param retry_limit: Maximum number of times to retry this job if it fails :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job - :param region_name: aws region name (example: us-east-1) :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. :param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 61114b14f338c..5cf8f2ecb5ff1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -60,14 +60,11 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): :param script_args: etl script arguments and AWS Glue arguments (templated) :param retry_limit: The maximum number of times to retry this job if it fails :param num_of_dpus: Number of AWS Glue DPUs to allocate to this Job. - :param region_name: aws region name (example: us-east-1) :param s3_bucket: S3 bucket where logs and local etl script will be uploaded :param iam_role_name: AWS IAM Role for Glue Job Execution. If set `iam_role_arn` must equal None. :param iam_role_arn: AWS IAM ARN for Glue Job Execution. If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation :param run_job_kwargs: Extra arguments for Glue Job Run - :param waiter_delay: Time in seconds to wait between status checks. (default: 60) - :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20) :param wait_for_completion: Whether to wait for job run completion. (default: True) :param deferrable: If True, the operator will wait asynchronously for the job to complete. This implies waiting for completion. This mode requires aiobotocore module to be installed. @@ -81,6 +78,17 @@ class GlueJobOperator(AwsBaseOperator[GlueJobHook]): Thus if status is returned immediately it might end up in case of more than 1 concurrent run. It is recommended to set this parameter to 10 when you are using concurrency=1. For more information see: https://repost.aws/questions/QUaKgpLBMPSGWO0iq2Fob_bw/glue-run-concurrent-jobs#ANFpCL2fRnQRqgDFuIU_rpvA + + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ aws_hook_class = GlueJobHook @@ -124,8 +132,6 @@ def __init__( verbose: bool = False, replace_script_file: bool = False, update_config: bool = False, - waiter_delay: int = 60, - waiter_max_attempts: int = 20, stop_job_run_on_kill: bool = False, sleep_before_return: int = 0, job_poll_interval: int | float = 6, @@ -151,8 +157,6 @@ def __init__( self.update_config = update_config self.replace_script_file = replace_script_file self.deferrable = deferrable - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts self.job_poll_interval = job_poll_interval self.stop_job_run_on_kill = stop_job_run_on_kill self._job_run_id: str | None = None @@ -237,8 +241,8 @@ def execute(self, context: Context): run_id=self._job_run_id, verbose=self.verbose, aws_conn_id=self.aws_conn_id, - waiter_delay=int(self.waiter_delay), - waiter_max_attempts=self.waiter_max_attempts, + waiter_delay=int(self.job_poll_interval), + waiter_max_attempts=self.retry_limit, ), method_name="execute_complete", ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py index 7d7b5e5e2afc6..6437d5513cd90 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/glue.py @@ -118,11 +118,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None message = f"Error: AWS Glue Job: {validated_event}" raise AirflowException(message) - self.log.info( - "AWS Glue Job completed. Job Name: %s, Run ID: %s", - self.job_name, - self.run_id, - ) + self.log.info("AWS Glue Job completed.") def poke(self, context: Context) -> bool: self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py index 3f3b196668f97..6314bce52886f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py @@ -78,8 +78,6 @@ def __init__( self.job_name = job_name self.run_id = run_id self.verbose = verbose - self.waiter_delay = waiter_delay - self.waiter_max_attempts = waiter_max_attempts def hook(self) -> AwsGenericHook: return GlueJobHook(