Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion airflow/providers/google/cloud/triggers/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,41 @@ def serialize(self):
},
)

@provide_session
def get_task_instance(self, session: Session) -> TaskInstance:
"""
Get the task instance for the current task.

:param session: Sqlalchemy session
"""
query = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.task_instance.dag_id,
TaskInstance.task_id == self.task_instance.task_id,
TaskInstance.run_id == self.task_instance.run_id,
TaskInstance.map_index == self.task_instance.map_index,
)
task_instance = query.one_or_none()
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
self.task_instance.dag_id,
self.task_instance.task_id,
self.task_instance.run_id,
self.task_instance.map_index,
)
return task_instance

def safe_to_cancel(self) -> bool:
"""
Whether it is safe to cancel the external job which is being executed by this trigger.

This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
Because in those cases, we should NOT cancel the external job.
"""
# Database query is needed to get the latest state of the task instance.
task_instance = self.get_task_instance() # type: ignore[call-arg]
return task_instance.state != TaskInstanceState.DEFERRED

async def run(self):
try:
while True:
Expand All @@ -131,7 +166,11 @@ async def run(self):
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
if self.job_id and self.cancel_on_kill:
if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
self.log.info(
"Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
" in deferred state."
)
self.log.info("Cancelling the job: %s", self.job_id)
# The synchronous hook is utilized to delete the cluster when a task is cancelled. This
# is because the asynchronous hook deletion is not awaited when the trigger task is
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/google/cloud/triggers/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def submit_trigger():
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
polling_interval_seconds=TEST_POLL_INTERVAL,
cancel_on_kill=True,
)


Expand Down Expand Up @@ -569,12 +570,15 @@ async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigge
assert event.payload == expected_event.payload

@pytest.mark.asyncio
@pytest.mark.parametrize("is_safe_to_cancel", [True, False])
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel")
async def test_submit_trigger_run_cancelled(
self, mock_get_sync_hook, mock_get_async_hook, submit_trigger
self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel
):
"""Test the trigger correctly handles an asyncio.CancelledError."""
mock_safe_to_cancel.return_value = is_safe_to_cancel
mock_async_hook = mock_get_async_hook.return_value
mock_async_hook.get_job.side_effect = asyncio.CancelledError

Expand All @@ -598,7 +602,7 @@ async def test_submit_trigger_run_cancelled(
pytest.fail(f"Unexpected exception raised: {e}")

# Check if cancel_job was correctly called
if submit_trigger.cancel_on_kill:
if submit_trigger.cancel_on_kill and is_safe_to_cancel:
mock_sync_hook.cancel_job.assert_called_once_with(
job_id=submit_trigger.job_id,
project_id=submit_trigger.project_id,
Expand Down