From ecdba0447682e196020896bcf93a357672f75e97 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Fri, 28 Apr 2023 15:19:09 +0530 Subject: [PATCH] optimize deferred execution for GCSObjectsWithPrefixExistenceSensor --- airflow/providers/google/cloud/sensors/gcs.py | 27 ++++++++++--------- .../google/cloud/sensors/test_gcs.py | 22 ++++++++++++--- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/airflow/providers/google/cloud/sensors/gcs.py b/airflow/providers/google/cloud/sensors/gcs.py index d346c5303b53b..de2a433fbf1c9 100644 --- a/airflow/providers/google/cloud/sensors/gcs.py +++ b/airflow/providers/google/cloud/sensors/gcs.py @@ -323,19 +323,20 @@ def execute(self, context: Context): super().execute(context) return self._matches else: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=GCSPrefixBlobTrigger( - bucket=self.bucket, - prefix=self.prefix, - poke_interval=self.poke_interval, - google_cloud_conn_id=self.google_cloud_conn_id, - hook_params={ - "impersonation_chain": self.impersonation_chain, - }, - ), - method_name="execute_complete", - ) + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=GCSPrefixBlobTrigger( + bucket=self.bucket, + prefix=self.prefix, + poke_interval=self.poke_interval, + google_cloud_conn_id=self.google_cloud_conn_id, + hook_params={ + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) def execute_complete(self, context: dict[str, Any], event: dict[str, str | list[str]]) -> str | list[str]: """ diff --git a/tests/providers/google/cloud/sensors/test_gcs.py b/tests/providers/google/cloud/sensors/test_gcs.py index b23a28e238ecb..0299898b661a3 100644 --- a/tests/providers/google/cloud/sensors/test_gcs.py +++ b/tests/providers/google/cloud/sensors/test_gcs.py @@ -369,6 +369,21 @@ def test_execute_timeout(self, mock_hook): task.execute(mock.MagicMock) mock_hook.return_value.list.assert_called_once_with(TEST_BUCKET, prefix=TEST_PREFIX) + @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") + @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor.defer") + def test_gcs_object_prefix_existence_sensor_finish_before_deferred(self, mock_defer, mock_hook): + task = GCSObjectsWithPrefixExistenceSensor( + task_id="task-id", + bucket=TEST_BUCKET, + prefix=TEST_PREFIX, + google_cloud_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + deferrable=True, + ) + mock_hook.return_value.list.return_value = True + task.execute(mock.MagicMock()) + assert not mock_defer.called + class TestGCSObjectsWithPrefixExistenceSensorAsync: OPERATOR = GCSObjectsWithPrefixExistenceSensor( @@ -379,14 +394,15 @@ class TestGCSObjectsWithPrefixExistenceSensorAsync: deferrable=True, ) - def test_gcs_object_with_prefix_existence_sensor_async(self, context): + @mock.patch("airflow.providers.google.cloud.sensors.gcs.GCSHook") + def test_gcs_object_with_prefix_existence_sensor_async(self, mock_hook): """ Asserts that a task is deferred and a GCSPrefixBlobTrigger will be fired when the GCSObjectsWithPrefixExistenceSensorAsync is executed. """ - + mock_hook.return_value.list.return_value = False with pytest.raises(TaskDeferred) as exc: - self.OPERATOR.execute(context) + self.OPERATOR.execute(mock.MagicMock()) assert isinstance(exc.value.trigger, GCSPrefixBlobTrigger), "Trigger is not a GCSPrefixBlobTrigger" def test_gcs_object_with_prefix_existence_sensor_async_execute_failure(self, context):