diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index a3851a8987c93..ba0b9021c4698 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -38,7 +38,7 @@ LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." -def _load_body_to_dict(body: str) -> dict: +def _load_body_to_dict(body): try: body_dict = yaml.safe_load(body) except yaml.YAMLError as e: @@ -287,22 +287,37 @@ def create_custom_object( :param namespace: kubernetes namespace """ api: client.CustomObjectsApi = self.custom_object_client + namespace = namespace or self.get_namespace() or self.DEFAULT_NAMESPACE if isinstance(body, str): body_dict = _load_body_to_dict(body) else: body_dict = body - response = api.create_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - body=body_dict, - ) + # Attribute "name" is not mandatory if "generateName" is used instead + if "name" in body_dict["metadata"]: + try: + api.delete_namespaced_custom_object( + group=group, + version=version, + namespace=namespace, + plural=plural, + name=body_dict["metadata"]["name"], + ) + + self.log.warning("Deleted SparkApplication with the same name") + except client.rest.ApiException: + self.log.info("SparkApplication %s not found", body_dict["metadata"]["name"]) - self.log.debug("Response: %s", response) - return response + try: + response = api.create_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, body=body_dict + ) + + self.log.debug("Response: %s", response) + return response + except client.rest.ApiException as e: + raise AirflowException(f"Exception when calling -> create_custom_object: {e}\n") def get_custom_object( self, group: str, version: str, plural: str, name: str, namespace: str | None = None @@ -317,36 +332,14 @@ def get_custom_object( :param namespace: kubernetes namespace """ api = client.CustomObjectsApi(self.api_client) - response = api.get_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - name=name, - ) - return response - - def delete_custom_object( - self, group: str, version: str, plural: str, name: str, namespace: str | None = None, **kwargs - ): - """ - Delete custom resource definition object from Kubernetes. - - :param group: api group - :param version: api version - :param plural: api plural - :param name: crd object name - :param namespace: kubernetes namespace - """ - api = client.CustomObjectsApi(self.api_client) - return api.delete_namespaced_custom_object( - group=group, - version=version, - namespace=namespace or self.get_namespace() or self.DEFAULT_NAMESPACE, - plural=plural, - name=name, - **kwargs, - ) + namespace = namespace or self.get_namespace() or self.DEFAULT_NAMESPACE + try: + response = api.get_namespaced_custom_object( + group=group, version=version, namespace=namespace, plural=plural, name=name + ) + return response + except client.rest.ApiException as e: + raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n") def get_namespace(self) -> str | None: """Returns the namespace that defined in the connection.""" diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index 2206e04a735d5..05705fd6774cd 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,10 +19,8 @@ from typing import TYPE_CHECKING, Sequence -from kubernetes.watch import Watch - from airflow.models import BaseOperator -from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook if TYPE_CHECKING: from airflow.utils.context import Context @@ -57,71 +55,24 @@ def __init__( kubernetes_conn_id: str = "kubernetes_default", api_group: str = "sparkoperator.k8s.io", api_version: str = "v1beta2", - in_cluster: bool | None = None, - cluster_context: str | None = None, - config_file: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) + self.application_file = application_file self.namespace = namespace self.kubernetes_conn_id = kubernetes_conn_id self.api_group = api_group self.api_version = api_version self.plural = "sparkapplications" - self.application_file = application_file - self.in_cluster = in_cluster - self.cluster_context = cluster_context - self.config_file = config_file - - self.hook = KubernetesHook( - conn_id=self.kubernetes_conn_id, - in_cluster=self.in_cluster, - config_file=self.config_file, - cluster_context=self.cluster_context, - ) def execute(self, context: Context): - body = _load_body_to_dict(self.application_file) - name = body["metadata"]["name"] - namespace = self.namespace or self.hook.get_namespace() - namespace_event_stream = Watch().stream( - self.hook.core_v1_client.list_namespaced_pod, - namespace=namespace, - _preload_content=False, - watch=True, - label_selector=f"sparkoperator.k8s.io/app-name={name},spark-role=driver", - field_selector="status.phase=Running", - ) - - self.hook.create_custom_object( - group=self.api_group, - version=self.api_version, - plural=self.plural, - body=body, - namespace=namespace, - ) - for event in namespace_event_stream: - if event["type"] == "ADDED": - pod_log_stream = Watch().stream( - self.hook.core_v1_client.read_namespaced_pod_log, - name=f"{name}-driver", - namespace=namespace, - _preload_content=False, - timestamps=True, - ) - for line in pod_log_stream: - self.log.info(line) - else: - break - - def on_kill(self) -> None: - body = _load_body_to_dict(self.application_file) - name = body["metadata"]["name"] - namespace = self.namespace or self.hook.get_namespace() - self.hook.delete_custom_object( + hook = KubernetesHook(conn_id=self.kubernetes_conn_id) + self.log.info("Creating sparkApplication") + response = hook.create_custom_object( group=self.api_group, version=self.api_version, plural=self.plural, - namespace=namespace, - name=name, + body=self.application_file, + namespace=self.namespace, ) + return response diff --git a/tests/providers/apache/flink/operators/test_flink_kubernetes.py b/tests/providers/apache/flink/operators/test_flink_kubernetes.py index 3e7cdb01f334e..598c6edd94c42 100644 --- a/tests/providers/apache/flink/operators/test_flink_kubernetes.py +++ b/tests/providers/apache/flink/operators/test_flink_kubernetes.py @@ -197,8 +197,11 @@ def setup_method(self): args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} self.dag = DAG("test_dag_id", default_args=args) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_yaml( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_YAML, dag=self.dag, @@ -207,7 +210,13 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="default", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -216,8 +225,11 @@ def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kub version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_create_application_from_json( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -226,7 +238,13 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="default", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -235,9 +253,10 @@ def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kub version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") def test_create_application_from_json_with_api_group_and_version( - self, mock_create_namespaced_crd, mock_kubernetes_hook + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook ): api_group = "flink.apache.org" api_version = "v1beta1" @@ -251,7 +270,13 @@ def test_create_application_from_json_with_api_group_and_version( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group=api_group, + namespace="default", + plural="flinkdeployments", + version=api_version, + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group=api_group, @@ -260,8 +285,11 @@ def test_create_application_from_json_with_api_group_and_version( version=api_version, ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_operator( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -271,7 +299,13 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="operator_namespace", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -280,8 +314,11 @@ def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernet version="v1beta1", ) + @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubernetes_hook): + def test_namespace_from_connection( + self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -291,7 +328,13 @@ def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubern op.execute(None) mock_kubernetes_hook.assert_called_once_with() - + mock_delete_namespaced_crd.assert_called_once_with( + group="flink.apache.org", + namespace="mock_namespace", + plural="flinkdeployments", + version="v1beta1", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py index ba151efaf7fa6..85ffbcdb18e9a 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes.py @@ -373,31 +373,6 @@ def test_missing_default_connection_is_ok(self, remove_default_conn): with pytest.raises(AirflowNotFoundException, match="The conn_id `some_conn` isn't defined"): hook.conn_extras - @patch("kubernetes.config.kube_config.KubeConfigLoader") - @patch("kubernetes.config.kube_config.KubeConfigMerger") - @patch(f"{HOOK_MODULE}.client.CustomObjectsApi") - def test_delete_custom_object( - self, mock_custom_object_api, mock_kube_config_merger, mock_kube_config_loader - ): - hook = KubernetesHook() - hook.delete_custom_object( - group="group", - version="version", - plural="plural", - name="name", - namespace="namespace", - _preload_content="_preload_content", - ) - - mock_custom_object_api.return_value.delete_namespaced_custom_object.assert_called_once_with( - group="group", - version="version", - plural="plural", - name="name", - namespace="namespace", - _preload_content="_preload_content", - ) - class TestKubernetesHookIncorrectConfiguration: @pytest.mark.parametrize( diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index 7ea67aec80a95..6989337a0b276 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -16,81 +16,418 @@ # under the License. from __future__ import annotations +import json from unittest.mock import patch +from airflow import DAG +from airflow.models import Connection from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator +from airflow.utils import db, timezone +TEST_VALID_APPLICATION_YAML = """ +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + name: spark-pi + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.5" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" + sparkVersion: "2.4.5" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.5 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.5 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" +""" -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_spark_kubernetes_operator(mock_kubernetes_hook): - SparkKubernetesOperator( - task_id="task_id", - application_file="application_file", - kubernetes_conn_id="kubernetes_conn_id", - in_cluster=True, - cluster_context="cluster_context", - config_file="config_file", - ) - - mock_kubernetes_hook.assert_called_once_with( - conn_id="kubernetes_conn_id", - in_cluster=True, - cluster_context="cluster_context", - config_file="config_file", - ) - - -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream): - mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} - mock_kubernetes_hook.return_value.get_namespace.return_value = "default" - mock_stream.side_effect = [[{"type": "ADDED"}], []] - - op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") - op.execute({}) - mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with( - group="sparkoperator.k8s.io", - version="v1beta2", - plural="sparkapplications", - body={"metadata": {"name": "spark-app"}}, - namespace="default", - ) - - assert mock_stream.call_count == 2 - mock_stream.assert_any_call( - mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_pod, - namespace="default", - _preload_content=False, - watch=True, - label_selector="sparkoperator.k8s.io/app-name=spark-app,spark-role=driver", - field_selector="status.phase=Running", - ) - - mock_stream.assert_any_call( - mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log, - name="spark-app-driver", - namespace="default", - _preload_content=False, - timestamps=True, - ) - - -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") -@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") -def test_on_kill(mock_kubernetes_hook, mock_load_body_to_dict): - mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} - mock_kubernetes_hook.return_value.get_namespace.return_value = "default" - - op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") - - op.on_kill() - - mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with( - group="sparkoperator.k8s.io", - version="v1beta2", - plural="sparkapplications", - namespace="default", - name="spark-app", - ) +TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME = """ +apiVersion: "sparkoperator.k8s.io/v1beta2" +kind: SparkApplication +metadata: + generateName: spark-pi + namespace: default +spec: + type: Scala + mode: cluster + image: "gcr.io/spark-operator/spark:v2.4.5" + imagePullPolicy: Always + mainClass: org.apache.spark.examples.SparkPi + mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" + sparkVersion: "2.4.5" + restartPolicy: + type: Never + volumes: + - name: "test-volume" + hostPath: + path: "/tmp" + type: Directory + driver: + cores: 1 + coreLimit: "1200m" + memory: "512m" + labels: + version: 2.4.5 + serviceAccount: spark + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" + executor: + cores: 1 + instances: 1 + memory: "512m" + labels: + version: 2.4.5 + volumeMounts: + - name: "test-volume" + mountPath: "/tmp" +""" + +TEST_VALID_APPLICATION_JSON = """ +{ + "apiVersion":"sparkoperator.k8s.io/v1beta2", + "kind":"SparkApplication", + "metadata":{ + "name":"spark-pi", + "namespace":"default" + }, + "spec":{ + "type":"Scala", + "mode":"cluster", + "image":"gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy":"Always", + "mainClass":"org.apache.spark.examples.SparkPi", + "mainApplicationFile":"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "sparkVersion":"2.4.5", + "restartPolicy":{ + "type":"Never" + }, + "volumes":[ + { + "name":"test-volume", + "hostPath":{ + "path":"/tmp", + "type":"Directory" + } + } + ], + "driver":{ + "cores":1, + "coreLimit":"1200m", + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "serviceAccount":"spark", + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + }, + "executor":{ + "cores":1, + "instances":1, + "memory":"512m", + "labels":{ + "version":"2.4.5" + }, + "volumeMounts":[ + { + "name":"test-volume", + "mountPath":"/tmp" + } + ] + } + } +} +""" + +TEST_APPLICATION_DICT = { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"name": "spark-pi", "namespace": "default"}, + "spec": { + "driver": { + "coreLimit": "1200m", + "cores": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "serviceAccount": "spark", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "mainClass": "org.apache.spark.examples.SparkPi", + "mode": "cluster", + "restartPolicy": {"type": "Never"}, + "sparkVersion": "2.4.5", + "type": "Scala", + "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], + }, +} + +TEST_APPLICATION_DICT_WITH_GENERATE_NAME = { + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "metadata": {"generateName": "spark-pi", "namespace": "default"}, + "spec": { + "driver": { + "coreLimit": "1200m", + "cores": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "serviceAccount": "spark", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "executor": { + "cores": 1, + "instances": 1, + "labels": {"version": "2.4.5"}, + "memory": "512m", + "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], + }, + "image": "gcr.io/spark-operator/spark:v2.4.5", + "imagePullPolicy": "Always", + "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", + "mainClass": "org.apache.spark.examples.SparkPi", + "mode": "cluster", + "restartPolicy": {"type": "Never"}, + "sparkVersion": "2.4.5", + "type": "Scala", + "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], + }, +} + + +@patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") +@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") +@patch("airflow.utils.context.Context") +class TestSparkKubernetesOperator: + def setup_method(self): + db.merge_conn( + Connection( + conn_id="kubernetes_default_kube_config", + conn_type="kubernetes", + extra=json.dumps({}), + ) + ) + + db.merge_conn( + Connection( + conn_id="kubernetes_with_namespace", + conn_type="kubernetes", + extra=json.dumps({"namespace": "mock_namespace"}), + ) + ) + + args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} + self.dag = DAG("test_dag_id", default_args=args) + + def test_create_application_from_yaml( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_YAML, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_yaml_using_generate_name( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + mock_delete_namespaced_crd.assert_not_called() + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT_WITH_GENERATE_NAME, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_json( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="default", + plural="sparkapplications", + version="v1beta2", + ) + + def test_create_application_from_json_with_api_group_and_version( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + api_group = "sparkoperator.example.com" + api_version = "v1alpha1" + + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id="test_task_id", + api_group=api_group, + api_version=api_version, + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group=api_group, + namespace="default", + plural="sparkapplications", + version=api_version, + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group=api_group, + namespace="default", + plural="sparkapplications", + version=api_version, + ) + + def test_namespace_from_operator( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + namespace="operator_namespace", + kubernetes_conn_id="kubernetes_with_namespace", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="operator_namespace", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="operator_namespace", + plural="sparkapplications", + version="v1beta2", + ) + + def test_namespace_from_connection( + self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + ): + op = SparkKubernetesOperator( + application_file=TEST_VALID_APPLICATION_JSON, + dag=self.dag, + kubernetes_conn_id="kubernetes_with_namespace", + task_id="test_task_id", + ) + + op.execute(context) + mock_kubernetes_hook.assert_called_once_with() + + mock_delete_namespaced_crd.assert_called_once_with( + group="sparkoperator.k8s.io", + namespace="mock_namespace", + plural="sparkapplications", + version="v1beta2", + name=TEST_APPLICATION_DICT["metadata"]["name"], + ) + + mock_create_namespaced_crd.assert_called_with( + body=TEST_APPLICATION_DICT, + group="sparkoperator.k8s.io", + namespace="mock_namespace", + plural="sparkapplications", + version="v1beta2", + )