diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index e4fccfedd87b6..34e9395f23af4 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -2592,6 +2592,7 @@ def execute(self, context: Context): gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, polling_interval_seconds=self.polling_interval_seconds, + cancel_on_kill=self.cancel_on_kill, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 32b536a2ecaa3..427bf8a09615c 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -44,6 +44,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, polling_interval_seconds: int = 30, + cancel_on_kill: bool = True, delete_on_error: bool = True, ): super().__init__() @@ -52,6 +53,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.polling_interval_seconds = polling_interval_seconds + self.cancel_on_kill = cancel_on_kill self.delete_on_error = delete_on_error def get_async_hook(self): @@ -63,8 +65,8 @@ def get_async_hook(self): def get_sync_hook(self): # The synchronous hook is utilized to delete the cluster when a task is cancelled. # This is because the asynchronous hook deletion is not awaited when the trigger task - # is cancelled. The call for deleting the cluster through the sync hook is not a blocking - # call, which means it does not wait until the cluster is deleted. + # is cancelled. The call for deleting the cluster or job through the sync hook is not a blocking + # call, which means it does not wait until the cluster or job is deleted. return DataprocHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, @@ -104,20 +106,39 @@ def serialize(self): "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "polling_interval_seconds": self.polling_interval_seconds, + "cancel_on_kill": self.cancel_on_kill, }, ) async def run(self): - while True: - job = await self.get_async_hook().get_job( - project_id=self.project_id, region=self.region, job_id=self.job_id - ) - state = job.status.state - self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) - if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): - break - await asyncio.sleep(self.polling_interval_seconds) - yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + try: + while True: + job = await self.get_async_hook().get_job( + project_id=self.project_id, region=self.region, job_id=self.job_id + ) + state = job.status.state + self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) + if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): + break + await asyncio.sleep(self.polling_interval_seconds) + yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + except asyncio.CancelledError: + self.log.info("Task got cancelled.") + try: + if self.job_id and self.cancel_on_kill: + self.log.info("Cancelling the job: %s", self.job_id) + # The synchronous hook is utilized to delete the cluster when a task is cancelled. This + # is because the asynchronous hook deletion is not awaited when the trigger task is + # cancelled. The call for deleting the cluster or job through the sync hook is not a + # blocking call, which means it does not wait until the cluster or job is deleted. + self.get_sync_hook().cancel_job( + job_id=self.job_id, project_id=self.project_id, region=self.region + ) + self.log.info("Job: %s is cancelled", self.job_id) + yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) + except Exception as e: + self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + raise e class DataprocClusterTrigger(DataprocBaseTrigger): diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index e310f2e0dfc9e..f41fc3a280283 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -22,7 +22,7 @@ from unittest import mock import pytest -from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus +from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from google.protobuf.any_pb2 import Any from google.rpc.status_pb2 import Status @@ -30,6 +30,7 @@ DataprocBatchTrigger, DataprocClusterTrigger, DataprocOperationTrigger, + DataprocSubmitTrigger, ) from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.triggers.base import TriggerEvent @@ -47,6 +48,7 @@ TEST_POLL_INTERVAL = 5 TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" +TEST_JOB_ID = "test-job-id" @pytest.fixture @@ -113,6 +115,17 @@ def func(**kwargs): return func +@pytest.fixture +def submit_trigger(): + return DataprocSubmitTrigger( + job_id=TEST_JOB_ID, + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + polling_interval_seconds=TEST_POLL_INTERVAL, + ) + + @pytest.fixture def async_get_batch(): def func(**kwargs): @@ -472,3 +485,94 @@ async def test_async_operation_triggers_on_error(self, mock_hook, operation_trig ) actual_event = await operation_trigger.run().asend(None) assert expected_event == actual_event + + +@pytest.mark.db_test +class TestDataprocSubmitTrigger: + def test_submit_trigger_serialization(self, submit_trigger): + """Test that the trigger serializes its configuration correctly.""" + classpath, kwargs = submit_trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger" + assert kwargs == { + "job_id": TEST_JOB_ID, + "project_id": TEST_PROJECT_ID, + "region": TEST_REGION, + "gcp_conn_id": TEST_GCP_CONN_ID, + "polling_interval_seconds": TEST_POLL_INTERVAL, + "cancel_on_kill": True, + "impersonation_chain": None, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + async def test_submit_trigger_run_success(self, mock_get_async_hook, submit_trigger): + """Test the trigger correctly handles a job completion.""" + mock_hook = mock_get_async_hook.return_value + mock_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) + ) + + async_gen = submit_trigger.run() + event = await async_gen.asend(None) + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigger): + """Test the trigger correctly handles a job error.""" + mock_hook = mock_get_async_hook.return_value + mock_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR)) + ) + + async_gen = submit_trigger.run() + event = await async_gen.asend(None) + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook") + async def test_submit_trigger_run_cancelled( + self, mock_get_sync_hook, mock_get_async_hook, submit_trigger + ): + """Test the trigger correctly handles an asyncio.CancelledError.""" + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job.side_effect = asyncio.CancelledError + + mock_sync_hook = mock_get_sync_hook.return_value + mock_sync_hook.cancel_job = mock.MagicMock() + + async_gen = submit_trigger.run() + + try: + await async_gen.asend(None) + # Should raise StopAsyncIteration if no more items to yield + await async_gen.asend(None) + except asyncio.CancelledError: + # Handle the cancellation as expected + pass + except StopAsyncIteration: + # The generator should be properly closed after handling the cancellation + pass + except Exception as e: + # Catch any other exceptions that should not occur + pytest.fail(f"Unexpected exception raised: {e}") + + # Check if cancel_job was correctly called + if submit_trigger.cancel_on_kill: + mock_sync_hook.cancel_job.assert_called_once_with( + job_id=submit_trigger.job_id, + project_id=submit_trigger.project_id, + region=submit_trigger.region, + ) + else: + mock_sync_hook.cancel_job.assert_not_called() + + # Clean up the generator + await async_gen.aclose()