From f0eb33981c50413f037dcfb70d029f870cc67897 Mon Sep 17 00:00:00 2001 From: romsharon98 Date: Tue, 2 Jan 2024 15:58:41 +0200 Subject: [PATCH 1/2] move assignment of templated field to execute --- airflow/providers/papermill/operators/papermill.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index 3326b2e3c93a8..7a0a4527cdd71 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -81,15 +81,11 @@ def __init__( if not input_nb: raise ValueError("Input notebook is not specified") - elif not isinstance(input_nb, NoteBook): - self.input_nb = NoteBook(url=input_nb, parameters=self.parameters) else: self.input_nb = input_nb if not output_nb: raise ValueError("Output notebook is not specified") - elif not isinstance(output_nb, NoteBook): - self.output_nb = NoteBook(url=output_nb) else: self.output_nb = output_nb @@ -101,6 +97,10 @@ def __init__( self.outlets.append(self.output_nb) def execute(self, context: Context): + if not isinstance(self.input_nb, NoteBook): + self.input_nb = NoteBook(url=self.input_nb, parameters=self.parameters) + if not isinstance(self.output_nb, NoteBook): + self.output_nb = NoteBook(url=self.output_nb) remote_kernel_kwargs = {} kernel_hook = self.hook if kernel_hook: From 24d54e855096ccc272114df105ff3a3f8e91a9bd Mon Sep 17 00:00:00 2001 From: romsharon98 Date: Tue, 2 Jan 2024 16:37:41 +0200 Subject: [PATCH 2/2] fix tests --- .../papermill/operators/papermill.py | 11 +++----- .../papermill/operators/test_papermill.py | 25 +++++++++---------- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index 7a0a4527cdd71..b89266d130889 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -81,26 +81,23 @@ def __init__( if not input_nb: raise ValueError("Input notebook is not specified") - else: - self.input_nb = input_nb + self.input_nb = input_nb if not output_nb: raise ValueError("Output notebook is not specified") - else: - self.output_nb = output_nb + self.output_nb = output_nb self.kernel_name = kernel_name self.language_name = language_name self.kernel_conn_id = kernel_conn_id - self.inlets.append(self.input_nb) - self.outlets.append(self.output_nb) - def execute(self, context: Context): if not isinstance(self.input_nb, NoteBook): self.input_nb = NoteBook(url=self.input_nb, parameters=self.parameters) if not isinstance(self.output_nb, NoteBook): self.output_nb = NoteBook(url=self.output_nb) + self.inlets.append(self.input_nb) + self.outlets.append(self.output_nb) remote_kernel_kwargs = {} kernel_hook = self.hook if kernel_hook: diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index 03fd0dc74e17b..e734bb8ed61a4 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -69,12 +69,21 @@ def test_mandatory_attributes(self): pytest.param(NoteBook(TEST_INPUT_URL), id="input-as-notebook-object"), ], ) - def test_notebooks_objects(self, input_nb, output_nb): + @patch("airflow.providers.papermill.operators.papermill.pm") + @patch("airflow.providers.papermill.operators.papermill.PapermillOperator.hook") + def test_notebooks_objects(self, mock_papermill, mock_hook, input_nb, output_nb): """Test different type of Input/Output notebooks arguments.""" op = PapermillOperator(task_id="test_notebooks_objects", input_nb=input_nb, output_nb=output_nb) + + op.execute(None) + assert op.input_nb.url == TEST_INPUT_URL assert op.output_nb.url == TEST_OUTPUT_URL + # Test render Lineage inlets/outlets + assert op.inlets[0] == op.input_nb + assert op.outlets[0] == op.output_nb + @patch("airflow.providers.papermill.operators.papermill.pm") def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" @@ -173,19 +182,9 @@ def test_render_template(self, create_task_instance_of_operator): task = ti.render_templates() # Test render Input/Output notebook attributes - assert task.input_nb.url == "/tmp/test_render_template.ipynb" - assert task.input_nb.parameters == { - "msgs": "dag id is test_render_template!", - "test_dt": DEFAULT_DATE.date().isoformat(), - } - assert task.output_nb.url == "/tmp/out-test_render_template.ipynb" - assert task.output_nb.parameters == {} + assert task.input_nb == "/tmp/test_render_template.ipynb" + assert task.output_nb == "/tmp/out-test_render_template.ipynb" # Test render other templated attributes - assert task.parameters == task.input_nb.parameters assert "python3" == task.kernel_name assert "python" == task.language_name - - # Test render Lineage inlets/outlets - assert task.inlets[0] == task.input_nb - assert task.outlets[0] == task.output_nb