From 3146941c886a59c3b076c0200bd28553bc4aaaa4 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 01:48:28 +0545 Subject: [PATCH 1/4] Workaround for asyncio cancelled error for DataprocCreateClusterOperator --- .../google/cloud/operators/dataproc.py | 6 +++ .../google/cloud/triggers/dataproc.py | 50 +++++++++++++++++-- .../google/cloud/triggers/test_dataproc.py | 12 +++++ 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 34e9395f23af4..a873d2f07a8ec 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -810,6 +810,9 @@ def execute(self, context: Context) -> dict: else: self.defer( trigger=DataprocClusterTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), cluster_name=self.cluster_name, project_id=self.project_id, region=self.region, @@ -2748,6 +2751,9 @@ def execute(self, context: Context): if cluster.status.state != cluster.status.State.RUNNING: self.defer( trigger=DataprocClusterTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), cluster_name=self.cluster_name, project_id=self.project_id, region=self.region, diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 427bf8a09615c..bbfc752dd948e 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -22,16 +22,22 @@ import asyncio import re import time -from typing import Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session class DataprocBaseTrigger(BaseTrigger): @@ -160,14 +166,20 @@ class DataprocClusterTrigger(DataprocBaseTrigger): :param polling_interval_seconds: polling period in seconds to check for the status """ - def __init__(self, cluster_name: str, **kwargs): + def __init__(self, dag_id: str, task_id: str, run_id: str | None, cluster_name: str, **kwargs): super().__init__(**kwargs) + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id self.cluster_name = cluster_name def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "cluster_name": self.cluster_name, "project_id": self.project_id, "region": self.region, @@ -178,6 +190,38 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state not in { + TaskInstanceState.RUNNING, + TaskInstanceState.DEFERRED, + } + async def run(self) -> AsyncIterator[TriggerEvent]: try: while True: @@ -207,7 +251,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: try: - if self.delete_on_error: + if self.delete_on_error and self.safe_to_cancel(): self.log.info("Deleting cluster %s.", self.cluster_name) # The synchronous hook is utilized to delete the cluster when a task is cancelled. # This is because the asynchronous hook deletion is not awaited when the trigger task diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index f41fc3a280283..e0e0917a9e50d 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -49,11 +49,17 @@ TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" TEST_JOB_ID = "test-job-id" +TEST_DAG_ID = "test_dag_id" +TEST_TASK_ID = "test_task_id" +TEST_RUN_ID = "test_run_id" @pytest.fixture def cluster_trigger(): return DataprocClusterTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, cluster_name=TEST_CLUSTER_NAME, project_id=TEST_PROJECT_ID, region=TEST_REGION, @@ -156,6 +162,9 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c classpath, kwargs = cluster_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "cluster_name": TEST_CLUSTER_NAME, "project_id": TEST_PROJECT_ID, "region": TEST_REGION, @@ -258,6 +267,9 @@ async def test_cluster_trigger_cancellation_handling( mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster cluster_trigger = DataprocClusterTrigger( + dag_id="dag_id", + task_id="task_id", + run_id="run_id", cluster_name="cluster_name", project_id="project-id", region="region", From 5141a5646df0a2487ede812ce13645ea21d09950 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 16:02:34 +0545 Subject: [PATCH 2/4] Use task_instance from BaseTrigger directly. --- .../google/cloud/operators/dataproc.py | 6 --- .../google/cloud/triggers/dataproc.py | 37 ++----------------- .../google/cloud/triggers/test_dataproc.py | 12 ------ 3 files changed, 3 insertions(+), 52 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index a873d2f07a8ec..34e9395f23af4 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -810,9 +810,6 @@ def execute(self, context: Context) -> dict: else: self.defer( trigger=DataprocClusterTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), cluster_name=self.cluster_name, project_id=self.project_id, region=self.region, @@ -2751,9 +2748,6 @@ def execute(self, context: Context): if cluster.status.state != cluster.status.State.RUNNING: self.defer( trigger=DataprocClusterTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), cluster_name=self.cluster_name, project_id=self.project_id, region=self.region, diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index bbfc752dd948e..3b5059d1aba50 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -22,23 +22,18 @@ import asyncio import re import time -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence +from typing import Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException -from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState -if TYPE_CHECKING: - from sqlalchemy.orm import Session - class DataprocBaseTrigger(BaseTrigger): """Base class for Dataproc triggers.""" @@ -166,20 +161,14 @@ class DataprocClusterTrigger(DataprocBaseTrigger): :param polling_interval_seconds: polling period in seconds to check for the status """ - def __init__(self, dag_id: str, task_id: str, run_id: str | None, cluster_name: str, **kwargs): + def __init__(self, cluster_name: str, **kwargs): super().__init__(**kwargs) - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id self.cluster_name = cluster_name def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "cluster_name": self.cluster_name, "project_id": self.project_id, "region": self.region, @@ -190,25 +179,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) - @provide_session - def get_task_instance(self, session: Session) -> TaskInstance: - """ - Get the task instance for the current task. - - :param session: Sqlalchemy session - """ - query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - ) - task_instance = query.one_or_none() - if task_instance is None: - raise AirflowException( - f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" - ) - return task_instance - def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. @@ -216,8 +186,7 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state not in { + return self.task_instance not in { TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, } diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index e0e0917a9e50d..f41fc3a280283 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -49,17 +49,11 @@ TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" TEST_JOB_ID = "test-job-id" -TEST_DAG_ID = "test_dag_id" -TEST_TASK_ID = "test_task_id" -TEST_RUN_ID = "test_run_id" @pytest.fixture def cluster_trigger(): return DataprocClusterTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, cluster_name=TEST_CLUSTER_NAME, project_id=TEST_PROJECT_ID, region=TEST_REGION, @@ -162,9 +156,6 @@ def test_async_cluster_trigger_serialization_should_execute_successfully(self, c classpath, kwargs = cluster_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "cluster_name": TEST_CLUSTER_NAME, "project_id": TEST_PROJECT_ID, "region": TEST_REGION, @@ -267,9 +258,6 @@ async def test_cluster_trigger_cancellation_handling( mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster cluster_trigger = DataprocClusterTrigger( - dag_id="dag_id", - task_id="task_id", - run_id="run_id", cluster_name="cluster_name", project_id="project-id", region="region", From 24f47bec1021d6e6624b7623688ecb37ecf81e8d Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 19:39:26 +0545 Subject: [PATCH 3/4] Use task_instance to fetch the status of task from database --- .../google/cloud/triggers/dataproc.py | 37 ++++++++++++++++--- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 3b5059d1aba50..6214e147ebcf1 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -22,18 +22,23 @@ import asyncio import re import time -from typing import Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + class DataprocBaseTrigger(BaseTrigger): """Base class for Dataproc triggers.""" @@ -179,6 +184,25 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance dag_id: %s,task_id: %s, run_id: %s and map_index: %s not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. @@ -186,10 +210,9 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - return self.task_instance not in { - TaskInstanceState.RUNNING, - TaskInstanceState.DEFERRED, - } + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED async def run(self) -> AsyncIterator[TriggerEvent]: try: @@ -221,6 +244,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except asyncio.CancelledError: try: if self.delete_on_error and self.safe_to_cancel(): + self.log.info( + "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in " + "deferred state." + ) self.log.info("Deleting cluster %s.", self.cluster_name) # The synchronous hook is utilized to delete the cluster when a task is cancelled. # This is because the asynchronous hook deletion is not awaited when the trigger task From b32e305d7d1ae093818d7cee145c064451a3db73 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 22:11:57 +0545 Subject: [PATCH 4/4] Add the tests --- .../google/cloud/triggers/dataproc.py | 2 +- .../google/cloud/triggers/test_dataproc.py | 35 ++++++++++++++++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 6214e147ebcf1..939e5bbcac716 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -195,7 +195,7 @@ def get_task_instance(self, session: Session) -> TaskInstance: task_instance = query.one_or_none() if task_instance is None: raise AirflowException( - "TaskInstance dag_id: %s,task_id: %s, run_id: %s and map_index: %s not found", + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.", self.task_instance.dag_id, self.task_instance.task_id, self.task_instance.run_id, diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index f41fc3a280283..08294a5ac59d2 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -18,7 +18,7 @@ import asyncio import logging -from asyncio import Future +from asyncio import CancelledError, Future, sleep from unittest import mock import pytest @@ -60,6 +60,7 @@ def cluster_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) @@ -328,6 +329,38 @@ async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_tri mock_delete_cluster.assert_not_called() + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel") + async def test_cluster_trigger_run_cancelled_not_safe_to_cancel( + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, cluster_trigger + ): + """Test the trigger's cancellation behavior when it is not safe to cancel.""" + mock_safe_to_cancel.return_value = False + cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING)) + future_cluster = asyncio.Future() + future_cluster.set_result(cluster) + mock_get_async_hook.return_value.get_cluster.return_value = future_cluster + + mock_delete_cluster = mock.MagicMock() + mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster + + cluster_trigger.delete_on_error = True + + async_gen = cluster_trigger.run() + task = asyncio.create_task(async_gen.__anext__()) + await sleep(0) + task.cancel() + + try: + await task + except CancelledError: + pass + + assert mock_delete_cluster.call_count == 0 + mock_delete_cluster.assert_not_called() + @pytest.mark.db_test class TestDataprocBatchTrigger: