Skip to content

Commit bf00d45

Browse files
committed
fix: close bqstorage client if created by to_dataframe/to_arrow
1 parent 1303d57 commit bf00d45

File tree

2 files changed

+106
-32
lines changed

2 files changed

+106
-32
lines changed

bigquery/google/cloud/bigquery/table.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,25 +1519,31 @@ def to_arrow(
15191519
if pyarrow is None:
15201520
raise ValueError(_NO_PYARROW_ERROR)
15211521

1522+
owns_bqstorage_client = False
15221523
if not bqstorage_client and create_bqstorage_client:
1524+
owns_bqstorage_client = True
15231525
bqstorage_client = self.client._create_bqstorage_client()
15241526

1525-
progress_bar = self._get_progress_bar(progress_bar_type)
1527+
try:
1528+
progress_bar = self._get_progress_bar(progress_bar_type)
15261529

1527-
record_batches = []
1528-
for record_batch in self._to_arrow_iterable(bqstorage_client=bqstorage_client):
1529-
record_batches.append(record_batch)
1530+
record_batches = []
1531+
for record_batch in self._to_arrow_iterable(bqstorage_client=bqstorage_client):
1532+
record_batches.append(record_batch)
15301533

1531-
if progress_bar is not None:
1532-
# In some cases, the number of total rows is not populated
1533-
# until the first page of rows is fetched. Update the
1534-
# progress bar's total to keep an accurate count.
1535-
progress_bar.total = progress_bar.total or self.total_rows
1536-
progress_bar.update(record_batch.num_rows)
1534+
if progress_bar is not None:
1535+
# In some cases, the number of total rows is not populated
1536+
# until the first page of rows is fetched. Update the
1537+
# progress bar's total to keep an accurate count.
1538+
progress_bar.total = progress_bar.total or self.total_rows
1539+
progress_bar.update(record_batch.num_rows)
15371540

1538-
if progress_bar is not None:
1539-
# Indicate that the download has finished.
1540-
progress_bar.close()
1541+
if progress_bar is not None:
1542+
# Indicate that the download has finished.
1543+
progress_bar.close()
1544+
finally:
1545+
if owns_bqstorage_client:
1546+
bqstorage_client.transport.channel.close()
15411547

15421548
if record_batches:
15431549
return pyarrow.Table.from_batches(record_batches)
@@ -1655,35 +1661,42 @@ def to_dataframe(
16551661
if dtypes is None:
16561662
dtypes = {}
16571663

1658-
if not bqstorage_client and create_bqstorage_client:
1659-
bqstorage_client = self.client._create_bqstorage_client()
1660-
1661-
if bqstorage_client and self.max_results is not None:
1664+
if (bqstorage_client or create_bqstorage_client) and self.max_results is not None:
16621665
warnings.warn(
16631666
"Cannot use bqstorage_client if max_results is set, "
16641667
"reverting to fetching data with the tabledata.list endpoint.",
16651668
stacklevel=2,
16661669
)
1670+
create_bqstorage_client = False
16671671
bqstorage_client = None
16681672

1669-
progress_bar = self._get_progress_bar(progress_bar_type)
1673+
owns_bqstorage_client = False
1674+
if not bqstorage_client and create_bqstorage_client:
1675+
owns_bqstorage_client = True
1676+
bqstorage_client = self.client._create_bqstorage_client()
16701677

1671-
frames = []
1672-
for frame in self._to_dataframe_iterable(
1673-
bqstorage_client=bqstorage_client, dtypes=dtypes
1674-
):
1675-
frames.append(frame)
1678+
try:
1679+
progress_bar = self._get_progress_bar(progress_bar_type)
1680+
1681+
frames = []
1682+
for frame in self._to_dataframe_iterable(
1683+
bqstorage_client=bqstorage_client, dtypes=dtypes
1684+
):
1685+
frames.append(frame)
1686+
1687+
if progress_bar is not None:
1688+
# In some cases, the number of total rows is not populated
1689+
# until the first page of rows is fetched. Update the
1690+
# progress bar's total to keep an accurate count.
1691+
progress_bar.total = progress_bar.total or self.total_rows
1692+
progress_bar.update(len(frame))
16761693

16771694
if progress_bar is not None:
1678-
# In some cases, the number of total rows is not populated
1679-
# until the first page of rows is fetched. Update the
1680-
# progress bar's total to keep an accurate count.
1681-
progress_bar.total = progress_bar.total or self.total_rows
1682-
progress_bar.update(len(frame))
1683-
1684-
if progress_bar is not None:
1685-
# Indicate that the download has finished.
1686-
progress_bar.close()
1695+
# Indicate that the download has finished.
1696+
progress_bar.close()
1697+
finally:
1698+
if owns_bqstorage_client:
1699+
bqstorage_client.transport.channel.close()
16871700

16881701
# Avoid concatting an empty list.
16891702
if not frames:

bigquery/tests/unit/test_table.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,12 @@
2626

2727
try:
2828
from google.cloud import bigquery_storage_v1beta1
29+
from google.cloud.bigquery_storage_v1beta1.gapic.transports import (
30+
big_query_storage_grpc_transport,
31+
)
2932
except ImportError: # pragma: NO COVER
3033
bigquery_storage_v1beta1 = None
34+
big_query_storage_grpc_transport = None
3135

3236
try:
3337
import pandas
@@ -1817,6 +1821,9 @@ def test_to_arrow_w_bqstorage(self):
18171821
bqstorage_client = mock.create_autospec(
18181822
bigquery_storage_v1beta1.BigQueryStorageClient
18191823
)
1824+
bqstorage_client.transport = mock.create_autospec(
1825+
big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
1826+
)
18201827
streams = [
18211828
# Use two streams we want to check frames are read from each stream.
18221829
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"},
@@ -1882,6 +1889,9 @@ def test_to_arrow_w_bqstorage(self):
18821889
total_rows = expected_num_rows * total_pages
18831890
self.assertEqual(actual_tbl.num_rows, total_rows)
18841891

1892+
# Don't close the client if it was passed in.
1893+
bqstorage_client.transport.channel.close.assert_not_called()
1894+
18851895
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
18861896
@unittest.skipIf(
18871897
bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`"
@@ -1894,6 +1904,9 @@ def test_to_arrow_w_bqstorage_creates_client(self):
18941904
bqstorage_client = mock.create_autospec(
18951905
bigquery_storage_v1beta1.BigQueryStorageClient
18961906
)
1907+
bqstorage_client.transport = mock.create_autospec(
1908+
big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
1909+
)
18971910
mock_client._create_bqstorage_client.return_value = bqstorage_client
18981911
session = bigquery_storage_v1beta1.types.ReadSession()
18991912
bqstorage_client.create_read_session.return_value = session
@@ -1910,6 +1923,7 @@ def test_to_arrow_w_bqstorage_creates_client(self):
19101923
)
19111924
row_iterator.to_arrow(create_bqstorage_client=True)
19121925
mock_client._create_bqstorage_client.assert_called_once()
1926+
bqstorage_client.transport.channel.close.assert_called_once()
19131927

19141928
@unittest.skipIf(pyarrow is None, "Requires `pyarrow`")
19151929
@unittest.skipIf(
@@ -2321,6 +2335,43 @@ def test_to_dataframe_max_results_w_bqstorage_warning(self):
23212335
]
23222336
self.assertEqual(len(matches), 1, msg="User warning was not emitted.")
23232337

2338+
@unittest.skipIf(pandas is None, "Requires `pandas`")
2339+
def test_to_dataframe_max_results_w_create_bqstorage_warning(self):
2340+
from google.cloud.bigquery.schema import SchemaField
2341+
2342+
schema = [
2343+
SchemaField("name", "STRING", mode="REQUIRED"),
2344+
SchemaField("age", "INTEGER", mode="REQUIRED"),
2345+
]
2346+
rows = [
2347+
{"f": [{"v": "Phred Phlyntstone"}, {"v": "32"}]},
2348+
{"f": [{"v": "Bharney Rhubble"}, {"v": "33"}]},
2349+
]
2350+
path = "/foo"
2351+
api_request = mock.Mock(return_value={"rows": rows})
2352+
mock_client = _mock_client()
2353+
2354+
row_iterator = self._make_one(
2355+
client=mock_client,
2356+
api_request=api_request,
2357+
path=path,
2358+
schema=schema,
2359+
max_results=42,
2360+
)
2361+
2362+
with warnings.catch_warnings(record=True) as warned:
2363+
row_iterator.to_dataframe(create_bqstorage_client=True)
2364+
2365+
matches = [
2366+
warning
2367+
for warning in warned
2368+
if warning.category is UserWarning
2369+
and "cannot use bqstorage_client" in str(warning).lower()
2370+
and "tabledata.list" in str(warning)
2371+
]
2372+
self.assertEqual(len(matches), 1, msg="User warning was not emitted.")
2373+
mock_client._create_bqstorage_client.assert_not_called()
2374+
23242375
@unittest.skipIf(pandas is None, "Requires `pandas`")
23252376
@unittest.skipIf(
23262377
bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`"
@@ -2333,6 +2384,9 @@ def test_to_dataframe_w_bqstorage_creates_client(self):
23332384
bqstorage_client = mock.create_autospec(
23342385
bigquery_storage_v1beta1.BigQueryStorageClient
23352386
)
2387+
bqstorage_client.transport = mock.create_autospec(
2388+
big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
2389+
)
23362390
mock_client._create_bqstorage_client.return_value = bqstorage_client
23372391
session = bigquery_storage_v1beta1.types.ReadSession()
23382392
bqstorage_client.create_read_session.return_value = session
@@ -2349,6 +2403,7 @@ def test_to_dataframe_w_bqstorage_creates_client(self):
23492403
)
23502404
row_iterator.to_dataframe(create_bqstorage_client=True)
23512405
mock_client._create_bqstorage_client.assert_called_once()
2406+
bqstorage_client.transport.channel.close.assert_called_once()
23522407

23532408
@unittest.skipIf(pandas is None, "Requires `pandas`")
23542409
@unittest.skipIf(
@@ -2485,6 +2540,9 @@ def test_to_dataframe_w_bqstorage_nonempty(self):
24852540
bqstorage_client = mock.create_autospec(
24862541
bigquery_storage_v1beta1.BigQueryStorageClient
24872542
)
2543+
bqstorage_client.transport = mock.create_autospec(
2544+
big_query_storage_grpc_transport.BigQueryStorageGrpcTransport
2545+
)
24882546
streams = [
24892547
# Use two streams we want to check frames are read from each stream.
24902548
{"name": "/projects/proj/dataset/dset/tables/tbl/streams/1234"},
@@ -2539,6 +2597,9 @@ def test_to_dataframe_w_bqstorage_nonempty(self):
25392597
total_rows = len(page_items) * total_pages
25402598
self.assertEqual(len(got.index), total_rows)
25412599

2600+
# Don't close the client if it was passed in.
2601+
bqstorage_client.transport.channel.close.assert_not_called()
2602+
25422603
@unittest.skipIf(pandas is None, "Requires `pandas`")
25432604
@unittest.skipIf(
25442605
bigquery_storage_v1beta1 is None, "Requires `google-cloud-bigquery-storage`"

0 commit comments

Comments
 (0)