diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index f2e3d4d3fe5d1..aa784b2293641 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -85,6 +85,9 @@ def get_connection_form_widgets() -> Dict[str, Any]: "extra__kubernetes__namespace": StringField( lazy_gettext('Namespace'), widget=BS3TextFieldWidget() ), + "extra__kubernetes__cluster_context": StringField( + lazy_gettext('Cluster context'), widget=BS3TextFieldWidget() + ), } @staticmethod @@ -96,25 +99,49 @@ def get_ui_field_behaviour() -> Dict: } def __init__( - self, conn_id: str = default_conn_name, client_configuration: Optional[client.Configuration] = None + self, + conn_id: Optional[str] = default_conn_name, + client_configuration: Optional[client.Configuration] = None, + cluster_context: Optional[str] = None, + config_file: Optional[str] = None, + in_cluster: Optional[bool] = None, ) -> None: super().__init__() self.conn_id = conn_id self.client_configuration = client_configuration + self.cluster_context = cluster_context + self.config_file = config_file + self.in_cluster = in_cluster + + @staticmethod + def _coalesce_param(*params): + for param in params: + if param is not None: + return param def get_conn(self) -> Any: """Returns kubernetes api session for use with requests""" - connection = self.get_connection(self.conn_id) - extras = connection.extra_dejson - in_cluster = extras.get("extra__kubernetes__in_cluster") or None - kubeconfig_path = extras.get("extra__kubernetes__kube_config_path") or None + if self.conn_id: + connection = self.get_connection(self.conn_id) + extras = connection.extra_dejson + else: + extras = {} + in_cluster = self._coalesce_param( + self.in_cluster, extras.get("extra__kubernetes__in_cluster") or None + ) + cluster_context = self._coalesce_param( + self.cluster_context, extras.get("extra__kubernetes__cluster_context") or None + ) + kubeconfig_path = self._coalesce_param( + self.config_file, extras.get("extra__kubernetes__kube_config_path") or None + ) kubeconfig = extras.get("extra__kubernetes__kube_config") or None num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o]) if num_selected_configuration > 1: raise AirflowException( - "Invalid connection configuration. Options extra__kubernetes__kube_config_path, " - "extra__kubernetes__kube_config, extra__kubernetes__in_cluster are mutually exclusive. " + "Invalid connection configuration. Options kube_config_path, " + "kube_config, in_cluster are mutually exclusive. " "You can only use one option at a time." ) if in_cluster: @@ -125,7 +152,9 @@ def get_conn(self) -> Any: if kubeconfig_path is not None: self.log.debug("loading kube_config from: %s", kubeconfig_path) config.load_kube_config( - config_file=kubeconfig_path, client_configuration=self.client_configuration + config_file=kubeconfig_path, + client_configuration=self.client_configuration, + context=cluster_context, ) return client.ApiClient() @@ -135,12 +164,17 @@ def get_conn(self) -> Any: temp_config.write(kubeconfig.encode()) temp_config.flush() config.load_kube_config( - config_file=temp_config.name, client_configuration=self.client_configuration + config_file=temp_config.name, + client_configuration=self.client_configuration, + context=cluster_context, ) return client.ApiClient() self.log.debug("loading kube_config from: default file") - config.load_kube_config(client_configuration=self.client_configuration) + config.load_kube_config( + client_configuration=self.client_configuration, + context=cluster_context, + ) return client.ApiClient() @cached_property @@ -148,6 +182,10 @@ def api_client(self) -> Any: """Cached Kubernetes API client""" return self.get_conn() + @cached_property + def core_v1_client(self): + return client.CoreV1Api(api_client=self.api_client) + def create_custom_object( self, group: str, version: str, plural: str, body: Union[str, dict], namespace: Optional[str] = None ): @@ -207,12 +245,13 @@ def get_custom_object( except client.rest.ApiException as e: raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n") - def get_namespace(self) -> str: + def get_namespace(self) -> Optional[str]: """Returns the namespace that defined in the connection""" - connection = self.get_connection(self.conn_id) - extras = connection.extra_dejson - namespace = extras.get("extra__kubernetes__namespace", "default") - return namespace + if self.conn_id: + connection = self.get_connection(self.conn_id) + extras = connection.extra_dejson + namespace = extras.get("extra__kubernetes__namespace", "default") + return namespace def get_pod_log_stream( self, diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index 31e81d8029609..256974ebd40d8 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -20,13 +20,11 @@ import json import os import tempfile -import unittest from unittest import mock from unittest.mock import patch import kubernetes import pytest -from parameterized import parameterized from airflow import AirflowException from airflow.models import Connection @@ -37,107 +35,188 @@ KUBE_CONFIG_PATH = os.getenv('KUBECONFIG', '~/.kube/config') -class TestKubernetesHook(unittest.TestCase): +class TestKubernetesHook: @classmethod - def setUpClass(cls) -> None: + def setup_class(cls) -> None: for conn_id, extra in [ - ('kubernetes_in_cluster', {'extra__kubernetes__in_cluster': True}), - ('kubernetes_kube_config', {'extra__kubernetes__kube_config': '{"test": "kube"}'}), - ('kubernetes_kube_config_path', {'extra__kubernetes__kube_config_path': 'path/to/file'}), - ('kubernetes_in_cluster_empty', {'extra__kubernetes__in_cluster': ''}), - ('kubernetes_kube_config_empty', {'extra__kubernetes__kube_config': ''}), - ('kubernetes_kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}), - ('kubernetes_with_namespace', {'extra__kubernetes__namespace': 'mock_namespace'}), - ('kubernetes_default_kube_config', {}), + ('in_cluster', {'extra__kubernetes__in_cluster': True}), + ('kube_config', {'extra__kubernetes__kube_config': '{"test": "kube"}'}), + ('kube_config_path', {'extra__kubernetes__kube_config_path': 'path/to/file'}), + ('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}), + ('kube_config_empty', {'extra__kubernetes__kube_config': ''}), + ('kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}), + ('kube_config_empty', {'extra__kubernetes__kube_config': ''}), + ('kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}), + ('context_empty', {'extra__kubernetes__cluster_context': ''}), + ('context', {'extra__kubernetes__cluster_context': 'my-context'}), + ('with_namespace', {'extra__kubernetes__namespace': 'mock_namespace'}), + ('default_kube_config', {}), ]: db.merge_conn(Connection(conn_type='kubernetes', conn_id=conn_id, extra=json.dumps(extra))) @classmethod - def tearDownClass(cls) -> None: + def teardown_class(cls) -> None: clear_db_connections() - @patch("kubernetes.config.incluster_config.InClusterConfigLoader") - def test_in_cluster_connection(self, mock_kube_config_loader): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_in_cluster') - api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once() - assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) - - @patch("kubernetes.config.kube_config.KubeConfigMerger") - @patch("kubernetes.config.kube_config.KubeConfigLoader") - def test_in_cluster_connection_empty(self, mock_kube_config_merger, mock_kube_config_loader): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_in_cluster_empty') - api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH) - mock_kube_config_merger.assert_called_once() - assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) - + @pytest.mark.parametrize( + 'in_cluster_param, conn_id, in_cluster_called', + ( + (True, None, True), + (None, None, False), + (False, None, False), + (None, 'in_cluster', True), + (True, 'in_cluster', True), + (False, 'in_cluster', False), + (None, 'in_cluster_empty', False), + (True, 'in_cluster_empty', True), + (False, 'in_cluster_empty', False), + ), + ) @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") - def test_kube_config_path(self, mock_kube_config_loader, mock_kube_config_merger): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_path') + @patch("kubernetes.config.incluster_config.InClusterConfigLoader") + def test_in_cluster_connection( + self, + mock_in_cluster_loader, + mock_merger, + mock_loader, + in_cluster_param, + conn_id, + in_cluster_called, + ): + """ + Verifies whether in_cluster is called depending on combination of hook param and connection extra. + Hook param should beat extra. + """ + kubernetes_hook = KubernetesHook(conn_id=conn_id, in_cluster=in_cluster_param) api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once_with("path/to/file") - mock_kube_config_merger.assert_called_once() + if in_cluster_called: + mock_in_cluster_loader.assert_called_once() + mock_merger.assert_not_called() + mock_loader.assert_not_called() + else: + mock_in_cluster_loader.assert_not_called() + mock_merger.assert_called_once_with(KUBE_CONFIG_PATH) + mock_loader.assert_called_once() assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + @pytest.mark.parametrize( + 'config_path_param, conn_id, call_path', + ( + (None, None, KUBE_CONFIG_PATH), + ('/my/path/override', None, '/my/path/override'), + (None, 'kube_config_path', 'path/to/file'), + ('/my/path/override', 'kube_config_path', '/my/path/override'), + (None, 'kube_config_path_empty', KUBE_CONFIG_PATH), + ('/my/path/override', 'kube_config_path_empty', '/my/path/override'), + ), + ) @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") - def test_kube_config_path_empty(self, mock_kube_config_loader, mock_kube_config_merger): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_path_empty') + def test_kube_config_path( + self, mock_kube_config_merger, mock_kube_config_loader, config_path_param, conn_id, call_path + ): + """ + Verifies kube config path depending on combination of hook param and connection extra. + Hook param should beat extra. + """ + kubernetes_hook = KubernetesHook(conn_id=conn_id, config_file=config_path_param) api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH) - mock_kube_config_merger.assert_called_once() + mock_kube_config_merger.assert_called_once_with(call_path) + mock_kube_config_loader.assert_called_once() assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + @pytest.mark.parametrize( + 'conn_id, has_config', + ( + (None, False), + ('kube_config', True), + ('kube_config_empty', False), + ), + ) @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch.object(tempfile, 'NamedTemporaryFile') - def test_kube_config_connection(self, mock_kube_config_loader, mock_kube_config_merger, mock_tempfile): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config') + def test_kube_config_connection( + self, mock_tempfile, mock_kube_config_merger, mock_kube_config_loader, conn_id, has_config + ): + """ + Verifies whether temporary kube config file is created. + """ + mock_tempfile.return_value.__enter__.return_value.name = "fake-temp-file" + mock_kube_config_merger.return_value.config = {"fake_config": "value"} + kubernetes_hook = KubernetesHook(conn_id=conn_id) api_conn = kubernetes_hook.get_conn() - mock_tempfile.is_called_once() - mock_kube_config_loader.assert_called_once() - mock_kube_config_merger.assert_called_once() + if has_config: + mock_tempfile.is_called_once() + mock_kube_config_loader.assert_called_once() + mock_kube_config_merger.assert_called_once_with('fake-temp-file') + else: + mock_tempfile.assert_not_called() + mock_kube_config_loader.assert_called_once() + mock_kube_config_merger.assert_called_once_with(KUBE_CONFIG_PATH) assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) - @patch("kubernetes.config.kube_config.KubeConfigLoader") - @patch("kubernetes.config.kube_config.KubeConfigMerger") - def test_kube_config_connection_empty(self, mock_kube_config_loader, mock_kube_config_merger): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_empty') - api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH) - mock_kube_config_merger.assert_called_once() - assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) + @pytest.mark.parametrize( + 'context_param, conn_id, expected_context', + ( + ('param-context', None, 'param-context'), + (None, None, None), + ('param-context', 'context', 'param-context'), + (None, 'context', 'my-context'), + ('param-context', 'context_empty', 'param-context'), + (None, 'context_empty', None), + ), + ) + @patch("kubernetes.config.load_kube_config") + def test_cluster_context(self, mock_load_kube_config, context_param, conn_id, expected_context): + """ + Verifies cluster context depending on combination of hook param and connection extra. + Hook param should beat extra. + """ + kubernetes_hook = KubernetesHook(conn_id=conn_id, cluster_context=context_param) + kubernetes_hook.get_conn() + mock_load_kube_config.assert_called_with(client_configuration=None, context=expected_context) @patch("kubernetes.config.kube_config.KubeConfigLoader") @patch("kubernetes.config.kube_config.KubeConfigMerger") @patch("kubernetes.config.kube_config.KUBE_CONFIG_DEFAULT_LOCATION", "/mock/config") - def test_default_kube_config_connection( - self, - mock_kube_config_loader, - mock_kube_config_merger, - ): - kubernetes_hook = KubernetesHook(conn_id='kubernetes_default_kube_config') + def test_default_kube_config_connection(self, mock_kube_config_merger, mock_kube_config_loader): + kubernetes_hook = KubernetesHook(conn_id='default_kube_config') api_conn = kubernetes_hook.get_conn() - mock_kube_config_loader.assert_called_once_with("/mock/config") - mock_kube_config_merger.assert_called_once() + mock_kube_config_merger.assert_called_once_with("/mock/config") + mock_kube_config_loader.assert_called_once() assert isinstance(api_conn, kubernetes.client.api_client.ApiClient) - def test_get_namespace(self): - kubernetes_hook_with_namespace = KubernetesHook(conn_id='kubernetes_with_namespace') - kubernetes_hook_without_namespace = KubernetesHook(conn_id='kubernetes_default_kube_config') - assert kubernetes_hook_with_namespace.get_namespace() == 'mock_namespace' - assert kubernetes_hook_without_namespace.get_namespace() == 'default' + @pytest.mark.parametrize( + 'conn_id, expected', + ( + pytest.param(None, None, id='no-conn-id'), + pytest.param('with_namespace', 'mock_namespace', id='conn-with-namespace'), + pytest.param('default_kube_config', 'default', id='conn-without-namespace'), + ), + ) + def test_get_namespace(self, conn_id, expected): + hook = KubernetesHook(conn_id=conn_id) + assert hook.get_namespace() == expected + + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader): + hook = KubernetesHook(None) + assert isinstance(hook.core_v1_client, kubernetes.client.CoreV1Api) + assert isinstance(hook.api_client, kubernetes.client.ApiClient) + assert isinstance(hook.get_conn(), kubernetes.client.ApiClient) -class TestKubernetesHookIncorrectConfiguration(unittest.TestCase): - @parameterized.expand( +class TestKubernetesHookIncorrectConfiguration: + @pytest.mark.parametrize( + 'conn_uri', ( "kubernetes://?extra__kubernetes__kube_config_path=/tmp/&extra__kubernetes__kube_config=[1,2,3]", "kubernetes://?extra__kubernetes__kube_config_path=/tmp/&extra__kubernetes__in_cluster=[1,2,3]", "kubernetes://?extra__kubernetes__kube_config=/tmp/&extra__kubernetes__in_cluster=[1,2,3]", - ) + ), ) def test_should_raise_exception_on_invalid_configuration(self, conn_uri): with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri), pytest.raises(