From a5a99021573363035f86e414184d5e1d923d53e4 Mon Sep 17 00:00:00 2001 From: Andrey Anshin Date: Mon, 12 Feb 2024 17:59:14 +0400 Subject: [PATCH] Fix rendering `LivyOperator.spark_params` --- .../providers/apache/livy/operators/livy.py | 8 +++- .../apache/livy/operators/test_livy.py | 48 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py index 5d9d1126a8d46..34c8a73eb2586 100644 --- a/airflow/providers/apache/livy/operators/livy.py +++ b/airflow/providers/apache/livy/operators/livy.py @@ -64,6 +64,9 @@ class LivyOperator(BaseOperator): See Tenacity documentation at https://github.com/jd/tenacity """ + template_fields: Sequence[str] = ("spark_params",) + template_fields_renderers = {"spark_params": "json"} + def __init__( self, *, @@ -94,7 +97,8 @@ def __init__( ) -> None: super().__init__(**kwargs) - self.spark_params = { + spark_params = { + # Prepare spark parameters, it will be templated later. "file": file, "class_name": class_name, "args": args, @@ -112,7 +116,7 @@ def __init__( "conf": conf, "proxy_user": proxy_user, } - + self.spark_params = spark_params self._livy_conn_id = livy_conn_id self._livy_conn_auth_type = livy_conn_auth_type self._polling_interval = polling_interval diff --git a/tests/providers/apache/livy/operators/test_livy.py b/tests/providers/apache/livy/operators/test_livy.py index 156f4f4d03009..02e8231eb2896 100644 --- a/tests/providers/apache/livy/operators/test_livy.py +++ b/tests/providers/apache/livy/operators/test_livy.py @@ -379,3 +379,51 @@ def test_execute_complete_error(self, mock_post): }, ) self.mock_context["ti"].xcom_push.assert_not_called() + + +@pytest.mark.db_test +def test_spark_params_templating(create_task_instance_of_operator): + ti = create_task_instance_of_operator( + LivyOperator, + # Templated fields + file="{{ 'literal-file' }}", + class_name="{{ 'literal-class-name' }}", + args="{{ 'literal-args' }}", + jars="{{ 'literal-jars' }}", + py_files="{{ 'literal-py-files' }}", + files="{{ 'literal-files' }}", + driver_memory="{{ 'literal-driver-memory' }}", + driver_cores="{{ 'literal-driver-cores' }}", + executor_memory="{{ 'literal-executor-memory' }}", + executor_cores="{{ 'literal-executor-cores' }}", + num_executors="{{ 'literal-num-executors' }}", + archives="{{ 'literal-archives' }}", + queue="{{ 'literal-queue' }}", + name="{{ 'literal-name' }}", + conf="{{ 'literal-conf' }}", + proxy_user="{{ 'literal-proxy-user' }}", + # Other parameters + dag_id="test_template_body_templating_dag", + task_id="test_template_body_templating_task", + execution_date=timezone.datetime(2024, 2, 1, tzinfo=timezone.utc), + ) + ti.render_templates() + task: LivyOperator = ti.task + assert task.spark_params == { + "archives": "literal-archives", + "args": "literal-args", + "class_name": "literal-class-name", + "conf": "literal-conf", + "driver_cores": "literal-driver-cores", + "driver_memory": "literal-driver-memory", + "executor_cores": "literal-executor-cores", + "executor_memory": "literal-executor-memory", + "file": "literal-file", + "files": "literal-files", + "jars": "literal-jars", + "name": "literal-name", + "num_executors": "literal-num-executors", + "proxy_user": "literal-proxy-user", + "py_files": "literal-py-files", + "queue": "literal-queue", + }