diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py index f038b48e04d5..059cad04f3df 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_internal.py @@ -237,6 +237,17 @@ def _get_temp_dataset_id(self): else: raise ValueError("temp_dataset has to be either str or DatasetReference") + def _get_temp_dataset_project(self): + """Returns the project ID for temporary dataset operations. + + If temp_dataset is a DatasetReference, returns its projectId. + Otherwise, returns the pipeline project for billing. + """ + if isinstance(self.temp_dataset, DatasetReference): + return self.temp_dataset.projectId + else: + return self._get_project() + def start_bundle(self): self.bq = bigquery_tools.BigQueryWrapper( temp_dataset_id=self._get_temp_dataset_id(), @@ -276,7 +287,9 @@ def process(self, def finish_bundle(self): if self.bq.created_temp_dataset: - self.bq.clean_up_temporary_dataset(self._get_project()) + # Use the same project that was used to create the temp dataset + temp_dataset_project = self._get_temp_dataset_project() + self.bq.clean_up_temporary_dataset(temp_dataset_project) def _get_bq_metadata(self): if not self.bq_io_metadata: @@ -300,7 +313,10 @@ def _setup_temporary_dataset( element: 'ReadFromBigQueryRequest'): location = bq.get_query_location( self._get_project(), element.query, not element.use_standard_sql) - bq.create_temporary_dataset(self._get_project(), location) + # Use the project from temp_dataset if it's a DatasetReference, + # otherwise use the pipeline project + temp_dataset_project = self._get_temp_dataset_project() + bq.create_temporary_dataset(temp_dataset_project, location) def _execute_query( self, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_read_internal_test.py b/sdks/python/apache_beam/io/gcp/bigquery_read_internal_test.py new file mode 100644 index 000000000000..9d162457df54 --- /dev/null +++ b/sdks/python/apache_beam/io/gcp/bigquery_read_internal_test.py @@ -0,0 +1,170 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for BigQuery read internal module.""" + +import unittest +from unittest import mock + +from apache_beam.io.gcp import bigquery_read_internal +from apache_beam.options.pipeline_options import GoogleCloudOptions +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.value_provider import StaticValueProvider + +try: + from apache_beam.io.gcp.internal.clients.bigquery import DatasetReference +except ImportError: + DatasetReference = None + + +class BigQueryReadSplitTest(unittest.TestCase): + """Tests for _BigQueryReadSplit DoFn.""" + def setUp(self): + if DatasetReference is None: + self.skipTest('BigQuery dependencies are not installed') + self.options = PipelineOptions() + self.gcp_options = self.options.view_as(GoogleCloudOptions) + self.gcp_options.project = 'test-project' + + def test_get_temp_dataset_project_with_string_temp_dataset(self): + """Test _get_temp_dataset_project with string temp_dataset.""" + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset='temp_dataset_id') + + # Should return the pipeline project when temp_dataset is a string + self.assertEqual(split._get_temp_dataset_project(), 'test-project') + + def test_get_temp_dataset_project_with_dataset_reference(self): + """Test _get_temp_dataset_project with DatasetReference temp_dataset.""" + dataset_ref = DatasetReference( + projectId='custom-project', datasetId='temp_dataset_id') + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset=dataset_ref) + + # Should return the project from DatasetReference + self.assertEqual(split._get_temp_dataset_project(), 'custom-project') + + def test_get_temp_dataset_project_with_none_temp_dataset(self): + """Test _get_temp_dataset_project with None temp_dataset.""" + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset=None) + + # Should return the pipeline project when temp_dataset is None + self.assertEqual(split._get_temp_dataset_project(), 'test-project') + + def test_get_temp_dataset_project_with_value_provider_project(self): + """Test _get_temp_dataset_project with ValueProvider project.""" + self.gcp_options.project = StaticValueProvider(str, 'vp-project') + dataset_ref = DatasetReference( + projectId='custom-project', datasetId='temp_dataset_id') + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset=dataset_ref) + + # Should still return the project from DatasetReference + self.assertEqual(split._get_temp_dataset_project(), 'custom-project') + + @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper') + def test_setup_temporary_dataset_uses_correct_project(self, mock_bq_wrapper): + """Test that _setup_temporary_dataset uses the correct project.""" + dataset_ref = DatasetReference( + projectId='custom-project', datasetId='temp_dataset_id') + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset=dataset_ref) + + # Mock the BigQueryWrapper instance + mock_bq = mock.Mock() + mock_bq.get_query_location.return_value = 'US' + + # Mock ReadFromBigQueryRequest + mock_element = mock.Mock() + mock_element.query = 'SELECT * FROM table' + mock_element.use_standard_sql = True + + # Call _setup_temporary_dataset + split._setup_temporary_dataset(mock_bq, mock_element) + + # Verify that create_temporary_dataset was called with the custom project + mock_bq.create_temporary_dataset.assert_called_once_with( + 'custom-project', 'US') + # Verify that get_query_location was called with the pipeline project + mock_bq.get_query_location.assert_called_once_with( + 'test-project', 'SELECT * FROM table', False) + + @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper') + def test_finish_bundle_uses_correct_project(self, mock_bq_wrapper): + """Test that finish_bundle uses the correct project for cleanup.""" + dataset_ref = DatasetReference( + projectId='custom-project', datasetId='temp_dataset_id') + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset=dataset_ref) + + # Mock the BigQueryWrapper instance + mock_bq = mock.Mock() + mock_bq.created_temp_dataset = True + split.bq = mock_bq + + # Call finish_bundle + split.finish_bundle() + + # Verify that clean_up_temporary_dataset was called with the custom project + mock_bq.clean_up_temporary_dataset.assert_called_once_with('custom-project') + + @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper') + def test_setup_temporary_dataset_with_string_temp_dataset( + self, mock_bq_wrapper): + """Test _setup_temporary_dataset with string temp_dataset uses pipeline + project.""" + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset='temp_dataset_id') + + # Mock the BigQueryWrapper instance + mock_bq = mock.Mock() + mock_bq.get_query_location.return_value = 'US' + + # Mock ReadFromBigQueryRequest + mock_element = mock.Mock() + mock_element.query = 'SELECT * FROM table' + mock_element.use_standard_sql = True + + # Call _setup_temporary_dataset + split._setup_temporary_dataset(mock_bq, mock_element) + + # Verify that create_temporary_dataset was called with the pipeline project + mock_bq.create_temporary_dataset.assert_called_once_with( + 'test-project', 'US') + + @mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper') + def test_finish_bundle_with_string_temp_dataset(self, mock_bq_wrapper): + """Test finish_bundle with string temp_dataset uses pipeline project.""" + split = bigquery_read_internal._BigQueryReadSplit( + options=self.options, temp_dataset='temp_dataset_id') + + # Mock the BigQueryWrapper instance + mock_bq = mock.Mock() + mock_bq.created_temp_dataset = True + split.bq = mock_bq + + # Call finish_bundle + split.finish_bundle() + + # Verify that clean_up_temporary_dataset was called with the pipeline + # project + mock_bq.clean_up_temporary_dataset.assert_called_once_with('test-project') + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools.py b/sdks/python/apache_beam/io/gcp/bigquery_tools.py index b31f6449fe90..8a130faa35cb 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools.py @@ -412,6 +412,17 @@ def _get_temp_table(self, project_id): dataset=self.temp_dataset_id, project=project_id) + def _get_temp_table_project(self, fallback_project_id): + """Returns the project ID for temporary table operations. + + If temp_table_ref exists, returns its projectId. + Otherwise, returns the fallback_project_id. + """ + if self.temp_table_ref: + return self.temp_table_ref.projectId + else: + return fallback_project_id + def _get_temp_dataset(self): if self.temp_table_ref: return self.temp_table_ref.datasetId @@ -639,7 +650,8 @@ def _start_query_job( query=query, useLegacySql=use_legacy_sql, allowLargeResults=not dry_run, - destinationTable=self._get_temp_table(project_id) + destinationTable=self._get_temp_table( + self._get_temp_table_project(project_id)) if not dry_run else None, flattenResults=flatten_results, priority=priority, diff --git a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py index 1307a7886924..e5552d5222ec 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_tools_test.py @@ -578,6 +578,27 @@ def test_start_query_job_priority_configuration(self): client.jobs.Insert.call_args[0][0].job.configuration.query.priority, 'INTERACTIVE') + def test_get_temp_table_project_with_temp_table_ref(self): + """Test _get_temp_table_project returns project from temp_table_ref.""" + client = mock.Mock() + temp_table_ref = bigquery.TableReference( + projectId='temp-project', + datasetId='temp_dataset', + tableId='temp_table') + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper( + client, temp_table_ref=temp_table_ref) + + result = wrapper._get_temp_table_project('fallback-project') + self.assertEqual(result, 'temp-project') + + def test_get_temp_table_project_without_temp_table_ref(self): + """Test _get_temp_table_project returns fallback when no temp_table_ref.""" + client = mock.Mock() + wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) + + result = wrapper._get_temp_table_project('fallback-project') + self.assertEqual(result, 'fallback-project') + @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestRowAsDictJsonCoder(unittest.TestCase):