Skip to content
69 changes: 54 additions & 15 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ def get_connection_form_widgets() -> Dict[str, Any]:
"extra__kubernetes__namespace": StringField(
lazy_gettext('Namespace'), widget=BS3TextFieldWidget()
),
"extra__kubernetes__cluster_context": StringField(
lazy_gettext('Cluster context'), widget=BS3TextFieldWidget()
),
}

@staticmethod
Expand All @@ -96,25 +99,49 @@ def get_ui_field_behaviour() -> Dict:
}

def __init__(
self, conn_id: str = default_conn_name, client_configuration: Optional[client.Configuration] = None
self,
conn_id: Optional[str] = default_conn_name,
client_configuration: Optional[client.Configuration] = None,
cluster_context: Optional[str] = None,
config_file: Optional[str] = None,
in_cluster: Optional[bool] = None,
) -> None:
super().__init__()
self.conn_id = conn_id
self.client_configuration = client_configuration
self.cluster_context = cluster_context
self.config_file = config_file
self.in_cluster = in_cluster

@staticmethod
def _coalesce_param(*params):
for param in params:
if param is not None:
return param

def get_conn(self) -> Any:
"""Returns kubernetes api session for use with requests"""
connection = self.get_connection(self.conn_id)
extras = connection.extra_dejson
in_cluster = extras.get("extra__kubernetes__in_cluster") or None
kubeconfig_path = extras.get("extra__kubernetes__kube_config_path") or None
if self.conn_id:
connection = self.get_connection(self.conn_id)
extras = connection.extra_dejson
else:
extras = {}
in_cluster = self._coalesce_param(
self.in_cluster, extras.get("extra__kubernetes__in_cluster") or None
)
cluster_context = self._coalesce_param(
self.cluster_context, extras.get("extra__kubernetes__cluster_context") or None
)
kubeconfig_path = self._coalesce_param(
self.config_file, extras.get("extra__kubernetes__kube_config_path") or None
)
kubeconfig = extras.get("extra__kubernetes__kube_config") or None
num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o])

if num_selected_configuration > 1:
raise AirflowException(
"Invalid connection configuration. Options extra__kubernetes__kube_config_path, "
"extra__kubernetes__kube_config, extra__kubernetes__in_cluster are mutually exclusive. "
"Invalid connection configuration. Options kube_config_path, "
"kube_config, in_cluster are mutually exclusive. "
"You can only use one option at a time."
)
if in_cluster:
Expand All @@ -125,7 +152,9 @@ def get_conn(self) -> Any:
if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
config.load_kube_config(
config_file=kubeconfig_path, client_configuration=self.client_configuration
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

Expand All @@ -135,19 +164,28 @@ def get_conn(self) -> Any:
temp_config.write(kubeconfig.encode())
temp_config.flush()
config.load_kube_config(
config_file=temp_config.name, client_configuration=self.client_configuration
config_file=temp_config.name,
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

self.log.debug("loading kube_config from: default file")
config.load_kube_config(client_configuration=self.client_configuration)
config.load_kube_config(
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()

@cached_property
def api_client(self) -> Any:
"""Cached Kubernetes API client"""
return self.get_conn()

@cached_property
def core_v1_client(self):
return client.CoreV1Api(api_client=self.api_client)

def create_custom_object(
self, group: str, version: str, plural: str, body: Union[str, dict], namespace: Optional[str] = None
):
Expand Down Expand Up @@ -207,12 +245,13 @@ def get_custom_object(
except client.rest.ApiException as e:
raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n")

def get_namespace(self) -> str:
def get_namespace(self) -> Optional[str]:
"""Returns the namespace that defined in the connection"""
connection = self.get_connection(self.conn_id)
extras = connection.extra_dejson
namespace = extras.get("extra__kubernetes__namespace", "default")
return namespace
if self.conn_id:
connection = self.get_connection(self.conn_id)
extras = connection.extra_dejson
namespace = extras.get("extra__kubernetes__namespace", "default")
return namespace

def get_pod_log_stream(
self,
Expand Down
211 changes: 145 additions & 66 deletions tests/providers/cncf/kubernetes/hooks/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@
import json
import os
import tempfile
import unittest
from unittest import mock
from unittest.mock import patch

import kubernetes
import pytest
from parameterized import parameterized

from airflow import AirflowException
from airflow.models import Connection
Expand All @@ -37,107 +35,188 @@
KUBE_CONFIG_PATH = os.getenv('KUBECONFIG', '~/.kube/config')


class TestKubernetesHook(unittest.TestCase):
class TestKubernetesHook:
@classmethod
def setUpClass(cls) -> None:
def setup_class(cls) -> None:
for conn_id, extra in [
('kubernetes_in_cluster', {'extra__kubernetes__in_cluster': True}),
('kubernetes_kube_config', {'extra__kubernetes__kube_config': '{"test": "kube"}'}),
('kubernetes_kube_config_path', {'extra__kubernetes__kube_config_path': 'path/to/file'}),
('kubernetes_in_cluster_empty', {'extra__kubernetes__in_cluster': ''}),
('kubernetes_kube_config_empty', {'extra__kubernetes__kube_config': ''}),
('kubernetes_kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}),
('kubernetes_with_namespace', {'extra__kubernetes__namespace': 'mock_namespace'}),
('kubernetes_default_kube_config', {}),
('in_cluster', {'extra__kubernetes__in_cluster': True}),
('kube_config', {'extra__kubernetes__kube_config': '{"test": "kube"}'}),
('kube_config_path', {'extra__kubernetes__kube_config_path': 'path/to/file'}),
('in_cluster_empty', {'extra__kubernetes__in_cluster': ''}),
('kube_config_empty', {'extra__kubernetes__kube_config': ''}),
('kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}),
('kube_config_empty', {'extra__kubernetes__kube_config': ''}),
('kube_config_path_empty', {'extra__kubernetes__kube_config_path': ''}),
('context_empty', {'extra__kubernetes__cluster_context': ''}),
('context', {'extra__kubernetes__cluster_context': 'my-context'}),
('with_namespace', {'extra__kubernetes__namespace': 'mock_namespace'}),
('default_kube_config', {}),
]:
db.merge_conn(Connection(conn_type='kubernetes', conn_id=conn_id, extra=json.dumps(extra)))

@classmethod
def tearDownClass(cls) -> None:
def teardown_class(cls) -> None:
clear_db_connections()

@patch("kubernetes.config.incluster_config.InClusterConfigLoader")
def test_in_cluster_connection(self, mock_kube_config_loader):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_in_cluster')
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch("kubernetes.config.kube_config.KubeConfigLoader")
def test_in_cluster_connection_empty(self, mock_kube_config_merger, mock_kube_config_loader):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_in_cluster_empty')
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH)
mock_kube_config_merger.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
'in_cluster_param, conn_id, in_cluster_called',
(
(True, None, True),
(None, None, False),
(False, None, False),
(None, 'in_cluster', True),
(True, 'in_cluster', True),
(False, 'in_cluster', False),
(None, 'in_cluster_empty', False),
(True, 'in_cluster_empty', True),
(False, 'in_cluster_empty', False),
),
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_kube_config_path(self, mock_kube_config_loader, mock_kube_config_merger):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_path')
@patch("kubernetes.config.incluster_config.InClusterConfigLoader")
def test_in_cluster_connection(
self,
mock_in_cluster_loader,
mock_merger,
mock_loader,
in_cluster_param,
conn_id,
in_cluster_called,
):
"""
Verifies whether in_cluster is called depending on combination of hook param and connection extra.
Hook param should beat extra.
"""
kubernetes_hook = KubernetesHook(conn_id=conn_id, in_cluster=in_cluster_param)
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with("path/to/file")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that merger and loader were backwards in some of these tests

mock_kube_config_merger.assert_called_once()
if in_cluster_called:
mock_in_cluster_loader.assert_called_once()
mock_merger.assert_not_called()
mock_loader.assert_not_called()
else:
mock_in_cluster_loader.assert_not_called()
mock_merger.assert_called_once_with(KUBE_CONFIG_PATH)
mock_loader.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
'config_path_param, conn_id, call_path',
(
(None, None, KUBE_CONFIG_PATH),
('/my/path/override', None, '/my/path/override'),
(None, 'kube_config_path', 'path/to/file'),
('/my/path/override', 'kube_config_path', '/my/path/override'),
(None, 'kube_config_path_empty', KUBE_CONFIG_PATH),
('/my/path/override', 'kube_config_path_empty', '/my/path/override'),
),
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_kube_config_path_empty(self, mock_kube_config_loader, mock_kube_config_merger):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_path_empty')
def test_kube_config_path(
self, mock_kube_config_merger, mock_kube_config_loader, config_path_param, conn_id, call_path
):
"""
Verifies kube config path depending on combination of hook param and connection extra.
Hook param should beat extra.
"""
kubernetes_hook = KubernetesHook(conn_id=conn_id, config_file=config_path_param)
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH)
mock_kube_config_merger.assert_called_once()
mock_kube_config_merger.assert_called_once_with(call_path)
mock_kube_config_loader.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@pytest.mark.parametrize(
'conn_id, has_config',
(
(None, False),
('kube_config', True),
('kube_config_empty', False),
),
)
@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch.object(tempfile, 'NamedTemporaryFile')
def test_kube_config_connection(self, mock_kube_config_loader, mock_kube_config_merger, mock_tempfile):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config')
def test_kube_config_connection(
self, mock_tempfile, mock_kube_config_merger, mock_kube_config_loader, conn_id, has_config
):
"""
Verifies whether temporary kube config file is created.
"""
mock_tempfile.return_value.__enter__.return_value.name = "fake-temp-file"
mock_kube_config_merger.return_value.config = {"fake_config": "value"}
kubernetes_hook = KubernetesHook(conn_id=conn_id)
api_conn = kubernetes_hook.get_conn()
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once()
if has_config:
mock_tempfile.is_called_once()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with('fake-temp-file')
else:
mock_tempfile.assert_not_called()
mock_kube_config_loader.assert_called_once()
mock_kube_config_merger.assert_called_once_with(KUBE_CONFIG_PATH)
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_kube_config_connection_empty(self, mock_kube_config_loader, mock_kube_config_merger):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_kube_config_empty')
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with(KUBE_CONFIG_PATH)
mock_kube_config_merger.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
@pytest.mark.parametrize(
'context_param, conn_id, expected_context',
(
('param-context', None, 'param-context'),
(None, None, None),
('param-context', 'context', 'param-context'),
(None, 'context', 'my-context'),
('param-context', 'context_empty', 'param-context'),
(None, 'context_empty', None),
),
)
@patch("kubernetes.config.load_kube_config")
def test_cluster_context(self, mock_load_kube_config, context_param, conn_id, expected_context):
"""
Verifies cluster context depending on combination of hook param and connection extra.
Hook param should beat extra.
"""
kubernetes_hook = KubernetesHook(conn_id=conn_id, cluster_context=context_param)
kubernetes_hook.get_conn()
mock_load_kube_config.assert_called_with(client_configuration=None, context=expected_context)

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
@patch("kubernetes.config.kube_config.KUBE_CONFIG_DEFAULT_LOCATION", "/mock/config")
def test_default_kube_config_connection(
self,
mock_kube_config_loader,
mock_kube_config_merger,
):
kubernetes_hook = KubernetesHook(conn_id='kubernetes_default_kube_config')
def test_default_kube_config_connection(self, mock_kube_config_merger, mock_kube_config_loader):
kubernetes_hook = KubernetesHook(conn_id='default_kube_config')
api_conn = kubernetes_hook.get_conn()
mock_kube_config_loader.assert_called_once_with("/mock/config")
mock_kube_config_merger.assert_called_once()
mock_kube_config_merger.assert_called_once_with("/mock/config")
mock_kube_config_loader.assert_called_once()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)

def test_get_namespace(self):
kubernetes_hook_with_namespace = KubernetesHook(conn_id='kubernetes_with_namespace')
kubernetes_hook_without_namespace = KubernetesHook(conn_id='kubernetes_default_kube_config')
assert kubernetes_hook_with_namespace.get_namespace() == 'mock_namespace'
assert kubernetes_hook_without_namespace.get_namespace() == 'default'
@pytest.mark.parametrize(
'conn_id, expected',
(
pytest.param(None, None, id='no-conn-id'),
pytest.param('with_namespace', 'mock_namespace', id='conn-with-namespace'),
pytest.param('default_kube_config', 'default', id='conn-without-namespace'),
),
)
def test_get_namespace(self, conn_id, expected):
hook = KubernetesHook(conn_id=conn_id)
assert hook.get_namespace() == expected

@patch("kubernetes.config.kube_config.KubeConfigLoader")
@patch("kubernetes.config.kube_config.KubeConfigMerger")
def test_client_types(self, mock_kube_config_merger, mock_kube_config_loader):
hook = KubernetesHook(None)
assert isinstance(hook.core_v1_client, kubernetes.client.CoreV1Api)
assert isinstance(hook.api_client, kubernetes.client.ApiClient)
assert isinstance(hook.get_conn(), kubernetes.client.ApiClient)


class TestKubernetesHookIncorrectConfiguration(unittest.TestCase):
@parameterized.expand(
class TestKubernetesHookIncorrectConfiguration:
@pytest.mark.parametrize(
'conn_uri',
(
"kubernetes://?extra__kubernetes__kube_config_path=/tmp/&extra__kubernetes__kube_config=[1,2,3]",
"kubernetes://?extra__kubernetes__kube_config_path=/tmp/&extra__kubernetes__in_cluster=[1,2,3]",
"kubernetes://?extra__kubernetes__kube_config=/tmp/&extra__kubernetes__in_cluster=[1,2,3]",
)
),
)
def test_should_raise_exception_on_invalid_configuration(self, conn_uri):
with mock.patch.dict("os.environ", AIRFLOW_CONN_KUBERNETES_DEFAULT=conn_uri), pytest.raises(
Expand Down