From 78450c75f6e2318fba529453f0f16a46ed2ab349 Mon Sep 17 00:00:00 2001 From: Hossein Torabi Date: Tue, 7 Mar 2023 16:32:54 +0100 Subject: [PATCH] enhance spark_k8s_operator --- .../cncf/kubernetes/hooks/kubernetes.py | 73 +-- .../kubernetes/operators/spark_kubernetes.py | 65 ++- .../flink/operators/test_flink_kubernetes.py | 63 +-- .../kubernetes/hooks/test_kubernetes_pod.py | 25 + .../operators/test_spark_kubernetes.py | 483 +++--------------- 5 files changed, 205 insertions(+), 504 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py index 469a8d6706548..f2866e5d714b6 100644 --- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py +++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py @@ -37,7 +37,7 @@ LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..." -def _load_body_to_dict(body): +def _load_body_to_dict(body: str) -> dict: try: body_dict = yaml.safe_load(body) except yaml.YAMLError as e: @@ -272,37 +272,22 @@ def create_custom_object( :param namespace: kubernetes namespace """ api: client.CustomObjectsApi = self.custom_object_client - namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE if isinstance(body, str): body_dict = _load_body_to_dict(body) else: body_dict = body - # Attribute "name" is not mandatory if "generateName" is used instead - if "name" in body_dict["metadata"]: - try: - api.delete_namespaced_custom_object( - group=group, - version=version, - namespace=namespace, - plural=plural, - name=body_dict["metadata"]["name"], - ) - - self.log.warning("Deleted SparkApplication with the same name") - except client.rest.ApiException: - self.log.info("SparkApplication %s not found", body_dict["metadata"]["name"]) - - try: - response = api.create_namespaced_custom_object( - group=group, version=version, namespace=namespace, plural=plural, body=body_dict - ) + response = api.create_namespaced_custom_object( + group=group, + version=version, + namespace=namespace or self.get_namespace(), + plural=plural, + body=body_dict, + ) - self.log.debug("Response: %s", response) - return response - except client.rest.ApiException as e: - raise AirflowException(f"Exception when calling -> create_custom_object: {e}\n") + self.log.debug("Response: %s", response) + return response def get_custom_object( self, group: str, version: str, plural: str, name: str, namespace: str | None = None @@ -317,14 +302,36 @@ def get_custom_object( :param namespace: kubernetes namespace """ api = client.CustomObjectsApi(self.api_client) - namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE - try: - response = api.get_namespaced_custom_object( - group=group, version=version, namespace=namespace, plural=plural, name=name - ) - return response - except client.rest.ApiException as e: - raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n") + response = api.get_namespaced_custom_object( + group=group, + version=version, + namespace=namespace or self.get_namespace(), + plural=plural, + name=name, + ) + return response + + def delete_custom_object( + self, group: str, version: str, plural: str, name: str, namespace: str | None = None, **kwargs + ): + """ + Delete custom resource definition object from Kubernetes + + :param group: api group + :param version: api version + :param plural: api plural + :param name: crd object name + :param namespace: kubernetes namespace + """ + api = client.CustomObjectsApi(self.api_client) + return api.delete_namespaced_custom_object( + group=group, + version=version, + namespace=namespace or self.get_namespace(), + plural=plural, + name=name, + **kwargs, + ) def get_namespace(self) -> str | None: """ diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index ff3828ffb0d4a..bb376967115a0 100644 --- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,8 +19,10 @@ from typing import TYPE_CHECKING, Sequence +from kubernetes.watch import Watch + from airflow.models import BaseOperator -from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook +from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict if TYPE_CHECKING: from airflow.utils.context import Context @@ -55,24 +57,71 @@ def __init__( kubernetes_conn_id: str = "kubernetes_default", api_group: str = "sparkoperator.k8s.io", api_version: str = "v1beta2", + in_cluster: bool | None = None, + cluster_context: str | None = None, + config_file: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) - self.application_file = application_file self.namespace = namespace self.kubernetes_conn_id = kubernetes_conn_id self.api_group = api_group self.api_version = api_version self.plural = "sparkapplications" + self.application_file = application_file + self.in_cluster = in_cluster + self.cluster_context = cluster_context + self.config_file = config_file + + self.hook = KubernetesHook( + conn_id=self.kubernetes_conn_id, + in_cluster=self.in_cluster, + config_file=self.config_file, + cluster_context=self.cluster_context, + ) def execute(self, context: Context): - hook = KubernetesHook(conn_id=self.kubernetes_conn_id) - self.log.info("Creating sparkApplication") - response = hook.create_custom_object( + body = _load_body_to_dict(self.application_file) + name = body["metadata"]["name"] + namespace = self.namespace or self.hook.get_namespace() + namespace_event_stream = Watch().stream( + self.hook.core_v1_client.list_namespaced_pod, + namespace=namespace, + _preload_content=False, + watch=True, + label_selector=f"sparkoperator.k8s.io/app-name={name},spark-role=driver", + field_selector="status.phase=Running", + ) + + self.hook.create_custom_object( + group=self.api_group, + version=self.api_version, + plural=self.plural, + body=body, + namespace=namespace, + ) + for event in namespace_event_stream: + if event["type"] == "ADDED": + pod_log_stream = Watch().stream( + self.hook.core_v1_client.read_namespaced_pod_log, + name=f"{name}-driver", + namespace=namespace, + _preload_content=False, + timestamps=True, + ) + for line in pod_log_stream: + self.log.info(line) + else: + break + + def on_kill(self) -> None: + body = _load_body_to_dict(self.application_file) + name = body["metadata"]["name"] + namespace = self.namespace or self.hook.get_namespace() + self.hook.delete_custom_object( group=self.api_group, version=self.api_version, plural=self.plural, - body=self.application_file, - namespace=self.namespace, + namespace=namespace, + name=name, ) - return response diff --git a/tests/providers/apache/flink/operators/test_flink_kubernetes.py b/tests/providers/apache/flink/operators/test_flink_kubernetes.py index 598c6edd94c42..3e7cdb01f334e 100644 --- a/tests/providers/apache/flink/operators/test_flink_kubernetes.py +++ b/tests/providers/apache/flink/operators/test_flink_kubernetes.py @@ -197,11 +197,8 @@ def setup_method(self): args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} self.dag = DAG("test_dag_id", default_args=args) - @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_yaml( - self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): + def test_create_application_from_yaml(self, mock_create_namespaced_crd, mock_kubernetes_hook): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_YAML, dag=self.dag, @@ -210,13 +207,7 @@ def test_create_application_from_yaml( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_called_once_with( - group="flink.apache.org", - namespace="default", - plural="flinkdeployments", - version="v1beta1", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) + mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -225,11 +216,8 @@ def test_create_application_from_yaml( version="v1beta1", ) - @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_create_application_from_json( - self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): + def test_create_application_from_json(self, mock_create_namespaced_crd, mock_kubernetes_hook): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -238,13 +226,7 @@ def test_create_application_from_json( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_called_once_with( - group="flink.apache.org", - namespace="default", - plural="flinkdeployments", - version="v1beta1", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) + mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -253,10 +235,9 @@ def test_create_application_from_json( version="v1beta1", ) - @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") def test_create_application_from_json_with_api_group_and_version( - self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook + self, mock_create_namespaced_crd, mock_kubernetes_hook ): api_group = "flink.apache.org" api_version = "v1beta1" @@ -270,13 +251,7 @@ def test_create_application_from_json_with_api_group_and_version( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_called_once_with( - group=api_group, - namespace="default", - plural="flinkdeployments", - version=api_version, - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) + mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group=api_group, @@ -285,11 +260,8 @@ def test_create_application_from_json_with_api_group_and_version( version=api_version, ) - @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_operator( - self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): + def test_namespace_from_operator(self, mock_create_namespaced_crd, mock_kubernetes_hook): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -299,13 +271,7 @@ def test_namespace_from_operator( ) op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_called_once_with( - group="flink.apache.org", - namespace="operator_namespace", - plural="flinkdeployments", - version="v1beta1", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) + mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", @@ -314,11 +280,8 @@ def test_namespace_from_operator( version="v1beta1", ) - @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") @patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") - def test_namespace_from_connection( - self, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): + def test_namespace_from_connection(self, mock_create_namespaced_crd, mock_kubernetes_hook): op = FlinkKubernetesOperator( application_file=TEST_VALID_APPLICATION_JSON, dag=self.dag, @@ -328,13 +291,7 @@ def test_namespace_from_connection( op.execute(None) mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_called_once_with( - group="flink.apache.org", - namespace="mock_namespace", - plural="flinkdeployments", - version="v1beta1", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) + mock_create_namespaced_crd.assert_called_with( body=TEST_APPLICATION_DICT, group="flink.apache.org", diff --git a/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py b/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py index 2cd19383713c8..204060449768d 100644 --- a/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py +++ b/tests/providers/cncf/kubernetes/hooks/test_kubernetes_pod.py @@ -358,6 +358,31 @@ def test_prefixed_names_still_work(self, mock_get_client): mock_get_client.assert_called_with(cluster_context="test") assert kubernetes_hook.get_namespace() == "test" + @patch("kubernetes.config.kube_config.KubeConfigLoader") + @patch("kubernetes.config.kube_config.KubeConfigMerger") + @patch(f"{HOOK_MODULE}.client.CustomObjectsApi") + def test_delete_custom_object( + self, mock_custom_object_api, mock_kube_config_merger, mock_kube_config_loader + ): + hook = KubernetesHook() + hook.delete_custom_object( + group="group", + version="version", + plural="plural", + name="name", + namespace="namespace", + _preload_content="_preload_content", + ) + + mock_custom_object_api.return_value.delete_namespaced_custom_object.assert_called_once_with( + group="group", + version="version", + plural="plural", + name="name", + namespace="namespace", + _preload_content="_preload_content", + ) + class TestKubernetesHookIncorrectConfiguration: @pytest.mark.parametrize( diff --git a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py index 6989337a0b276..7ea67aec80a95 100644 --- a/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -16,418 +16,81 @@ # under the License. from __future__ import annotations -import json from unittest.mock import patch -from airflow import DAG -from airflow.models import Connection from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator -from airflow.utils import db, timezone -TEST_VALID_APPLICATION_YAML = """ -apiVersion: "sparkoperator.k8s.io/v1beta2" -kind: SparkApplication -metadata: - name: spark-pi - namespace: default -spec: - type: Scala - mode: cluster - image: "gcr.io/spark-operator/spark:v2.4.5" - imagePullPolicy: Always - mainClass: org.apache.spark.examples.SparkPi - mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" - sparkVersion: "2.4.5" - restartPolicy: - type: Never - volumes: - - name: "test-volume" - hostPath: - path: "/tmp" - type: Directory - driver: - cores: 1 - coreLimit: "1200m" - memory: "512m" - labels: - version: 2.4.5 - serviceAccount: spark - volumeMounts: - - name: "test-volume" - mountPath: "/tmp" - executor: - cores: 1 - instances: 1 - memory: "512m" - labels: - version: 2.4.5 - volumeMounts: - - name: "test-volume" - mountPath: "/tmp" -""" -TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME = """ -apiVersion: "sparkoperator.k8s.io/v1beta2" -kind: SparkApplication -metadata: - generateName: spark-pi - namespace: default -spec: - type: Scala - mode: cluster - image: "gcr.io/spark-operator/spark:v2.4.5" - imagePullPolicy: Always - mainClass: org.apache.spark.examples.SparkPi - mainApplicationFile: "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar" - sparkVersion: "2.4.5" - restartPolicy: - type: Never - volumes: - - name: "test-volume" - hostPath: - path: "/tmp" - type: Directory - driver: - cores: 1 - coreLimit: "1200m" - memory: "512m" - labels: - version: 2.4.5 - serviceAccount: spark - volumeMounts: - - name: "test-volume" - mountPath: "/tmp" - executor: - cores: 1 - instances: 1 - memory: "512m" - labels: - version: 2.4.5 - volumeMounts: - - name: "test-volume" - mountPath: "/tmp" -""" - -TEST_VALID_APPLICATION_JSON = """ -{ - "apiVersion":"sparkoperator.k8s.io/v1beta2", - "kind":"SparkApplication", - "metadata":{ - "name":"spark-pi", - "namespace":"default" - }, - "spec":{ - "type":"Scala", - "mode":"cluster", - "image":"gcr.io/spark-operator/spark:v2.4.5", - "imagePullPolicy":"Always", - "mainClass":"org.apache.spark.examples.SparkPi", - "mainApplicationFile":"local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", - "sparkVersion":"2.4.5", - "restartPolicy":{ - "type":"Never" - }, - "volumes":[ - { - "name":"test-volume", - "hostPath":{ - "path":"/tmp", - "type":"Directory" - } - } - ], - "driver":{ - "cores":1, - "coreLimit":"1200m", - "memory":"512m", - "labels":{ - "version":"2.4.5" - }, - "serviceAccount":"spark", - "volumeMounts":[ - { - "name":"test-volume", - "mountPath":"/tmp" - } - ] - }, - "executor":{ - "cores":1, - "instances":1, - "memory":"512m", - "labels":{ - "version":"2.4.5" - }, - "volumeMounts":[ - { - "name":"test-volume", - "mountPath":"/tmp" - } - ] - } - } -} -""" - -TEST_APPLICATION_DICT = { - "apiVersion": "sparkoperator.k8s.io/v1beta2", - "kind": "SparkApplication", - "metadata": {"name": "spark-pi", "namespace": "default"}, - "spec": { - "driver": { - "coreLimit": "1200m", - "cores": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "serviceAccount": "spark", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "executor": { - "cores": 1, - "instances": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "image": "gcr.io/spark-operator/spark:v2.4.5", - "imagePullPolicy": "Always", - "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", - "mainClass": "org.apache.spark.examples.SparkPi", - "mode": "cluster", - "restartPolicy": {"type": "Never"}, - "sparkVersion": "2.4.5", - "type": "Scala", - "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], - }, -} - -TEST_APPLICATION_DICT_WITH_GENERATE_NAME = { - "apiVersion": "sparkoperator.k8s.io/v1beta2", - "kind": "SparkApplication", - "metadata": {"generateName": "spark-pi", "namespace": "default"}, - "spec": { - "driver": { - "coreLimit": "1200m", - "cores": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "serviceAccount": "spark", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "executor": { - "cores": 1, - "instances": 1, - "labels": {"version": "2.4.5"}, - "memory": "512m", - "volumeMounts": [{"mountPath": "/tmp", "name": "test-volume"}], - }, - "image": "gcr.io/spark-operator/spark:v2.4.5", - "imagePullPolicy": "Always", - "mainApplicationFile": "local:///opt/spark/examples/jars/spark-examples_2.11-2.4.5.jar", - "mainClass": "org.apache.spark.examples.SparkPi", - "mode": "cluster", - "restartPolicy": {"type": "Never"}, - "sparkVersion": "2.4.5", - "type": "Scala", - "volumes": [{"hostPath": {"path": "/tmp", "type": "Directory"}, "name": "test-volume"}], - }, -} - - -@patch("airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook.get_conn") -@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.delete_namespaced_custom_object") -@patch("kubernetes.client.api.custom_objects_api.CustomObjectsApi.create_namespaced_custom_object") -@patch("airflow.utils.context.Context") -class TestSparkKubernetesOperator: - def setup_method(self): - db.merge_conn( - Connection( - conn_id="kubernetes_default_kube_config", - conn_type="kubernetes", - extra=json.dumps({}), - ) - ) - - db.merge_conn( - Connection( - conn_id="kubernetes_with_namespace", - conn_type="kubernetes", - extra=json.dumps({"namespace": "mock_namespace"}), - ) - ) - - args = {"owner": "airflow", "start_date": timezone.datetime(2020, 2, 1)} - self.dag = DAG("test_dag_id", default_args=args) - - def test_create_application_from_yaml( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_YAML, - dag=self.dag, - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="test_task_id", - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - - mock_delete_namespaced_crd.assert_called_once_with( - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) - - def test_create_application_from_yaml_using_generate_name( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_YAML_USING_GENERATE_NAME, - dag=self.dag, - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="test_task_id", - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - mock_delete_namespaced_crd.assert_not_called() - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT_WITH_GENERATE_NAME, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) - - def test_create_application_from_json( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="test_task_id", - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - - mock_delete_namespaced_crd.assert_called_once_with( - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="default", - plural="sparkapplications", - version="v1beta2", - ) - - def test_create_application_from_json_with_api_group_and_version( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - api_group = "sparkoperator.example.com" - api_version = "v1alpha1" - - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - kubernetes_conn_id="kubernetes_default_kube_config", - task_id="test_task_id", - api_group=api_group, - api_version=api_version, - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - - mock_delete_namespaced_crd.assert_called_once_with( - group=api_group, - namespace="default", - plural="sparkapplications", - version=api_version, - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group=api_group, - namespace="default", - plural="sparkapplications", - version=api_version, - ) - - def test_namespace_from_operator( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - namespace="operator_namespace", - kubernetes_conn_id="kubernetes_with_namespace", - task_id="test_task_id", - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - - mock_delete_namespaced_crd.assert_called_once_with( - group="sparkoperator.k8s.io", - namespace="operator_namespace", - plural="sparkapplications", - version="v1beta2", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="operator_namespace", - plural="sparkapplications", - version="v1beta2", - ) - - def test_namespace_from_connection( - self, context, mock_create_namespaced_crd, mock_delete_namespaced_crd, mock_kubernetes_hook - ): - op = SparkKubernetesOperator( - application_file=TEST_VALID_APPLICATION_JSON, - dag=self.dag, - kubernetes_conn_id="kubernetes_with_namespace", - task_id="test_task_id", - ) - - op.execute(context) - mock_kubernetes_hook.assert_called_once_with() - - mock_delete_namespaced_crd.assert_called_once_with( - group="sparkoperator.k8s.io", - namespace="mock_namespace", - plural="sparkapplications", - version="v1beta2", - name=TEST_APPLICATION_DICT["metadata"]["name"], - ) - - mock_create_namespaced_crd.assert_called_with( - body=TEST_APPLICATION_DICT, - group="sparkoperator.k8s.io", - namespace="mock_namespace", - plural="sparkapplications", - version="v1beta2", - ) +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") +def test_spark_kubernetes_operator(mock_kubernetes_hook): + SparkKubernetesOperator( + task_id="task_id", + application_file="application_file", + kubernetes_conn_id="kubernetes_conn_id", + in_cluster=True, + cluster_context="cluster_context", + config_file="config_file", + ) + + mock_kubernetes_hook.assert_called_once_with( + conn_id="kubernetes_conn_id", + in_cluster=True, + cluster_context="cluster_context", + config_file="config_file", + ) + + +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.Watch.stream") +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") +def test_execute(mock_kubernetes_hook, mock_load_body_to_dict, mock_stream): + mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} + mock_kubernetes_hook.return_value.get_namespace.return_value = "default" + mock_stream.side_effect = [[{"type": "ADDED"}], []] + + op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") + op.execute({}) + mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with( + group="sparkoperator.k8s.io", + version="v1beta2", + plural="sparkapplications", + body={"metadata": {"name": "spark-app"}}, + namespace="default", + ) + + assert mock_stream.call_count == 2 + mock_stream.assert_any_call( + mock_kubernetes_hook.return_value.core_v1_client.list_namespaced_pod, + namespace="default", + _preload_content=False, + watch=True, + label_selector="sparkoperator.k8s.io/app-name=spark-app,spark-role=driver", + field_selector="status.phase=Running", + ) + + mock_stream.assert_any_call( + mock_kubernetes_hook.return_value.core_v1_client.read_namespaced_pod_log, + name="spark-app-driver", + namespace="default", + _preload_content=False, + timestamps=True, + ) + + +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes._load_body_to_dict") +@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook") +def test_on_kill(mock_kubernetes_hook, mock_load_body_to_dict): + mock_load_body_to_dict.return_value = {"metadata": {"name": "spark-app"}} + mock_kubernetes_hook.return_value.get_namespace.return_value = "default" + + op = SparkKubernetesOperator(task_id="task_id", application_file="application_file") + + op.on_kill() + + mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with( + group="sparkoperator.k8s.io", + version="v1beta2", + plural="sparkapplications", + namespace="default", + name="spark-app", + )