diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index f2866e5d714b6..91c00f483c1e8 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -17,6 +17,7 @@ from __future__ import annotations import contextlib +import json import tempfile import warnings from typing import TYPE_CHECKING, Any, Generator @@ -101,6 +102,9 @@ def get_connection_form_widgets() -> dict[str, Any]: "xcom_sidecar_container_image": StringField( lazy_gettext("XCom sidecar image"), widget=BS3TextFieldWidget() ), + "xcom_sidecar_container_resources": StringField( + lazy_gettext("XCom sidecar resources (JSON format)"), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -366,6 +370,13 @@ def get_xcom_sidecar_container_image(self): """Returns the xcom sidecar image that defined in the connection""" return self._get_field("xcom_sidecar_container_image") + def get_xcom_sidecar_container_resources(self): + """Returns the xcom sidecar resources that defined in the connection""" + field = self._get_field("xcom_sidecar_container_resources") + if not field: + return None + return json.loads(field) + def get_pod_log_stream( self, pod_name: str, diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 5267d1aeb6374..a1ab51878f506 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -810,7 +810,9 @@ def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod: if self.do_xcom_push: self.log.debug("Adding xcom sidecar to task %s", self.task_id) pod = xcom_sidecar.add_xcom_sidecar( - pod, sidecar_container_image=self.hook.get_xcom_sidecar_container_image() + pod, + sidecar_container_image=self.hook.get_xcom_sidecar_container_image(), + sidecar_container_resources=self.hook.get_xcom_sidecar_container_resources(), ) labels = self._get_ti_pod_labels(context) diff --git a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py index 81b3047993691..462e6870b69ca 100644 --- a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py +++ b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py @@ -42,12 +42,18 @@ class PodDefaults: resources=k8s.V1ResourceRequirements( requests={ "cpu": "1m", - } + "memory": "10Mi", + }, ), ) -def add_xcom_sidecar(pod: k8s.V1Pod, *, sidecar_container_image=None) -> k8s.V1Pod: +def add_xcom_sidecar( + pod: k8s.V1Pod, + *, + sidecar_container_image: str | None = None, + sidecar_container_resources: k8s.V1ResourceRequirements | dict | None = None, +) -> k8s.V1Pod: """Adds sidecar""" pod_cp = copy.deepcopy(pod) pod_cp.spec.volumes = pod.spec.volumes or [] @@ -56,6 +62,8 @@ def add_xcom_sidecar(pod: k8s.V1Pod, *, sidecar_container_image=None) -> k8s.V1P pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT) sidecar = copy.deepcopy(PodDefaults.SIDECAR_CONTAINER) sidecar.image = sidecar_container_image or PodDefaults.SIDECAR_CONTAINER.image + if sidecar_container_resources: + sidecar.resources = sidecar_container_resources pod_cp.spec.containers.append(sidecar) return pod_cp diff --git a/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes_tests/test_kubernetes_pod_operator.py index 63443bb15336c..bc7e356975a9f 100644 --- a/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes_tests/test_kubernetes_pod_operator.py @@ -699,18 +699,18 @@ def test_pod_template_file_system(self): assert result == {"hello": "world"} @pytest.mark.parametrize( - "input", + "env_vars", [ param([k8s.V1EnvVar(name="env_name", value="value")], id="current"), param({"env_name": "value"}, id="backcompat"), # todo: remove? ], ) - def test_pod_template_file_with_overrides_system(self, input, test_label): + def test_pod_template_file_with_overrides_system(self, env_vars, test_label): fixture = sys.path[0] + "/tests/kubernetes/basic_pod.yaml" k = KubernetesPodOperator( task_id=str(uuid4()), labels=self.labels, - env_vars=[k8s.V1EnvVar(name="env_name", value="value")], + env_vars=env_vars, in_cluster=False, pod_template_file=fixture, do_xcom_push=True, @@ -890,6 +890,7 @@ def test_pod_template_file( await_xcom_sidecar_container_start_mock.return_value = None hook_mock.return_value.is_in_cluster = False hook_mock.return_value.get_xcom_sidecar_container_image.return_value = None + hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = None extract_xcom_mock.return_value = "{}" path = sys.path[0] + "/tests/kubernetes/pod.yaml" k = KubernetesPodOperator( @@ -956,7 +957,9 @@ def test_pod_template_file( "command": ["sh", "-c", 'trap "exit 0" INT; while true; do sleep 1; done;'], "image": "alpine", "name": "airflow-xcom-sidecar", - "resources": {"requests": {"cpu": "1m"}}, + "resources": { + "requests": {"cpu": "1m", "memory": "10Mi"}, + }, "volumeMounts": [{"mountPath": "/airflow/xcom", "name": "xcom"}], }, ], diff --git a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py index 853d056e2362f..48f042526ded9 100644 --- a/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/decorators/test_kubernetes.py @@ -128,6 +128,10 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None): (ti,) = dr.task_instances mock_hook.return_value.get_xcom_sidecar_container_image.return_value = XCOM_IMAGE + mock_hook.return_value.get_xcom_sidecar_container_resources.return_value = { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "1m", "memory": "50Mi"}, + } dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session)) @@ -139,6 +143,7 @@ def f(arg1, arg2, kwarg1=None, kwarg2=None): ) assert mock_create_pod.call_count == 1 assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1 + assert mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1 containers = mock_create_pod.call_args[1]["pod"].spec.containers diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py index 204060449768d..9ccca08abcad5 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py @@ -76,6 +76,18 @@ def setup_class(cls) -> None: ("disable_tcp_keepalive_empty", {"disable_tcp_keepalive": ""}), ("sidecar_container_image", {"xcom_sidecar_container_image": "private.repo.com/alpine:3.16"}), ("sidecar_container_image_empty", {"xcom_sidecar_container_image": ""}), + ( + "sidecar_container_resources", + { + "xcom_sidecar_container_resources": json.dumps( + { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "1m", "memory": "50Mi"}, + } + ), + }, + ), + ("sidecar_container_resources_empty", {"xcom_sidecar_container_resources": ""}), ]: db.merge_conn(Connection(conn_type="kubernetes", conn_id=conn_id, extra=json.dumps(extra))) @@ -341,6 +353,27 @@ def test_get_xcom_sidecar_container_image(self, conn_id, expected): hook = KubernetesHook(conn_id=conn_id) assert hook.get_xcom_sidecar_container_image() == expected + @pytest.mark.parametrize( + "conn_id, expected", + ( + pytest.param( + "sidecar_container_resources", + { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": { + "cpu": "1m", + "memory": "50Mi", + }, + }, + id="sidecar-with-resources", + ), + pytest.param("sidecar_container_resources_empty", None, id="sidecar-without-resources"), + ), + ) + def test_get_xcom_sidecar_container_resources(self, conn_id, expected): + hook = KubernetesHook(conn_id=conn_id) + assert hook.get_xcom_sidecar_container_resources() == expected + @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader): diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index 53457ae374fa9..c542ccbdb130a 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -513,6 +513,39 @@ def test_xcom_sidecar_container_image_custom(self, hook_mock): pod = k.build_pod_request_obj(create_context(k)) assert pod.spec.containers[1].image == "private.repo/alpine:3.13" + @patch(HOOK_CLASS) + def test_xcom_sidecar_container_resources_default(self, hook_mock): + hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = None + k = KubernetesPodOperator( + name="test", + task_id="task", + do_xcom_push=True, + ) + pod = k.build_pod_request_obj(create_context(k)) + assert pod.spec.containers[1].resources == k8s.V1ResourceRequirements( + requests={ + "cpu": "1m", + "memory": "10Mi", + }, + ) + + @patch(HOOK_CLASS) + def test_xcom_sidecar_container_resources_custom(self, hook_mock): + hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "10m", "memory": "50Mi"}, + } + k = KubernetesPodOperator( + name="test", + task_id="task", + do_xcom_push=True, + ) + pod = k.build_pod_request_obj(create_context(k)) + assert pod.spec.containers[1].resources == { + "requests": {"cpu": "1m", "memory": "10Mi"}, + "limits": {"cpu": "10m", "memory": "50Mi"}, + } + def test_image_pull_policy_correctly_set(self): k = KubernetesPodOperator( task_id="task", @@ -1264,6 +1297,23 @@ def test_async_xcom_sidecar_container_image_default_should_execute_successfully( pod = k.build_pod_request_obj(create_context(k)) assert pod.spec.containers[1].image == "alpine" + @patch(HOOK_CLASS) + def test_async_xcom_sidecar_container_resources_default_should_execute_successfully(self, hook_mock): + hook_mock.return_value.get_xcom_sidecar_container_resources.return_value = None + k = KubernetesPodOperator( + name=TEST_NAME, + task_id="task", + do_xcom_push=True, + deferrable=True, + ) + pod = k.build_pod_request_obj(create_context(k)) + assert pod.spec.containers[1].resources == k8s.V1ResourceRequirements( + requests={ + "cpu": "1m", + "memory": "10Mi", + }, + ) + @pytest.mark.parametrize("do_xcom_push", [True, False]) @patch(KUB_OP_PATH.format("post_complete_action")) @patch(KUB_OP_PATH.format("extract_xcom"))