diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py index 5d25866eb86b5..8c8d02f1ecc41 100644 --- a/airflow/providers/apache/spark/hooks/spark_jdbc.py +++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py @@ -83,6 +83,8 @@ class SparkJDBCHook(SparkSubmitHook): (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. + :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying + on keytab for Kerberos login """ conn_name_attr = "spark_conn_id" @@ -121,6 +123,7 @@ def __init__( upper_bound: str | None = None, create_table_column_types: str | None = None, *args: Any, + use_krb5ccache: bool = False, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -153,6 +156,7 @@ def __init__( self._upper_bound = upper_bound self._create_table_column_types = create_table_column_types self._jdbc_connection = self._resolve_jdbc_connection() + self._use_krb5ccache = use_krb5ccache def _resolve_jdbc_connection(self) -> dict[str, Any]: conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""} diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py index 70eb224828065..4b4dd648a67cd 100644 --- a/airflow/providers/apache/spark/operators/spark_jdbc.py +++ b/airflow/providers/apache/spark/operators/spark_jdbc.py @@ -91,6 +91,9 @@ class SparkJDBCOperator(SparkSubmitOperator): (e.g: "name CHAR(64), comments VARCHAR(1024)"). The specified types should be valid spark sql data types. + :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying + on keytab for Kerberos login + """ def __init__( @@ -124,6 +127,7 @@ def __init__( lower_bound: str | None = None, upper_bound: str | None = None, create_table_column_types: str | None = None, + use_krb5ccache: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -156,6 +160,7 @@ def __init__( self._upper_bound = upper_bound self._create_table_column_types = create_table_column_types self._hook: SparkJDBCHook | None = None + self._use_krb5ccache = use_krb5ccache def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" @@ -198,4 +203,5 @@ def _get_hook(self) -> SparkJDBCHook: lower_bound=self._lower_bound, upper_bound=self._upper_bound, create_table_column_types=self._create_table_column_types, + use_krb5ccache=self._use_krb5ccache, ) diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py index 0620a44c0ea46..903ff9b7207da 100644 --- a/airflow/providers/apache/spark/operators/spark_submit.py +++ b/airflow/providers/apache/spark/operators/spark_submit.py @@ -69,6 +69,8 @@ class SparkSubmitOperator(BaseOperator): :param verbose: Whether to pass the verbose flag to spark-submit process for debugging :param spark_binary: The command to use for spark submit. Some distros may use spark2-submit or spark3-submit. + :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying + on keytab for Kerberos login """ template_fields: Sequence[str] = ( @@ -118,6 +120,7 @@ def __init__( env_vars: dict[str, Any] | None = None, verbose: bool = False, spark_binary: str | None = None, + use_krb5ccache: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -148,6 +151,7 @@ def __init__( self._spark_binary = spark_binary self._hook: SparkSubmitHook | None = None self._conn_id = conn_id + self._use_krb5ccache = use_krb5ccache def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" @@ -187,4 +191,5 @@ def _get_hook(self) -> SparkSubmitHook: env_vars=self._env_vars, verbose=self._verbose, spark_binary=self._spark_binary, + use_krb5ccache=self._use_krb5ccache, ) diff --git a/tests/providers/apache/spark/operators/test_spark_jdbc.py b/tests/providers/apache/spark/operators/test_spark_jdbc.py index 9060fd63b72bb..bccc9e536753d 100644 --- a/tests/providers/apache/spark/operators/test_spark_jdbc.py +++ b/tests/providers/apache/spark/operators/test_spark_jdbc.py @@ -53,6 +53,7 @@ class TestSparkJDBCOperator: "upper_bound": "20", "create_table_column_types": "columnMcColumnFace INTEGER(100), name CHAR(64)," "comments VARCHAR(1024)", + "use_krb5ccache": True, } def setup_method(self): @@ -95,6 +96,7 @@ def test_execute(self): "upper_bound": "20", "create_table_column_types": "columnMcColumnFace INTEGER(100), name CHAR(64)," "comments VARCHAR(1024)", + "use_krb5ccache": True, } assert spark_conn_id == operator._spark_conn_id @@ -125,3 +127,4 @@ def test_execute(self): assert expected_dict["lower_bound"] == operator._lower_bound assert expected_dict["upper_bound"] == operator._upper_bound assert expected_dict["create_table_column_types"] == operator._create_table_column_types + assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache diff --git a/tests/providers/apache/spark/operators/test_spark_submit.py b/tests/providers/apache/spark/operators/test_spark_submit.py index cf131fdc1a876..3c6aa78336229 100644 --- a/tests/providers/apache/spark/operators/test_spark_submit.py +++ b/tests/providers/apache/spark/operators/test_spark_submit.py @@ -65,6 +65,7 @@ class TestSparkSubmitOperator: "--with-spaces", "args should keep embedded spaces", ], + "use_krb5ccache": True, } def setup_method(self): @@ -75,7 +76,10 @@ def test_execute(self): # Given / When conn_id = "spark_default" operator = SparkSubmitOperator( - task_id="spark_submit_job", spark_binary="sparky", dag=self.dag, **self._config + task_id="spark_submit_job", + spark_binary="sparky", + dag=self.dag, + **self._config, ) # Then expected results @@ -115,6 +119,7 @@ def test_execute(self): "args should keep embedded spaces", ], "spark_binary": "sparky", + "use_krb5ccache": True, } assert conn_id == operator._conn_id @@ -142,6 +147,7 @@ def test_execute(self): assert expected_dict["driver_memory"] == operator._driver_memory assert expected_dict["application_args"] == operator._application_args assert expected_dict["spark_binary"] == operator._spark_binary + assert expected_dict["use_krb5ccache"] == operator._use_krb5ccache @pytest.mark.db_test def test_render_template(self):