diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index e2b5edd8e4583..6cfb6c523a31d 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -134,15 +134,23 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.debug("Container %s status: %s", self.base_container_name, container_state) if container_state == ContainerState.TERMINATED: - yield TriggerEvent( - { - "name": self.pod_name, - "namespace": self.pod_namespace, - "status": "success", - "message": "All containers inside pod have started successfully.", - } - ) - return + if pod_status not in PodPhase.terminal_states: + self.log.info( + "Pod %s is still running. Sleeping for %s seconds.", + self.pod_name, + self.poll_interval, + ) + await asyncio.sleep(self.poll_interval) + else: + yield TriggerEvent( + { + "name": self.pod_name, + "namespace": self.pod_namespace, + "status": "success", + "message": "All containers inside pod have started successfully.", + } + ) + return elif self.should_wait(pod_phase=pod_status, container_state=container_state): self.log.info("Container is not completed and still working.") diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 28159ca5ff9f4..6d5d18d028c18 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -95,7 +95,8 @@ def test_serialize(self, trigger): @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}._get_async_hook") async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigger): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + pod_mock = mock.MagicMock(**{"status.phase": "Succeeded"}) + mock_hook.return_value.get_pod.return_value = self._mock_pod_result(pod_mock) mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -110,6 +111,35 @@ async def test_run_loop_return_success_event(self, mock_hook, mock_method, trigg assert actual_event == expected_event + @pytest.mark.asyncio + @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + async def test_run_loop_wait_pod_termination_before_returning_success_event( + self, mock_hook, mock_method, trigger + ): + running_state = mock.MagicMock(**{"status.phase": "Running"}) + succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"}) + mock_hook.return_value.get_pod.side_effect = [ + self._mock_pod_result(running_state), + self._mock_pod_result(running_state), + self._mock_pod_result(succeeded_state), + ] + mock_method.return_value = ContainerState.TERMINATED + + expected_event = TriggerEvent( + { + "name": POD_NAME, + "namespace": NAMESPACE, + "status": "success", + "message": "All containers inside pod have started successfully.", + } + ) + with mock.patch.object(asyncio, "sleep") as mock_sleep: + actual_event = await (trigger.run()).asend(None) + + assert actual_event == expected_event + assert mock_sleep.call_count == 2 + @pytest.mark.asyncio @mock.patch(f"{TRIGGER_PATH}.define_container_state") @mock.patch(f"{TRIGGER_PATH}._get_async_hook") diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 3be05c3b1b66b..e957767e3eeaf 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -108,7 +108,13 @@ def test_serialize_should_execute_successfully(self, trigger): async def test_run_loop_return_success_event_should_execute_successfully( self, mock_hook, mock_method, trigger ): - mock_hook.return_value.get_pod.return_value = self._mock_pod_result(mock.MagicMock()) + running_state = mock.MagicMock(**{"status.phase": "Running"}) + succeeded_state = mock.MagicMock(**{"status.phase": "Succeeded"}) + mock_hook.return_value.get_pod.side_effect = [ + self._mock_pod_result(running_state), + self._mock_pod_result(running_state), + self._mock_pod_result(succeeded_state), + ] mock_method.return_value = ContainerState.TERMINATED expected_event = TriggerEvent( @@ -119,9 +125,11 @@ async def test_run_loop_return_success_event_should_execute_successfully( "message": "All containers inside pod have started successfully.", } ) - actual_event = await (trigger.run()).asend(None) + with mock.patch.object(asyncio, "sleep") as mock_sleep: + actual_event = await (trigger.run()).asend(None) assert actual_event == expected_event + assert mock_sleep.call_count == 2 @pytest.mark.asyncio @mock.patch(f"{TRIGGER_KUB_PATH}.define_container_state")