Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,11 +364,12 @@ def execute(self, context: Context):

def execute_sync(self, context: Context):
with ExitStack() as exit_stack:
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
if self.py_file.lower().startswith("gs://"):
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
tmp_gcs_file = exit_stack.enter_context(gcs_hook.provide_file(object_url=self.py_file))
self.py_file = tmp_gcs_file.name
if self.snake_case_pipeline_options.get("requirements_file", "").startswith("gs://"):
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
tmp_req_file = exit_stack.enter_context(
gcs_hook.provide_file(object_url=self.snake_case_pipeline_options["requirements_file"])
)
Expand Down
73 changes: 73 additions & 0 deletions tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,79 @@ def test_on_kill_direct_runner(self, _, dataflow_mock, __):
op.on_kill()
dataflow_cancel_job.assert_not_called()

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_execute_gcs_hook_not_called_without_gs_prefix(self, mock_gcs_hook, _):
"""
Test that execute method does not call GCSHook when neither py_file nor requirements_file
starts with 'gs://'. (i.e., running pipeline entirely locally)
"""
local_test_op_args = {
"task_id": TASK_ID,
"py_file": "local_file.py",
"py_options": ["-m"],
"default_pipeline_options": {
"project": TEST_PROJECT,
"requirements_file": "local_requirements.txt",
},
"pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}},
}

op = BeamRunPythonPipelineOperator(**local_test_op_args)
context_mock = mock.MagicMock()

op.execute(context_mock)
mock_gcs_hook.assert_not_called()

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_execute_gcs_hook_called_with_gs_prefix_py_file(self, mock_gcs_hook, _):
"""
Test that execute method calls GCSHook when only 'py_file' starts with 'gs://'.
"""
local_test_op_args = {
"task_id": TASK_ID,
"py_file": "gs://gcs_file.py",
"py_options": ["-m"],
"default_pipeline_options": {
"project": TEST_PROJECT,
"requirements_file": "local_requirements.txt",
},
"pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}},
}
op = BeamRunPythonPipelineOperator(**local_test_op_args)
context_mock = mock.MagicMock()

op.execute(context_mock)
mock_gcs_hook.assert_called_once()

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_execute_gcs_hook_called_with_gs_prefix_pipeline_requirements(self, mock_gcs_hook, _):
"""
Test that execute method calls GCSHook when only pipeline_options 'requirements_file' starts with
'gs://'.
Note: "pipeline_options" is merged with and overrides keys in "default_pipeline_options" when
BeamRunPythonPipelineOperator is instantiated, so testing GCS 'requirements_file' specified
in "pipeline_options"
"""
local_test_op_args = {
"task_id": TASK_ID,
"py_file": "local_file.py",
"py_options": ["-m"],
"default_pipeline_options": {
"project": TEST_PROJECT,
"requirements_file": "gs://gcs_requirements.txt",
},
"pipeline_options": {"output": "test_local/output", "labels": {"foo": "bar"}},
}

op = BeamRunPythonPipelineOperator(**local_test_op_args)
context_mock = mock.MagicMock()

op.execute(context_mock)
mock_gcs_hook.assert_called_once()


class TestBeamRunJavaPipelineOperator:
@pytest.fixture(autouse=True)
Expand Down