From f2e650c554b8ef15e88f9dabe618852659436a11 Mon Sep 17 00:00:00 2001 From: James Nowell Date: Mon, 5 Jun 2023 11:14:44 -0500 Subject: [PATCH] Revert "enhance spark_k8s_operator (#29977)" This reverts commit 9a4f6748521c9c3b66d96598036be08fd94ccf89. Based on the discussion found [here](https://github.com/apache/airflow/issues/31183), previous changes to the Spark K8s Operator broke existing functionality and did not update the documentation for the newly enabled functionality. The Spark Sensor no longer works, XCOM no longer works on the Operator itself, and the Operator does not fail when the Spark job fails. Rather than attempt to fix or resolve the current implementation, I am reverting to the existing, documented implementation. I would propose creating a _new_ Operator with alternative functionality (one which does not need a Sensor, copies logs, etc.) if that is desired. --- .../cncf/kubernetes/hooks/kubernetes.py | 73 ++- .../kubernetes/operators/spark_kubernetes.py | 65 +-- .../flink/operators/test_flink_kubernetes.py | 63 ++- .../cncf/kubernetes/hooks/test_kubernetes.py | 25 - .../operators/test_spark_kubernetes.py | 483 +++++++++++++++--- 5 files changed, 504 insertions(+), 205 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index a3851a8987c93..ba0b9021c4698 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -38,7 +38,7 @@ LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." -def _load_body_to_dict(body: str) -> dict: +def _load_body_to_dict(body): try: body_dict = yaml.safe_load(body) except yaml.YAMLError as e: @@ -287,22 +287,37 @@ def create_custom_object( :param namespace: kubernetes namespace """ api: client.CustomObjectsApi = self.custom_object_client + namespace = namespace or self.get_namespace() or self.DEFAULT_NAMESPACE if isinstance(body, str): body_dict = _load_body_to_dict(body) else: body_dict = body - response = api.create_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - body=body_dict, - ) + # Attribute "name" is not mandatory if "generateName" is used instead + if "name" in body_dict["metadata"]: + try: + api.delete_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + name=body_dict["metadata"]["name"], + ) + + self.log.warning("Deleted SparkApplication with the same name") + except client.rest.ApiException: + self.log.info("SparkApplication %s not found", body_dict["metadata"]["name"]) - self.log.debug("Response: %s", response) - return response + try: + response = api.create_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, body=body_dict + ) + + self.log.debug("Response: %s", response) + return response + except client.rest.ApiException as e: + raise AirflowException(f"Exception when calling -> create_custom_object: {e}\n") def get_custom_object( self, group: str, version: str, plural: str, name: str, namespace: str | None = None @@ -317,36 +332,14 @@ def get_custom_object( :param namespace: kubernetes namespace """ api = client.CustomObjectsApi(self.api_client) - response = api.get_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - name=name, - ) - return response - - def delete_custom_object( - self, group: str, version: str, plural: str, name: str, namespace: str | None = None, **kwargs - ): - """ - Delete custom resource definition object from Kubernetes. - - :param group: api group - :param version: api version - :param plural: api plural - :param name: crd object name - :param namespace: kubernetes namespace - """ - api = client.CustomObjectsApi(self.api_client) - return api.delete_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - name=name, - **kwargs, - ) + namespace = namespace or self.get_namespace() or self.DEFAULT_NAMESPACE + try: + response = api.get_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, name=name + ) + return response + except client.rest.ApiException as e: + raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n") def get_namespace(self) -> str | None: """Returns the namespace that defined in the connection.""" diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 2206e04a735d5..05705fd6774cd 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,10 +19,8 @@ from typing import TYPE_CHECKING, Sequence -from kubernetes.watch import Watch - from airflow.models import BaseOperator -from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook if TYPE_CHECKING: from airflow.utils.context import Context @@ -57,71 +55,24 @@ def __init__( kubernetes_conn_id: str = "kubernetes_default", api_group: str = "sparkoperator.k8s.io", api_version: str = "v1beta2", - in_cluster: bool | None = None, - cluster_context: str | None = None, - config_file: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) + self.application_file = application_file self.namespace = namespace self.kubernetes_conn_id = kubernetes_conn_id self.api_group = api_group self.api_version = api_version self.plural = "sparkapplications" - self.application_file = application_file - self.in_cluster = in_cluster - self.cluster_context = cluster_context - self.config_file = config_file - - self.hook = KubernetesHook( - conn_id=self.kubernetes_conn_id, - in_cluster=self.in_cluster, - config_file=self.config_file, - cluster_context=self.cluster_context, - ) def execute(self, context: Context): - body = _load_body_to_dict(self.application_file) - name = body["metadata"]["name"] - namespace = self.namespace or self.hook.get_namespace() - namespace_event_stream = Watch().stream( - self.hook.core_v1_client.list_namespaced_pod, - namespace=namespace, - _preload_content=False, - watch=True, - label_selector=f"sparkoperator.k8s.io/app-name={name},spark-role=driver", - field_selector="status.phase=Running", - ) - - self.hook.create_custom_object( - group=self.api_group, - version=self.api_version, - plural=self.plural, - body=body, - namespace=namespace, - ) - for event in namespace_event_stream: - if event["type"] == "ADDED": - pod_log_stream = Watch().stream( - self.hook.core_v1_client.read_namespaced_pod_log, - name=f"{name}-driver", - namespace=namespace, - _preload_content=False, - timestamps=True, - ) - for line in pod_log_stream: - self.log.info(line) - else: - break - - def on_kill(self) -> None: - body = _load_body_to_dict(self.application_file) - name = body["metadata"]["name"] - namespace = self.namespace or self.hook.get_namespace() - self.hook.delete_custom_object( + hook = KubernetesHook(conn_id=self.kubernetes_conn_id) + self.log.info("Creating sparkApplication") + response = hook.create_custom_object( group=self.api_group, version=self.api_version, plural=self.plural, - namespace=namespace, - name=name, + body=self.application_file, + namespace=self.namespace, ) + return response diff --git a/tests/providers/apache/flink/operators/test_flink_kubernetes.py b/tests/providers/apache/flink/operators/test_flink_kubernetes.py index 3e7cdb01f334e..598c6edd94c42 100644 --- a/tests/providers/apache/flink/operators/test_flink_kubernetes.py +++ b/tests/providers/apache/flink/operators/test_flink_kubernetes.py @@ -197,8 +197,11 @@ def setup_method(self): args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} self.dag = DAG("test_dag_id", default_args=args) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_yaml( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_YAML, dag=self.dag, @@ -207,7 +210,13 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="default", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -216,8 +225,11 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_json( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -226,7 +238,13 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="default", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -235,9 +253,10 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") def test_create_application_from_json_with_api_group_and_version( - self, mock_create_namespaced_crd, mock_kubernetes_hook + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook ): api_group = "flink.apache.org" api_version = "v1beta1" @@ -251,7 +270,13 @@ def test_create_application_from_json_with_api_group_and_version( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group=api_group, + namespace="default", + plural="flinkdeployments", + version=api_version, + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group=api_group, @@ -260,8 +285,11 @@ def test_create_application_from_json_with_api_group_and_version( version=api_version, ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_operator( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -271,7 +299,13 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="operator_namespace", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -280,8 +314,11 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_connection( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -291,7 +328,13 @@ def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubern op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="mock_namespace", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index ba151efaf7fa6..85ffbcdb18e9a 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -373,31 +373,6 @@ def test_missing_default_connection_is_ok(self, remove_default_conn): with pytest.raises(AirflowNotFoundException, match="The conn_id `some_conn` isn't defined"): hook.conn_extras - @patch("kubernetes.config.kube_config.KubeConfigLoader") - @patch("kubernetes.config.kube_config.KubeConfigMerger") - @patch(f"{HOOK_MODULE}.client.CustomObjectsApi") - def test_delete_custom_object( - self, mock_custom_object_api, mock_kube_config_merger, mock_kube_config_loader - ): - hook = KubernetesHook() - hook.delete_custom_object( - group="group", - version="version", - plural="plural", - name="name", - namespace="namespace", - _preload_content="_preload_content", - ) - - mock_custom_object_api.return_value.delete_namespaced_custom_object.assert_called_once_with( - group="group", - version="version", - plural="plural", - name="name", - namespace="namespace", - _preload_content="_preload_content", - ) - class TestKubernetesHookIncorrectConfiguration: @pytest.mark.parametrize( diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index 7ea67aec80a95..6989337a0b276 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -16,81 +16,418 @@ # under the License. from __future__ import annotations +import json from unittest.mock import patch +from airflow import DAG +from airflow.models import Connection from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator +from airflow.utils import db, timezone +TEST_VALID_APPLICATION_YAML = """ +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + name: spark-pi + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.5" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" + sparkVersion: "2.4.5" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.5 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.5 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" +""" -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_spark_kubernetes_operator(mock_kubernetes_hook): - SparkKubernetesOperator( - task_id="task_id", - application_file="application_file", - kubernetes_conn_id="kubernetes_conn_id", - in_cluster=True, - cluster_context="cluster_context", - config_file="config_file", - ) - - mock_kubernetes_hook.assert_called_once_with( - conn_id="kubernetes_conn_id", - in_cluster=True, - cluster_context="cluster_context", - config_file="config_file", - ) - - -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream): - mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} - mock_kubernetes_hook.return_value.get_namespace.return_value = "default" - mock_stream.side_effect = [[{"type": "ADDED"}], []] - - op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") - op.execute({}) - mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with( - group="sparkoperator.k8s.io", - version="v1beta2", - plural="sparkapplications", - body={"metadata": {"name": "spark-app"}}, - namespace="default", - ) - - assert mock_stream.call_count == 2 - mock_stream.assert_any_call( - mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_pod, - namespace="default", - _preload_content=False, - watch=True, - label_selector="sparkoperator.k8s.io/app-name=spark-app,spark-role=driver", - field_selector="status.phase=Running", - ) - - mock_stream.assert_any_call( - mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log, - name="spark-app-driver", - namespace="default", - _preload_content=False, - timestamps=True, - ) - - -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_on_kill(mock_kubernetes_hook, mock_load_body_to_dict): - mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} - mock_kubernetes_hook.return_value.get_namespace.return_value = "default" - - op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") - - op.on_kill() - - mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with( - group="sparkoperator.k8s.io", - version="v1beta2", - plural="sparkapplications", - namespace="default", - name="spark-app", - ) +TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME = """ +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + generateName: spark-pi + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.5" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" + sparkVersion: "2.4.5" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.5 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.5 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" +""" + +TEST_VALID_APPLICATION_JSON = """ +{ + "apiVersion":"sparkoperator.k8s.io/v1beta2", + "kind":"SparkApplication", + "metadata":{ + "name":"spark-pi", + "namespace":"default" + }, + "spec":{ + "type":"Scala", + "mode":"cluster", + "image":"gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy":"Always", + "mainClass":"org.apache.spark.examples.SparkPi", + "mainApplicationFile":"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "sparkVersion":"2.4.5", + "restartPolicy":{ + "type":"Never" + }, + "volumes":[ + { + "name":"test-volume", + "hostPath":{ + "path":"/tmp", + "type":"Directory" + } + } + ], + "driver":{ + "cores":1, + "coreLimit":"1200m", + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "serviceAccount":"spark", + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + }, + "executor":{ + "cores":1, + "instances":1, + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + } + } +} +""" + +TEST_APPLICATION_DICT = { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"name": "spark-pi", "namespace": "default"}, + "spec": { + "driver": { + "coreLimit": "1200m", + "cores": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "serviceAccount": "spark", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "mainClass": "org.apache.spark.examples.SparkPi", + "mode": "cluster", + "restartPolicy": {"type": "Never"}, + "sparkVersion": "2.4.5", + "type": "Scala", + "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], + }, +} + +TEST_APPLICATION_DICT_WITH_GENERATE_NAME = { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"generateName": "spark-pi", "namespace": "default"}, + "spec": { + "driver": { + "coreLimit": "1200m", + "cores": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "serviceAccount": "spark", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "mainClass": "org.apache.spark.examples.SparkPi", + "mode": "cluster", + "restartPolicy": {"type": "Never"}, + "sparkVersion": "2.4.5", + "type": "Scala", + "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], + }, +} + + +@patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") +@patch("airflow.utils.context.Context") +class TestSparkKubernetesOperator: + def setup_method(self): + db.merge_conn( + Connection( + conn_id="kubernetes_default_kube_config", + conn_type="kubernetes", + extra=json.dumps({}), + ) + ) + + db.merge_conn( + Connection( + conn_id="kubernetes_with_namespace", + conn_type="kubernetes", + extra=json.dumps({"namespace": "mock_namespace"}), + ) + ) + + args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} + self.dag = DAG("test_dag_id", default_args=args) + + def test_create_application_from_yaml( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_YAML, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_yaml_using_generate_name( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_not_called() + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT_WITH_GENERATE_NAME, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_json( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_json_with_api_group_and_version( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + api_group = "sparkoperator.example.com" + api_version = "v1alpha1" + + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + api_group=api_group, + api_version=api_version, + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group=api_group, + namespace="default", + plural="sparkapplications", + version=api_version, + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group=api_group, + namespace="default", + plural="sparkapplications", + version=api_version, + ) + + def test_namespace_from_operator( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + namespace="operator_namespace", + kubernetes_conn_id="kubernetes_with_namespace", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="operator_namespace", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="operator_namespace", + plural="sparkapplications", + version="v1beta2", + ) + + def test_namespace_from_connection( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_with_namespace", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="mock_namespace", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="mock_namespace", + plural="sparkapplications", + version="v1beta2", + )