From 3036c0e61e363e8d2976d3ad860884cdfe8b3c88 Mon Sep 17 00:00:00 2001 From: Zack Strathe Date: Wed, 3 Apr 2024 10:52:26 -0500 Subject: [PATCH 1/4] Bugfix to correct GCSHook being called even when not required with BeamRunPythonPipelineOperator --- airflow/providers/apache/beam/operators/beam.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index e88923bc05374..daa2a69982c50 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -364,11 +364,13 @@ def execute(self, context: Context): def execute_sync(self, context: Context): with ExitStack() as exit_stack: - gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) if self.py_file.lower().startswith("gs://"): + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name if self.snake_case_pipeline_options.get("requirements_file", "").startswith("gs://"): + if 'gcs_hook' not in locals(): + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_req_file = exit_stack.enter_context( gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"]) ) From 566f07360b2c2393d64b6432e14d0e828d758b95 Mon Sep 17 00:00:00 2001 From: Zack Strathe Date: Sun, 14 Apr 2024 20:06:11 -0500 Subject: [PATCH 2/4] remove unneccary check for GCSHook and add unit test for BeamRunPythonPipelineOperator to ensure that GCSHook is only called when necessary --- .../providers/apache/beam/operators/beam.py | 3 +- .../apache/beam/operators/test_beam.py | 53 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) mode change 100644 => 100755 tests/providers/apache/beam/operators/test_beam.py diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index daa2a69982c50..6187e7971dcdf 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -369,8 +369,7 @@ def execute_sync(self, context: Context): tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name if self.snake_case_pipeline_options.get("requirements_file", "").startswith("gs://"): - if 'gcs_hook' not in locals(): - gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_req_file = exit_stack.enter_context( gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"]) ) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py old mode 100644 new mode 100755 index f7ca9649fb71b..b98ec9c10d351 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -256,6 +256,59 @@ def test_on_kill_direct_runner(self, _, dataflow_mock, __): op.on_kill() dataflow_cancel_job.assert_not_called() + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_called_only_with_gs_prefix(self, mock_gcs_hook, _): + local_test_op_args = { + "task_id": TASK_ID, + "py_file": 'local_file.py', + "py_options": ['-m'], + "default_pipeline_options": { + "project": TEST_PROJECT, + 'requirements_file': 'local_requirements.txt' + }, + "pipeline_options": { + "output": 'test_local/output', "labels": {"foo": "bar"} + } + } + + """ + Test that execute method does not call GCSHook when neither py_file nor requirements_file + starts with 'gs://'. + """ + test_kwargs_local = copy.deepcopy(local_test_op_args) + op = BeamRunPythonPipelineOperator(**test_kwargs_local) + context_mock = mock.MagicMock() + + op.execute(context_mock) + mock_gcs_hook.assert_not_called() + mock_gcs_hook.reset_mock() + + """ + Test that execute method calls GCSHook when only 'py_file' starts with 'gs://'. + """ + test_kwargs_local = copy.deepcopy(local_test_op_args) + test_kwargs_local['py_file'] = 'gs://gcs_file.py' + op = BeamRunPythonPipelineOperator(**test_kwargs_local) + context_mock = mock.MagicMock() + op.execute(context_mock) + mock_gcs_hook.assert_called_once() + mock_gcs_hook.reset_mock() + + """ + Test that execute calls GCSHook when only pipeline_options 'requirements_file' starts with 'gs://'. + Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when + BeamRunPythonPipelineOperator is instantiated, so testing 'requirements_file' specified + in "pipeline_options" + """ + test_kwargs_local = copy.deepcopy(local_test_op_args) + test_kwargs_local['pipeline_options']['requirements_file'] = 'gs://gcs_requirements.txt' + op = BeamRunPythonPipelineOperator(**test_kwargs_local) + context_mock = mock.MagicMock() + op.execute(context_mock) + mock_gcs_hook.assert_called_once() + mock_gcs_hook.reset_mock() + class TestBeamRunJavaPipelineOperator: @pytest.fixture(autouse=True) From b70cc1d28f046231e2c00da436d3d20dfa294aa6 Mon Sep 17 00:00:00 2001 From: Zack Strathe Date: Tue, 16 Apr 2024 23:23:54 -0500 Subject: [PATCH 3/4] Split out unit tests for TestBeamRunPythonPipelineOperator with GCSHook 'gs://' arg prefixes --- .../apache/beam/operators/test_beam.py | 62 +++++++++++++------ 1 file changed, 44 insertions(+), 18 deletions(-) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py index b98ec9c10d351..325db7337f5da 100755 --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -258,7 +258,11 @@ def test_on_kill_direct_runner(self, _, dataflow_mock, __): @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) - def test_execute_gcs_hook_called_only_with_gs_prefix(self, mock_gcs_hook, _): + def test_execute_gcs_hook_not_called_without_gs_prefix(self, mock_gcs_hook, _): + """ + Test that execute method does not call GCSHook when neither py_file nor requirements_file + starts with 'gs://'. (i.e., running pipeline entirely locally) + """ local_test_op_args = { "task_id": TASK_ID, "py_file": 'local_file.py', @@ -272,42 +276,64 @@ def test_execute_gcs_hook_called_only_with_gs_prefix(self, mock_gcs_hook, _): } } - """ - Test that execute method does not call GCSHook when neither py_file nor requirements_file - starts with 'gs://'. - """ - test_kwargs_local = copy.deepcopy(local_test_op_args) - op = BeamRunPythonPipelineOperator(**test_kwargs_local) + op = BeamRunPythonPipelineOperator(**local_test_op_args) context_mock = mock.MagicMock() op.execute(context_mock) mock_gcs_hook.assert_not_called() - mock_gcs_hook.reset_mock() + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_called_with_gs_prefix_py_file(self, mock_gcs_hook, _): """ Test that execute method calls GCSHook when only 'py_file' starts with 'gs://'. """ - test_kwargs_local = copy.deepcopy(local_test_op_args) - test_kwargs_local['py_file'] = 'gs://gcs_file.py' - op = BeamRunPythonPipelineOperator(**test_kwargs_local) + local_test_op_args = { + "task_id": TASK_ID, + "py_file": 'gs://gcs_file.py', + "py_options": ['-m'], + "default_pipeline_options": { + "project": TEST_PROJECT, + 'requirements_file': 'local_requirements.txt' + }, + "pipeline_options": { + "output": 'test_local/output', "labels": {"foo": "bar"} + } + } + op = BeamRunPythonPipelineOperator(**local_test_op_args) context_mock = mock.MagicMock() + op.execute(context_mock) mock_gcs_hook.assert_called_once() - mock_gcs_hook.reset_mock() + @mock.patch(BEAM_OPERATOR_PATH.format("BeamHook")) + @mock.patch(BEAM_OPERATOR_PATH.format("GCSHook")) + def test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self, mock_gcs_hook, _): """ - Test that execute calls GCSHook when only pipeline_options 'requirements_file' starts with 'gs://'. + Test that execute method calls GCSHook when only pipeline_options 'requirements_file' starts with + 'gs://'. Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when - BeamRunPythonPipelineOperator is instantiated, so testing 'requirements_file' specified + BeamRunPythonPipelineOperator is instantiated, so testing GCS 'requirements_file' specified in "pipeline_options" """ - test_kwargs_local = copy.deepcopy(local_test_op_args) - test_kwargs_local['pipeline_options']['requirements_file'] = 'gs://gcs_requirements.txt' - op = BeamRunPythonPipelineOperator(**test_kwargs_local) + local_test_op_args = { + "task_id": TASK_ID, + "py_file": 'local_file.py', + "py_options": ['-m'], + "default_pipeline_options": { + "project": TEST_PROJECT, + 'requirements_file': 'gs://gcs_requirements.txt' + }, + "pipeline_options": { + "output": 'test_local/output', "labels": {"foo": "bar"} + } + } + + op = BeamRunPythonPipelineOperator(**local_test_op_args) context_mock = mock.MagicMock() + op.execute(context_mock) mock_gcs_hook.assert_called_once() - mock_gcs_hook.reset_mock() class TestBeamRunJavaPipelineOperator: From 31c4847ff3f0896cf07495ec53d3a4f543e61584 Mon Sep 17 00:00:00 2001 From: Zack Strathe Date: Thu, 18 Apr 2024 17:55:14 -0500 Subject: [PATCH 4/4] Fix formatting --- .../providers/apache/beam/operators/beam.py | 2 +- .../apache/beam/operators/test_beam.py | 36 ++++++++----------- 2 files changed, 16 insertions(+), 22 deletions(-) mode change 100755 => 100644 tests/providers/apache/beam/operators/test_beam.py diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py index 6187e7971dcdf..62f650f19a4b1 100644 --- a/airflow/providers/apache/beam/operators/beam.py +++ b/airflow/providers/apache/beam/operators/beam.py @@ -369,7 +369,7 @@ def execute_sync(self, context: Context): tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file)) self.py_file = tmp_gcs_file.name if self.snake_case_pipeline_options.get("requirements_file", "").startswith("gs://"): - gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) + gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id) tmp_req_file = exit_stack.enter_context( gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"]) ) diff --git a/tests/providers/apache/beam/operators/test_beam.py b/tests/providers/apache/beam/operators/test_beam.py old mode 100755 new mode 100644 index 325db7337f5da..a6a4c31c77a5c --- a/tests/providers/apache/beam/operators/test_beam.py +++ b/tests/providers/apache/beam/operators/test_beam.py @@ -265,15 +265,13 @@ def test_execute_gcs_hook_not_called_without_gs_prefix(self, mock_gcs_hook, _): """ local_test_op_args = { "task_id": TASK_ID, - "py_file": 'local_file.py', - "py_options": ['-m'], + "py_file": "local_file.py", + "py_options": ["-m"], "default_pipeline_options": { "project": TEST_PROJECT, - 'requirements_file': 'local_requirements.txt' + "requirements_file": "local_requirements.txt", }, - "pipeline_options": { - "output": 'test_local/output', "labels": {"foo": "bar"} - } + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, } op = BeamRunPythonPipelineOperator(**local_test_op_args) @@ -290,15 +288,13 @@ def test_execute_gcs_hook_called_with_gs_prefix_py_file(self, mock_gcs_hook, _): """ local_test_op_args = { "task_id": TASK_ID, - "py_file": 'gs://gcs_file.py', - "py_options": ['-m'], + "py_file": "gs://gcs_file.py", + "py_options": ["-m"], "default_pipeline_options": { "project": TEST_PROJECT, - 'requirements_file': 'local_requirements.txt' + "requirements_file": "local_requirements.txt", }, - "pipeline_options": { - "output": 'test_local/output', "labels": {"foo": "bar"} - } + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, } op = BeamRunPythonPipelineOperator(**local_test_op_args) context_mock = mock.MagicMock() @@ -312,23 +308,21 @@ def test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self, mock """ Test that execute method calls GCSHook when only pipeline_options 'requirements_file' starts with 'gs://'. - Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when - BeamRunPythonPipelineOperator is instantiated, so testing GCS 'requirements_file' specified + Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when + BeamRunPythonPipelineOperator is instantiated, so testing GCS 'requirements_file' specified in "pipeline_options" """ local_test_op_args = { "task_id": TASK_ID, - "py_file": 'local_file.py', - "py_options": ['-m'], + "py_file": "local_file.py", + "py_options": ["-m"], "default_pipeline_options": { "project": TEST_PROJECT, - 'requirements_file': 'gs://gcs_requirements.txt' + "requirements_file": "gs://gcs_requirements.txt", }, - "pipeline_options": { - "output": 'test_local/output', "labels": {"foo": "bar"} - } + "pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}}, } - + op = BeamRunPythonPipelineOperator(**local_test_op_args) context_mock = mock.MagicMock()