Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery_read_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,17 @@ def _get_temp_dataset_id(self):
else:
raise ValueError("temp_dataset has to be either str or DatasetReference")

def _get_temp_dataset_project(self):
"""Returns the project ID for temporary dataset operations.

If temp_dataset is a DatasetReference, returns its projectId.
Otherwise, returns the pipeline project for billing.
"""
if isinstance(self.temp_dataset, DatasetReference):
return self.temp_dataset.projectId
else:
return self._get_project()

def start_bundle(self):
self.bq = bigquery_tools.BigQueryWrapper(
temp_dataset_id=self._get_temp_dataset_id(),
Expand Down Expand Up @@ -276,7 +287,9 @@ def process(self,

def finish_bundle(self):
if self.bq.created_temp_dataset:
self.bq.clean_up_temporary_dataset(self._get_project())
# Use the same project that was used to create the temp dataset
temp_dataset_project = self._get_temp_dataset_project()
self.bq.clean_up_temporary_dataset(temp_dataset_project)

def _get_bq_metadata(self):
if not self.bq_io_metadata:
Expand All @@ -300,7 +313,10 @@ def _setup_temporary_dataset(
element: 'ReadFromBigQueryRequest'):
location = bq.get_query_location(
self._get_project(), element.query, not element.use_standard_sql)
bq.create_temporary_dataset(self._get_project(), location)
# Use the project from temp_dataset if it's a DatasetReference,
# otherwise use the pipeline project
temp_dataset_project = self._get_temp_dataset_project()
bq.create_temporary_dataset(temp_dataset_project, location)

def _execute_query(
self,
Expand Down
170 changes: 170 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_read_internal_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Unit tests for BigQuery read internal module."""

import unittest
from unittest import mock

from apache_beam.io.gcp import bigquery_read_internal
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.value_provider import StaticValueProvider

try:
from apache_beam.io.gcp.internal.clients.bigquery import DatasetReference
except ImportError:
DatasetReference = None


class BigQueryReadSplitTest(unittest.TestCase):
"""Tests for _BigQueryReadSplit DoFn."""
def setUp(self):
if DatasetReference is None:
self.skipTest('BigQuery dependencies are not installed')
self.options = PipelineOptions()
self.gcp_options = self.options.view_as(GoogleCloudOptions)
self.gcp_options.project = 'test-project'

def test_get_temp_dataset_project_with_string_temp_dataset(self):
"""Test _get_temp_dataset_project with string temp_dataset."""
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset='temp_dataset_id')

# Should return the pipeline project when temp_dataset is a string
self.assertEqual(split._get_temp_dataset_project(), 'test-project')

def test_get_temp_dataset_project_with_dataset_reference(self):
"""Test _get_temp_dataset_project with DatasetReference temp_dataset."""
dataset_ref = DatasetReference(
projectId='custom-project', datasetId='temp_dataset_id')
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset=dataset_ref)

# Should return the project from DatasetReference
self.assertEqual(split._get_temp_dataset_project(), 'custom-project')

def test_get_temp_dataset_project_with_none_temp_dataset(self):
"""Test _get_temp_dataset_project with None temp_dataset."""
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset=None)

# Should return the pipeline project when temp_dataset is None
self.assertEqual(split._get_temp_dataset_project(), 'test-project')

def test_get_temp_dataset_project_with_value_provider_project(self):
"""Test _get_temp_dataset_project with ValueProvider project."""
self.gcp_options.project = StaticValueProvider(str, 'vp-project')
dataset_ref = DatasetReference(
projectId='custom-project', datasetId='temp_dataset_id')
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset=dataset_ref)

# Should still return the project from DatasetReference
self.assertEqual(split._get_temp_dataset_project(), 'custom-project')

@mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
def test_setup_temporary_dataset_uses_correct_project(self, mock_bq_wrapper):
"""Test that _setup_temporary_dataset uses the correct project."""
dataset_ref = DatasetReference(
projectId='custom-project', datasetId='temp_dataset_id')
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset=dataset_ref)

# Mock the BigQueryWrapper instance
mock_bq = mock.Mock()
mock_bq.get_query_location.return_value = 'US'

# Mock ReadFromBigQueryRequest
mock_element = mock.Mock()
mock_element.query = 'SELECT * FROM table'
mock_element.use_standard_sql = True

# Call _setup_temporary_dataset
split._setup_temporary_dataset(mock_bq, mock_element)

# Verify that create_temporary_dataset was called with the custom project
mock_bq.create_temporary_dataset.assert_called_once_with(
'custom-project', 'US')
# Verify that get_query_location was called with the pipeline project
mock_bq.get_query_location.assert_called_once_with(
'test-project', 'SELECT * FROM table', False)

@mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
def test_finish_bundle_uses_correct_project(self, mock_bq_wrapper):
"""Test that finish_bundle uses the correct project for cleanup."""
dataset_ref = DatasetReference(
projectId='custom-project', datasetId='temp_dataset_id')
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset=dataset_ref)

# Mock the BigQueryWrapper instance
mock_bq = mock.Mock()
mock_bq.created_temp_dataset = True
split.bq = mock_bq

# Call finish_bundle
split.finish_bundle()

# Verify that clean_up_temporary_dataset was called with the custom project
mock_bq.clean_up_temporary_dataset.assert_called_once_with('custom-project')

@mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
def test_setup_temporary_dataset_with_string_temp_dataset(
self, mock_bq_wrapper):
"""Test _setup_temporary_dataset with string temp_dataset uses pipeline
project."""
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset='temp_dataset_id')

# Mock the BigQueryWrapper instance
mock_bq = mock.Mock()
mock_bq.get_query_location.return_value = 'US'

# Mock ReadFromBigQueryRequest
mock_element = mock.Mock()
mock_element.query = 'SELECT * FROM table'
mock_element.use_standard_sql = True

# Call _setup_temporary_dataset
split._setup_temporary_dataset(mock_bq, mock_element)

# Verify that create_temporary_dataset was called with the pipeline project
mock_bq.create_temporary_dataset.assert_called_once_with(
'test-project', 'US')

@mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
def test_finish_bundle_with_string_temp_dataset(self, mock_bq_wrapper):
"""Test finish_bundle with string temp_dataset uses pipeline project."""
split = bigquery_read_internal._BigQueryReadSplit(
options=self.options, temp_dataset='temp_dataset_id')

# Mock the BigQueryWrapper instance
mock_bq = mock.Mock()
mock_bq.created_temp_dataset = True
split.bq = mock_bq

# Call finish_bundle
split.finish_bundle()

# Verify that clean_up_temporary_dataset was called with the pipeline
# project
mock_bq.clean_up_temporary_dataset.assert_called_once_with('test-project')


if __name__ == '__main__':
unittest.main()
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,17 @@ def _get_temp_table(self, project_id):
dataset=self.temp_dataset_id,
project=project_id)

def _get_temp_table_project(self, fallback_project_id):
"""Returns the project ID for temporary table operations.

If temp_table_ref exists, returns its projectId.
Otherwise, returns the fallback_project_id.
"""
if self.temp_table_ref:
return self.temp_table_ref.projectId
else:
return fallback_project_id

def _get_temp_dataset(self):
if self.temp_table_ref:
return self.temp_table_ref.datasetId
Expand Down Expand Up @@ -639,7 +650,8 @@ def _start_query_job(
query=query,
useLegacySql=use_legacy_sql,
allowLargeResults=not dry_run,
destinationTable=self._get_temp_table(project_id)
destinationTable=self._get_temp_table(
self._get_temp_table_project(project_id))
if not dry_run else None,
flattenResults=flatten_results,
priority=priority,
Expand Down
21 changes: 21 additions & 0 deletions sdks/python/apache_beam/io/gcp/bigquery_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,27 @@ def test_start_query_job_priority_configuration(self):
client.jobs.Insert.call_args[0][0].job.configuration.query.priority,
'INTERACTIVE')

def test_get_temp_table_project_with_temp_table_ref(self):
"""Test _get_temp_table_project returns project from temp_table_ref."""
client = mock.Mock()
temp_table_ref = bigquery.TableReference(
projectId='temp-project',
datasetId='temp_dataset',
tableId='temp_table')
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(
client, temp_table_ref=temp_table_ref)

result = wrapper._get_temp_table_project('fallback-project')
self.assertEqual(result, 'temp-project')

def test_get_temp_table_project_without_temp_table_ref(self):
"""Test _get_temp_table_project returns fallback when no temp_table_ref."""
client = mock.Mock()
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)

result = wrapper._get_temp_table_project('fallback-project')
self.assertEqual(result, 'fallback-project')


@unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
class TestRowAsDictJsonCoder(unittest.TestCase):
Expand Down
Loading