diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 3bacb95f4ff76..a531316666ca5 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -544,6 +544,19 @@ def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None: def extract_xcom(self, pod: V1Pod) -> str: """Retrieves XCom value and kills xcom sidecar container.""" + try: + result = self.extract_xcom_json(pod) + return result + finally: + self.extract_xcom_kill(pod) + + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + reraise=True, + ) + def extract_xcom_json(self, pod: V1Pod) -> str: + """Retrieves XCom value and also checks if xcom json is valid.""" with closing( kubernetes_stream( self._client.connect_get_namespaced_pod_exec, @@ -562,11 +575,38 @@ def extract_xcom(self, pod: V1Pod) -> str: resp, f"if [ -s {PodDefaults.XCOM_MOUNT_PATH}/return.json ]; then cat {PodDefaults.XCOM_MOUNT_PATH}/return.json; else echo __airflow_xcom_result_empty__; fi", # noqa ) - self._exec_pod_command(resp, "kill -s SIGINT 1") + if result and result.rstrip() != "__airflow_xcom_result_empty__": + # Note: result string is parsed to check if its valid json. + # This function still returns a string which is converted into json in the calling method. + json.loads(result) + if result is None: raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}") return result + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + wait=tenacity.wait_exponential(multiplier=1, min=4, max=10), + reraise=True, + ) + def extract_xcom_kill(self, pod: V1Pod): + """Kills xcom sidecar container.""" + with closing( + kubernetes_stream( + self._client.connect_get_namespaced_pod_exec, + pod.metadata.name, + pod.metadata.namespace, + container=PodDefaults.SIDECAR_CONTAINER_NAME, + command=["/bin/sh"], + stdin=True, + stdout=True, + stderr=True, + tty=False, + _preload_content=False, + ) + ) as resp: + self._exec_pod_command(resp, "kill -s SIGINT 1") + def _exec_pod_command(self, resp, command: str) -> str | None: res = None if resp.is_open(): diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index a55be38a5e227..8f28d33dfdea5 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -18,6 +18,7 @@ import logging from datetime import datetime +from json.decoder import JSONDecodeError from unittest import mock from unittest.mock import MagicMock @@ -370,6 +371,53 @@ def test_container_is_terminated_with_waiting_state(self, container_state, expec pod_info.status.container_statuses = [container_status] assert container_is_terminated(pod_info, "base") == expected_is_terminated + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager._exec_pod_command") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill") + def test_extract_xcom_success(self, mock_exec_xcom_kill, mock_exec_pod_command, mock_kubernetes_stream): + """test when valid json is retrieved from xcom sidecar container.""" + xcom_json = """{"a": "true"}""" + mock_pod = MagicMock() + mock_exec_pod_command.return_value = xcom_json + ret = self.pod_manager.extract_xcom(pod=mock_pod) + assert ret == xcom_json + assert mock_exec_xcom_kill.call_count == 1 + + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager._exec_pod_command") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill") + def test_extract_xcom_failure(self, mock_exec_xcom_kill, mock_exec_pod_command, mock_kubernetes_stream): + """test when invalid json is retrieved from xcom sidecar container.""" + with pytest.raises(JSONDecodeError): + xcom_json = """{"a": "tru""" + mock_pod = MagicMock() + mock_exec_pod_command.return_value = xcom_json + self.pod_manager.extract_xcom(pod=mock_pod) + assert mock_exec_xcom_kill.call_count == 1 + + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager._exec_pod_command") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill") + def test_extract_xcom_empty(self, mock_exec_xcom_kill, mock_exec_pod_command, mock_kubernetes_stream): + """test when __airflow_xcom_result_empty__ is retrieved from xcom sidecar container.""" + mock_pod = MagicMock() + xcom_result = "__airflow_xcom_result_empty__" + mock_exec_pod_command.return_value = xcom_result + ret = self.pod_manager.extract_xcom(pod=mock_pod) + assert ret == xcom_result + assert mock_exec_xcom_kill.call_count == 1 + + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.kubernetes_stream") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager._exec_pod_command") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.extract_xcom_kill") + def test_extract_xcom_none(self, mock_exec_xcom_kill, mock_exec_pod_command, mock_kubernetes_stream): + """test when None is retrieved from xcom sidecar container.""" + with pytest.raises(AirflowException): + mock_pod = MagicMock() + mock_exec_pod_command.return_value = None + self.pod_manager.extract_xcom(pod=mock_pod) + assert mock_exec_xcom_kill.call_count == 1 + def params_for_test_container_is_running(): """The `container_is_running` method is designed to handle an assortment of bad objects