From 1adddc81e1b87288454d65c485eb65103d4728ab Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 02:07:04 +0545 Subject: [PATCH 1/4] Workaround for asyncio cancelled error for DataprocSubmitJobOperator --- .../google/cloud/operators/dataproc.py | 6 +++ .../google/cloud/triggers/dataproc.py | 53 +++++++++++++++++-- .../google/cloud/triggers/test_dataproc.py | 13 ++++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 34e9395f23af4..546bd740a8e89 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -1454,6 +1454,9 @@ def execute(self, context: Context): if self.deferrable: self.defer( trigger=DataprocSubmitTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), job_id=job_id, project_id=self.project_id, region=self.region, @@ -2586,6 +2589,9 @@ def execute(self, context: Context): raise AirflowException(f"Job was cancelled:\n{job}") self.defer( trigger=DataprocSubmitTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), job_id=self.job_id, project_id=self.project_id, region=self.region, diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 427bf8a09615c..056d9d56d6654 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -22,16 +22,22 @@ import asyncio import re import time -from typing import Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session class DataprocBaseTrigger(BaseTrigger): @@ -77,6 +83,9 @@ class DataprocSubmitTrigger(DataprocBaseTrigger): """ DataprocSubmitTrigger run on the trigger worker to perform create Build operation. + :param dag_id: The DAG ID. + :param task_id: The task ID. + :param run_id: The run ID. :param job_id: The ID of a Dataproc job. :param project_id: Google Cloud Project where the job is running :param region: The Cloud Dataproc region in which to handle the request. @@ -92,14 +101,20 @@ class DataprocSubmitTrigger(DataprocBaseTrigger): :param polling_interval_seconds: polling period in seconds to check for the status """ - def __init__(self, job_id: str, **kwargs): + def __init__(self, job_id: str, dag_id: str, task_id: str, run_id: str | None, **kwargs): self.job_id = job_id + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id super().__init__(**kwargs) def serialize(self): return ( "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "job_id": self.job_id, "project_id": self.project_id, "region": self.region, @@ -110,6 +125,38 @@ def serialize(self): }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state not in { + TaskInstanceState.RUNNING, + TaskInstanceState.DEFERRED, + } + async def run(self): try: while True: @@ -125,7 +172,7 @@ async def run(self): except asyncio.CancelledError: self.log.info("Task got cancelled.") try: - if self.job_id and self.cancel_on_kill: + if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): self.log.info("Cancelling the job: %s", self.job_id) # The synchronous hook is utilized to delete the cluster when a task is cancelled. This # is because the asynchronous hook deletion is not awaited when the trigger task is diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index f41fc3a280283..d12b821d9afae 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -49,6 +49,9 @@ TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" TEST_JOB_ID = "test-job-id" +TEST_DAG_ID = "test_dag_id" +TEST_TASK_ID = "test_task_id" +TEST_RUN_ID = "test_run_id" @pytest.fixture @@ -118,6 +121,9 @@ def func(**kwargs): @pytest.fixture def submit_trigger(): return DataprocSubmitTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, region=TEST_REGION, @@ -494,6 +500,9 @@ def test_submit_trigger_serialization(self, submit_trigger): classpath, kwargs = submit_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "job_id": TEST_JOB_ID, "project_id": TEST_PROJECT_ID, "region": TEST_REGION, @@ -538,10 +547,12 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel") async def test_submit_trigger_run_cancelled( - self, mock_get_sync_hook, mock_get_async_hook, submit_trigger + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger ): """Test the trigger correctly handles an asyncio.CancelledError.""" + mock_safe_to_cancel.return_value = True mock_async_hook = mock_get_async_hook.return_value mock_async_hook.get_job.side_effect = asyncio.CancelledError From fee2a3fb19a79b46527c9d9157878df6e6f1083b Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 19:54:24 +0545 Subject: [PATCH 2/4] Use task_instance to fetch the status of task from database --- .../google/cloud/operators/dataproc.py | 3 -- .../google/cloud/triggers/dataproc.py | 34 +++++++++---------- .../google/cloud/triggers/test_dataproc.py | 9 ----- 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 546bd740a8e89..3de0abdaca100 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -2589,9 +2589,6 @@ def execute(self, context: Context): raise AirflowException(f"Job was cancelled:\n{job}") self.defer( trigger=DataprocSubmitTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), job_id=self.job_id, project_id=self.project_id, region=self.region, diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 056d9d56d6654..707be0d1f2ff9 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -83,9 +83,6 @@ class DataprocSubmitTrigger(DataprocBaseTrigger): """ DataprocSubmitTrigger run on the trigger worker to perform create Build operation. - :param dag_id: The DAG ID. - :param task_id: The task ID. - :param run_id: The run ID. :param job_id: The ID of a Dataproc job. :param project_id: Google Cloud Project where the job is running :param region: The Cloud Dataproc region in which to handle the request. @@ -101,20 +98,14 @@ class DataprocSubmitTrigger(DataprocBaseTrigger): :param polling_interval_seconds: polling period in seconds to check for the status """ - def __init__(self, job_id: str, dag_id: str, task_id: str, run_id: str | None, **kwargs): + def __init__(self, job_id: str, **kwargs): self.job_id = job_id - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id super().__init__(**kwargs) def serialize(self): return ( "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "job_id": self.job_id, "project_id": self.project_id, "region": self.region, @@ -133,14 +124,19 @@ def get_task_instance(self, session: Session) -> TaskInstance: :param session: Sqlalchemy session """ query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, ) task_instance = query.one_or_none() if task_instance is None: raise AirflowException( - f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, ) return task_instance @@ -151,11 +147,9 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ + # Database query is needed to get the latest state of the task instance. task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state not in { - TaskInstanceState.RUNNING, - TaskInstanceState.DEFERRED, - } + return task_instance.state != TaskInstanceState.DEFERRED async def run(self): try: @@ -173,6 +167,10 @@ async def run(self): self.log.info("Task got cancelled.") try: if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info( + "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" + " in deferred state." + ) self.log.info("Cancelling the job: %s", self.job_id) # The synchronous hook is utilized to delete the cluster when a task is cancelled. This # is because the asynchronous hook deletion is not awaited when the trigger task is diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index d12b821d9afae..7d6d8ec9bdad2 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -49,9 +49,6 @@ TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" TEST_JOB_ID = "test-job-id" -TEST_DAG_ID = "test_dag_id" -TEST_TASK_ID = "test_task_id" -TEST_RUN_ID = "test_run_id" @pytest.fixture @@ -121,9 +118,6 @@ def func(**kwargs): @pytest.fixture def submit_trigger(): return DataprocSubmitTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, job_id=TEST_JOB_ID, project_id=TEST_PROJECT_ID, region=TEST_REGION, @@ -500,9 +494,6 @@ def test_submit_trigger_serialization(self, submit_trigger): classpath, kwargs = submit_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "job_id": TEST_JOB_ID, "project_id": TEST_PROJECT_ID, "region": TEST_REGION, From 013893ef204136bd518d7e72b0fb717669f24fa6 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 20:40:05 +0545 Subject: [PATCH 3/4] Update the test --- tests/providers/google/cloud/triggers/test_dataproc.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index 7d6d8ec9bdad2..9f1f4a5d5ee99 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -123,6 +123,7 @@ def submit_trigger(): region=TEST_REGION, gcp_conn_id=TEST_GCP_CONN_ID, polling_interval_seconds=TEST_POLL_INTERVAL, + cancel_on_kill=True, ) @@ -536,14 +537,15 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge assert event.payload == expected_event.payload @pytest.mark.asyncio + @pytest.mark.parametrize("is_safe_to_cancel", [True, False]) @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel") async def test_submit_trigger_run_cancelled( - self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel ): """Test the trigger correctly handles an asyncio.CancelledError.""" - mock_safe_to_cancel.return_value = True + mock_safe_to_cancel.return_value = is_safe_to_cancel mock_async_hook = mock_get_async_hook.return_value mock_async_hook.get_job.side_effect = asyncio.CancelledError @@ -567,7 +569,7 @@ async def test_submit_trigger_run_cancelled( pytest.fail(f"Unexpected exception raised: {e}") # Check if cancel_job was correctly called - if submit_trigger.cancel_on_kill: + if submit_trigger.cancel_on_kill and is_safe_to_cancel: mock_sync_hook.cancel_job.assert_called_once_with( job_id=submit_trigger.job_id, project_id=submit_trigger.project_id, From f8cc2324e47ab703cf32e1b4391535fa3b3e50dd Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 20:45:04 +0545 Subject: [PATCH 4/4] remove missing dag_id, task_id and run_id --- airflow/providers/google/cloud/operators/dataproc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 3de0abdaca100..34e9395f23af4 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -1454,9 +1454,6 @@ def execute(self, context: Context): if self.deferrable: self.defer( trigger=DataprocSubmitTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), job_id=job_id, project_id=self.project_id, region=self.region,