diff --git a/airflow/providers/papermill/operators/papermill.py b/airflow/providers/papermill/operators/papermill.py index 3326b2e3c93a8..b89266d130889 100644 --- a/airflow/providers/papermill/operators/papermill.py +++ b/airflow/providers/papermill/operators/papermill.py @@ -81,26 +81,23 @@ def __init__( if not input_nb: raise ValueError("Input notebook is not specified") - elif not isinstance(input_nb, NoteBook): - self.input_nb = NoteBook(url=input_nb, parameters=self.parameters) - else: - self.input_nb = input_nb + self.input_nb = input_nb if not output_nb: raise ValueError("Output notebook is not specified") - elif not isinstance(output_nb, NoteBook): - self.output_nb = NoteBook(url=output_nb) - else: - self.output_nb = output_nb + self.output_nb = output_nb self.kernel_name = kernel_name self.language_name = language_name self.kernel_conn_id = kernel_conn_id + def execute(self, context: Context): + if not isinstance(self.input_nb, NoteBook): + self.input_nb = NoteBook(url=self.input_nb, parameters=self.parameters) + if not isinstance(self.output_nb, NoteBook): + self.output_nb = NoteBook(url=self.output_nb) self.inlets.append(self.input_nb) self.outlets.append(self.output_nb) - - def execute(self, context: Context): remote_kernel_kwargs = {} kernel_hook = self.hook if kernel_hook: diff --git a/tests/providers/papermill/operators/test_papermill.py b/tests/providers/papermill/operators/test_papermill.py index 03fd0dc74e17b..e734bb8ed61a4 100644 --- a/tests/providers/papermill/operators/test_papermill.py +++ b/tests/providers/papermill/operators/test_papermill.py @@ -69,12 +69,21 @@ def test_mandatory_attributes(self): pytest.param(NoteBook(TEST_INPUT_URL), id="input-as-notebook-object"), ], ) - def test_notebooks_objects(self, input_nb, output_nb): + @patch("airflow.providers.papermill.operators.papermill.pm") + @patch("airflow.providers.papermill.operators.papermill.PapermillOperator.hook") + def test_notebooks_objects(self, mock_papermill, mock_hook, input_nb, output_nb): """Test different type of Input/Output notebooks arguments.""" op = PapermillOperator(task_id="test_notebooks_objects", input_nb=input_nb, output_nb=output_nb) + + op.execute(None) + assert op.input_nb.url == TEST_INPUT_URL assert op.output_nb.url == TEST_OUTPUT_URL + # Test render Lineage inlets/outlets + assert op.inlets[0] == op.input_nb + assert op.outlets[0] == op.output_nb + @patch("airflow.providers.papermill.operators.papermill.pm") def test_execute(self, mock_papermill): in_nb = "/tmp/does_not_exist" @@ -173,19 +182,9 @@ def test_render_template(self, create_task_instance_of_operator): task = ti.render_templates() # Test render Input/Output notebook attributes - assert task.input_nb.url == "/tmp/test_render_template.ipynb" - assert task.input_nb.parameters == { - "msgs": "dag id is test_render_template!", - "test_dt": DEFAULT_DATE.date().isoformat(), - } - assert task.output_nb.url == "/tmp/out-test_render_template.ipynb" - assert task.output_nb.parameters == {} + assert task.input_nb == "/tmp/test_render_template.ipynb" + assert task.output_nb == "/tmp/out-test_render_template.ipynb" # Test render other templated attributes - assert task.parameters == task.input_nb.parameters assert "python3" == task.kernel_name assert "python" == task.language_name - - # Test render Lineage inlets/outlets - assert task.inlets[0] == task.input_nb - assert task.outlets[0] == task.output_nb