From 4aa1b1991815d719b7a67bd38c32643dbba5ef08 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 00:17:02 +0545 Subject: [PATCH 1/7] workaround for asyncio cancelled error --- .../google/cloud/operators/bigquery.py | 15 +++ .../google/cloud/triggers/bigquery.py | 110 +++++++++++++++++- .../google/cloud/triggers/test_bigquery.py | 77 +++++++++++- 3 files changed, 197 insertions(+), 5 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 1cf0f9ee9a350..f81978d9e18e8 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -315,6 +315,9 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryCheckTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, @@ -453,6 +456,9 @@ def execute(self, context: Context) -> None: # type: ignore[override] self.defer( timeout=self.execution_timeout, trigger=BigQueryValueCheckTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, @@ -608,6 +614,9 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryIntervalCheckTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), conn_id=self.gcp_conn_id, first_job_id=job_1.job_id, second_job_id=job_2.job_id, @@ -1124,6 +1133,9 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryGetDataTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, dataset_id=self.dataset_id, @@ -2956,6 +2968,9 @@ def execute(self, context: Any): self.defer( timeout=self.execution_timeout, trigger=BigQueryInsertJobTrigger( + dag_id=self.dag_id, + task_id=self.task_id, + run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=self.job_id, project_id=self.project_id, diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index e2e0e82f6b0eb..9cdf471ac19e9 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -17,19 +17,29 @@ from __future__ import annotations import asyncio -from typing import Any, AsyncIterator, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError +from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session class BigQueryInsertJobTrigger(BaseTrigger): """ BigQueryInsertJobTrigger run on the trigger worker to perform insert operation. + :param dag_id: The DAG ID. + :param task_id: The task ID. + :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param job_id: The ID of the job. It will be suffixed with hash of job configuration :param project_id: Google Cloud Project where the job is running @@ -49,6 +59,9 @@ class BigQueryInsertJobTrigger(BaseTrigger): def __init__( self, + dag_id: str, + task_id: str, + run_id: str | None, conn_id: str, job_id: str | None, project_id: str, @@ -61,6 +74,9 @@ def __init__( ): super().__init__() self.log.info("Using the connection %s .", conn_id) + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id self.conn_id = conn_id self.job_id = job_id self._job_conn = None @@ -77,6 +93,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -89,6 +108,38 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state not in { + TaskInstanceState.RUNNING, + TaskInstanceState.DEFERRED, + } + async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job execution status and yields a TriggerEvent.""" hook = self._get_async_hook() @@ -118,7 +169,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] await asyncio.sleep(self.poll_interval) except asyncio.CancelledError: self.log.info("Task was killed.") - if self.job_id and self.cancel_on_kill: + if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): await hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) @@ -140,6 +191,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -148,6 +202,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, + "cancel_on_kill": self.cancel_on_kill, }, ) @@ -201,12 +256,26 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): """ BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. + :param dag_id: The DAG ID. + :param task_id: The task ID. + :param run_id: The run ID. :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists (default: False). """ - def __init__(self, as_dict: bool = False, selected_fields: str | None = None, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + dag_id: str, + task_id: str, + run_id: str | None, + as_dict: bool = False, + selected_fields: str | None = None, + **kwargs, + ): + super().__init__(dag_id=dag_id, task_id=task_id, run_id=run_id, **kwargs) + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id self.as_dict = as_dict self.selected_fields = selected_fields @@ -215,6 +284,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -270,6 +342,9 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): """ BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. + :param dag_id: The DAG ID. + :param task_id: The task ID. + :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param first_job_id: The ID of the job 1 performed :param second_job_id: The ID of the job 2 performed @@ -296,6 +371,9 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): def __init__( self, + dag_id: str, + task_id: str, + run_id: str | None, conn_id: str, first_job_id: str, second_job_id: str, @@ -313,6 +391,9 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, ): super().__init__( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, conn_id=conn_id, job_id=first_job_id, project_id=project_id, @@ -322,6 +403,9 @@ def __init__( poll_interval=poll_interval, impersonation_chain=impersonation_chain, ) + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id self.conn_id = conn_id self.first_job_id = first_job_id self.second_job_id = second_job_id @@ -338,6 +422,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "conn_id": self.conn_id, "first_job_id": self.first_job_id, "second_job_id": self.second_job_id, @@ -434,6 +521,9 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): """ BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. + :param dag_id: The DAG ID. + :param task_id: The task ID. + :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param sql: the sql to be executed :param pass_value: pass value @@ -456,6 +546,9 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): def __init__( self, + dag_id: str, + task_id: str, + run_id: str | None, conn_id: str, sql: str, pass_value: int | float | str, @@ -469,6 +562,9 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, ): super().__init__( + dag_id=dag_id, + task_id=task_id, + run_id=run_id, conn_id=conn_id, job_id=job_id, project_id=project_id, @@ -478,6 +574,9 @@ def __init__( poll_interval=poll_interval, impersonation_chain=impersonation_chain, ) + self.dag_id = dag_id + self.task_id = task_id + self.run_id = run_id self.sql = sql self.pass_value = pass_value self.tolerance = tolerance @@ -487,6 +586,9 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger", { + "dag_id": self.dag_id, + "task_id": self.task_id, + "run_id": self.run_id, "conn_id": self.conn_id, "pass_value": self.pass_value, "job_id": self.job_id, diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index 436872903eb5b..c8c85056ab665 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -65,11 +65,17 @@ TEST_HOOK_PARAMS: dict[str, Any] = {} TEST_PARTITION_ID = "1234" TEST_SELECTED_FIELDS = "f0_,f1_" +TEST_DAG_ID = "test_dag_id" +TEST_TASK_ID = "test_task_id" +TEST_RUN_ID = "test_run_id" @pytest.fixture def insert_job_trigger(): return BigQueryInsertJobTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -84,6 +90,9 @@ def insert_job_trigger(): @pytest.fixture def get_data_trigger(): return BigQueryGetDataTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -112,6 +121,9 @@ def table_existence_trigger(): @pytest.fixture def interval_check_trigger(): return BigQueryIntervalCheckTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, first_job_id=TEST_FIRST_JOB_ID, second_job_id=TEST_SECOND_JOB_ID, @@ -132,6 +144,9 @@ def interval_check_trigger(): @pytest.fixture def check_trigger(): return BigQueryCheckTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -146,6 +161,9 @@ def check_trigger(): @pytest.fixture def value_check_trigger(): return BigQueryValueCheckTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, pass_value=TEST_PASS_VALUE, job_id=TEST_JOB_ID, @@ -167,6 +185,9 @@ def test_serialization(self, insert_job_trigger): classpath, kwargs = insert_job_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "cancel_on_kill": True, "conn_id": TEST_CONN_ID, "job_id": TEST_JOB_ID, @@ -239,13 +260,15 @@ async def test_bigquery_op_trigger_exception(self, mock_job_status, caplog, inse @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel") async def test_bigquery_insert_job_trigger_cancellation( - self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger + self, mock_get_task_instance, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger ): """ Test that BigQueryInsertJobTrigger handles cancellation correctly, logs the appropriate message, and conditionally cancels the job based on the `cancel_on_kill` attribute. """ + mock_get_task_instance.return_value = True insert_job_trigger.cancel_on_kill = True insert_job_trigger.job_id = "1234" @@ -271,6 +294,42 @@ async def test_bigquery_insert_job_trigger_cancellation( ), "Expected messages about task status or cancellation not found in log." mock_cancel_job.assert_awaited_once() + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") + @mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger.safe_to_cancel") + async def test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation( + self, mock_safe_to_cancel, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger + ): + """ + Test that BigQueryInsertJobTrigger logs the appropriate message and does not cancel the job + if safe_to_cancel returns False even when the task is cancelled. + """ + mock_safe_to_cancel.return_value = False + insert_job_trigger.cancel_on_kill = True + insert_job_trigger.job_id = "1234" + + # Simulate the initial job status as running + mock_get_job_status.side_effect = [ + {"status": "running", "message": "Job is still running"}, + asyncio.CancelledError(), + {"status": "running", "message": "Job is still running after cancellation"}, + ] + + caplog.set_level(logging.INFO) + + try: + async for _ in insert_job_trigger.run(): + pass + except asyncio.CancelledError: + pass + + assert "Task was killed" in caplog.text, "Expected message about task status not found in log." + assert ( + "Skipping to cancel job" in caplog.text + ), "Expected message about skipping cancellation not found in log." + assert mock_get_job_status.call_count == 2, "Job status should be checked multiple times" + class TestBigQueryGetDataTrigger: def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): @@ -279,6 +338,9 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): classpath, kwargs = get_data_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "as_dict": False, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, @@ -439,6 +501,9 @@ def test_check_trigger_serialization(self, check_trigger): classpath, kwargs = check_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "job_id": TEST_JOB_ID, @@ -447,6 +512,7 @@ def test_check_trigger_serialization(self, check_trigger): "table_id": TEST_TABLE_ID, "location": None, "poll_interval": POLLING_PERIOD_SECONDS, + "cancel_on_kill": True, } @pytest.mark.asyncio @@ -521,6 +587,9 @@ def test_interval_check_trigger_serialization(self, interval_check_trigger): classpath, kwargs = interval_check_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "first_job_id": TEST_FIRST_JOB_ID, @@ -615,6 +684,9 @@ def test_bigquery_value_check_op_trigger_serialization(self, value_check_trigger assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger" assert kwargs == { + "dag_id": TEST_DAG_ID, + "task_id": TEST_TASK_ID, + "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "pass_value": TEST_PASS_VALUE, @@ -690,6 +762,9 @@ async def test_value_check_trigger_exception(self, mock_job_status): mock_job_status.side_effect = Exception("Test exception") trigger = BigQueryValueCheckTrigger( + dag_id=TEST_DAG_ID, + task_id=TEST_TASK_ID, + run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, sql=TEST_SQL_QUERY, pass_value=TEST_PASS_VALUE, From 696ac241eff40e7fd393a33dca801806ef103ad5 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 15:41:57 +0545 Subject: [PATCH 2/7] Use trigger task instance instead --- .../google/cloud/operators/bigquery.py | 15 --- .../google/cloud/triggers/bigquery.py | 91 +------------------ .../google/cloud/triggers/test_bigquery.py | 36 -------- 3 files changed, 5 insertions(+), 137 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index f81978d9e18e8..1cf0f9ee9a350 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -315,9 +315,6 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryCheckTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, @@ -456,9 +453,6 @@ def execute(self, context: Context) -> None: # type: ignore[override] self.defer( timeout=self.execution_timeout, trigger=BigQueryValueCheckTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, project_id=hook.project_id, @@ -614,9 +608,6 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryIntervalCheckTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), conn_id=self.gcp_conn_id, first_job_id=job_1.job_id, second_job_id=job_2.job_id, @@ -1133,9 +1124,6 @@ def execute(self, context: Context): self.defer( timeout=self.execution_timeout, trigger=BigQueryGetDataTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=job.job_id, dataset_id=self.dataset_id, @@ -2968,9 +2956,6 @@ def execute(self, context: Any): self.defer( timeout=self.execution_timeout, trigger=BigQueryInsertJobTrigger( - dag_id=self.dag_id, - task_id=self.task_id, - run_id=context.get("run_id"), conn_id=self.gcp_conn_id, job_id=self.job_id, project_id=self.project_id, diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 9cdf471ac19e9..7af8ff3b52ed2 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -17,29 +17,20 @@ from __future__ import annotations import asyncio -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs +from typing import Any, AsyncIterator, Sequence, SupportsAbs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError -from airflow.exceptions import AirflowException -from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent -from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState -if TYPE_CHECKING: - from sqlalchemy.orm import Session - class BigQueryInsertJobTrigger(BaseTrigger): """ BigQueryInsertJobTrigger run on the trigger worker to perform insert operation. - :param dag_id: The DAG ID. - :param task_id: The task ID. - :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param job_id: The ID of the job. It will be suffixed with hash of job configuration :param project_id: Google Cloud Project where the job is running @@ -59,9 +50,6 @@ class BigQueryInsertJobTrigger(BaseTrigger): def __init__( self, - dag_id: str, - task_id: str, - run_id: str | None, conn_id: str, job_id: str | None, project_id: str, @@ -74,9 +62,6 @@ def __init__( ): super().__init__() self.log.info("Using the connection %s .", conn_id) - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id self.conn_id = conn_id self.job_id = job_id self._job_conn = None @@ -93,9 +78,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -108,25 +90,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) - @provide_session - def get_task_instance(self, session: Session) -> TaskInstance: - """ - Get the task instance for the current task. - - :param session: Sqlalchemy session - """ - query = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - ) - task_instance = query.one_or_none() - if task_instance is None: - raise AirflowException( - f"TaskInstance {self.dag_id}.{self.task_id} with run_id {self.run_id} not found" - ) - return task_instance - def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. @@ -134,8 +97,8 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state not in { + self.log.info("Checking if it is safe to cancel the job.") + return self.task_instance not in { TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, } @@ -170,6 +133,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] except asyncio.CancelledError: self.log.info("Task was killed.") if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info("Cancelling job: %s:%s.%s", self.project_id, self.location, self.job_id) await hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) @@ -191,9 +155,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -256,26 +217,17 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): """ BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. - :param dag_id: The DAG ID. - :param task_id: The task ID. - :param run_id: The run ID. :param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists (default: False). """ def __init__( self, - dag_id: str, - task_id: str, - run_id: str | None, as_dict: bool = False, selected_fields: str | None = None, **kwargs, ): - super().__init__(dag_id=dag_id, task_id=task_id, run_id=run_id, **kwargs) - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id + super().__init__(**kwargs) self.as_dict = as_dict self.selected_fields = selected_fields @@ -284,9 +236,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "conn_id": self.conn_id, "job_id": self.job_id, "dataset_id": self.dataset_id, @@ -342,9 +291,6 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): """ BigQueryIntervalCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. - :param dag_id: The DAG ID. - :param task_id: The task ID. - :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param first_job_id: The ID of the job 1 performed :param second_job_id: The ID of the job 2 performed @@ -371,9 +317,6 @@ class BigQueryIntervalCheckTrigger(BigQueryInsertJobTrigger): def __init__( self, - dag_id: str, - task_id: str, - run_id: str | None, conn_id: str, first_job_id: str, second_job_id: str, @@ -391,9 +334,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, ): super().__init__( - dag_id=dag_id, - task_id=task_id, - run_id=run_id, conn_id=conn_id, job_id=first_job_id, project_id=project_id, @@ -403,9 +343,6 @@ def __init__( poll_interval=poll_interval, impersonation_chain=impersonation_chain, ) - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id self.conn_id = conn_id self.first_job_id = first_job_id self.second_job_id = second_job_id @@ -422,9 +359,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "conn_id": self.conn_id, "first_job_id": self.first_job_id, "second_job_id": self.second_job_id, @@ -521,9 +455,6 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): """ BigQueryValueCheckTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class. - :param dag_id: The DAG ID. - :param task_id: The task ID. - :param run_id: The run ID. :param conn_id: Reference to google cloud connection id :param sql: the sql to be executed :param pass_value: pass value @@ -546,9 +477,6 @@ class BigQueryValueCheckTrigger(BigQueryInsertJobTrigger): def __init__( self, - dag_id: str, - task_id: str, - run_id: str | None, conn_id: str, sql: str, pass_value: int | float | str, @@ -562,9 +490,6 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, ): super().__init__( - dag_id=dag_id, - task_id=task_id, - run_id=run_id, conn_id=conn_id, job_id=job_id, project_id=project_id, @@ -574,9 +499,6 @@ def __init__( poll_interval=poll_interval, impersonation_chain=impersonation_chain, ) - self.dag_id = dag_id - self.task_id = task_id - self.run_id = run_id self.sql = sql self.pass_value = pass_value self.tolerance = tolerance @@ -586,9 +508,6 @@ def serialize(self) -> tuple[str, dict[str, Any]]: return ( "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger", { - "dag_id": self.dag_id, - "task_id": self.task_id, - "run_id": self.run_id, "conn_id": self.conn_id, "pass_value": self.pass_value, "job_id": self.job_id, diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index c8c85056ab665..51fada53d458b 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -65,17 +65,11 @@ TEST_HOOK_PARAMS: dict[str, Any] = {} TEST_PARTITION_ID = "1234" TEST_SELECTED_FIELDS = "f0_,f1_" -TEST_DAG_ID = "test_dag_id" -TEST_TASK_ID = "test_task_id" -TEST_RUN_ID = "test_run_id" @pytest.fixture def insert_job_trigger(): return BigQueryInsertJobTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -90,9 +84,6 @@ def insert_job_trigger(): @pytest.fixture def get_data_trigger(): return BigQueryGetDataTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -121,9 +112,6 @@ def table_existence_trigger(): @pytest.fixture def interval_check_trigger(): return BigQueryIntervalCheckTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, first_job_id=TEST_FIRST_JOB_ID, second_job_id=TEST_SECOND_JOB_ID, @@ -144,9 +132,6 @@ def interval_check_trigger(): @pytest.fixture def check_trigger(): return BigQueryCheckTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, job_id=TEST_JOB_ID, project_id=TEST_GCP_PROJECT_ID, @@ -161,9 +146,6 @@ def check_trigger(): @pytest.fixture def value_check_trigger(): return BigQueryValueCheckTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, pass_value=TEST_PASS_VALUE, job_id=TEST_JOB_ID, @@ -185,9 +167,6 @@ def test_serialization(self, insert_job_trigger): classpath, kwargs = insert_job_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "cancel_on_kill": True, "conn_id": TEST_CONN_ID, "job_id": TEST_JOB_ID, @@ -338,9 +317,6 @@ def test_bigquery_get_data_trigger_serialization(self, get_data_trigger): classpath, kwargs = get_data_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryGetDataTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "as_dict": False, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, @@ -501,9 +477,6 @@ def test_check_trigger_serialization(self, check_trigger): classpath, kwargs = check_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryCheckTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "job_id": TEST_JOB_ID, @@ -587,9 +560,6 @@ def test_interval_check_trigger_serialization(self, interval_check_trigger): classpath, kwargs = interval_check_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryIntervalCheckTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "first_job_id": TEST_FIRST_JOB_ID, @@ -684,9 +654,6 @@ def test_bigquery_value_check_op_trigger_serialization(self, value_check_trigger assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryValueCheckTrigger" assert kwargs == { - "dag_id": TEST_DAG_ID, - "task_id": TEST_TASK_ID, - "run_id": TEST_RUN_ID, "conn_id": TEST_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "pass_value": TEST_PASS_VALUE, @@ -762,9 +729,6 @@ async def test_value_check_trigger_exception(self, mock_job_status): mock_job_status.side_effect = Exception("Test exception") trigger = BigQueryValueCheckTrigger( - dag_id=TEST_DAG_ID, - task_id=TEST_TASK_ID, - run_id=TEST_RUN_ID, conn_id=TEST_CONN_ID, sql=TEST_SQL_QUERY, pass_value=TEST_PASS_VALUE, From 856f919dfa9525c44ce758fc89397c5bd7bf6c8c Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 16:14:44 +0545 Subject: [PATCH 3/7] Fix PR comment --- airflow/providers/google/cloud/triggers/bigquery.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 7af8ff3b52ed2..83e17e7e13ab4 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -97,7 +97,6 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - self.log.info("Checking if it is safe to cancel the job.") return self.task_instance not in { TaskInstanceState.RUNNING, TaskInstanceState.DEFERRED, @@ -221,12 +220,7 @@ class BigQueryGetDataTrigger(BigQueryInsertJobTrigger): (default: False). """ - def __init__( - self, - as_dict: bool = False, - selected_fields: str | None = None, - **kwargs, - ): + def __init__(self, as_dict: bool = False, selected_fields: str | None = None, **kwargs): super().__init__(**kwargs) self.as_dict = as_dict self.selected_fields = selected_fields From b882a49b8cd538cfd961fa5ef97448226389682c Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 16:33:59 +0545 Subject: [PATCH 4/7] Fix the PR comments --- airflow/providers/google/cloud/triggers/bigquery.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 83e17e7e13ab4..778cf4d41159a 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -132,7 +132,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] except asyncio.CancelledError: self.log.info("Task was killed.") if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): - self.log.info("Cancelling job: %s:%s.%s", self.project_id, self.location, self.job_id) + self.log.info( + "Cancelling job. Project ID: %s, Location: %s, Job ID: %s", + self.project_id, + self.location, + self.job_id, + ) await hook.cancel_job( # type: ignore[union-attr] job_id=self.job_id, project_id=self.project_id, location=self.location ) From d6332ff05c1a760e85438479a6db073e9f70bd37 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 18:39:07 +0545 Subject: [PATCH 5/7] Use task_instance to fetch the status of task from database --- .../google/cloud/triggers/bigquery.py | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 778cf4d41159a..7dfb13159cfeb 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -17,15 +17,21 @@ from __future__ import annotations import asyncio -from typing import Any, AsyncIterator, Sequence, SupportsAbs +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, SupportsAbs from aiohttp import ClientSession from aiohttp.client_exceptions import ClientResponseError +from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.bigquery import BigQueryAsyncHook, BigQueryTableAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session from airflow.utils.state import TaskInstanceState +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + class BigQueryInsertJobTrigger(BaseTrigger): """ @@ -90,6 +96,24 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance %s, %s with run_id %s not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + ) + return task_instance + def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. @@ -97,10 +121,9 @@ def safe_to_cancel(self) -> bool: This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. Because in those cases, we should NOT cancel the external job. """ - return self.task_instance not in { - TaskInstanceState.RUNNING, - TaskInstanceState.DEFERRED, - } + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current job execution status and yields a TriggerEvent.""" @@ -130,8 +153,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] ) await asyncio.sleep(self.poll_interval) except asyncio.CancelledError: - self.log.info("Task was killed.") if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info( + "The job is safe to cancel the as airflow TaskInstance is not in deferred state." + ) self.log.info( "Cancelling job. Project ID: %s, Location: %s, Job ID: %s", self.project_id, @@ -142,7 +167,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] job_id=self.job_id, project_id=self.project_id, location=self.location ) else: - self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) + self.log.info( + "Trigger may have shutdown. Skipping to cancel job because the airflow " + "task is not cancelled yet: Project ID: %s, Location:%s, Job ID:%s", + self.project_id, + self.location, + self.job_id, + ) except Exception as e: self.log.exception("Exception occurred while checking for query completion") yield TriggerEvent({"status": "error", "message": str(e)}) From 1026cdea16c330a5139e563d232ab37de64c067f Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 18:44:40 +0545 Subject: [PATCH 6/7] Fix the test --- tests/providers/google/cloud/triggers/test_bigquery.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index 51fada53d458b..bbb1a50356882 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -303,7 +303,6 @@ async def test_bigquery_insert_job_trigger_cancellation_unsafe_cancellation( except asyncio.CancelledError: pass - assert "Task was killed" in caplog.text, "Expected message about task status not found in log." assert ( "Skipping to cancel job" in caplog.text ), "Expected message about skipping cancellation not found in log." From 770010aab78cdf2bb5c791b848bc8b087a12f36e Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Tue, 7 May 2024 19:43:04 +0545 Subject: [PATCH 7/7] Fix the line comment in PR --- airflow/providers/google/cloud/triggers/bigquery.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index 7dfb13159cfeb..fc19db988126f 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -107,10 +107,11 @@ def get_task_instance(self, session: Session) -> TaskInstance: task_instance = query.one_or_none() if task_instance is None: raise AirflowException( - "TaskInstance %s, %s with run_id %s not found", + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", self.task_instance.dag_id, self.task_instance.task_id, self.task_instance.run_id, + self.task_instance.map_index, ) return task_instance