Skip to content

Commit 40a3752

Browse files
committed
Allow passing schema as dicts in pandas helpers
1 parent e0cc7fd commit 40a3752

File tree

2 files changed

+181
-10
lines changed

2 files changed

+181
-10
lines changed

bigquery/google/cloud/bigquery/_pandas_helpers.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,10 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
239239
Args:
240240
dataframe (pandas.DataFrame):
241241
DataFrame for which the client determines the BigQuery schema.
242-
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
242+
bq_schema (Sequence[Union[ \
243+
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
244+
Sequence[Mapping[str, str]] \
245+
]]):
243246
A BigQuery schema. Use this argument to override the autodetected
244247
type for some or all of the DataFrame columns.
245248
@@ -249,6 +252,7 @@ def dataframe_to_bq_schema(dataframe, bq_schema):
249252
any column cannot be determined.
250253
"""
251254
if bq_schema:
255+
bq_schema = schema._to_schema_fields(bq_schema)
252256
for field in bq_schema:
253257
if field.field_type in schema._STRUCT_TYPES:
254258
raise ValueError(
@@ -297,7 +301,10 @@ def dataframe_to_arrow(dataframe, bq_schema):
297301
Args:
298302
dataframe (pandas.DataFrame):
299303
DataFrame to convert to Arrow table.
300-
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
304+
bq_schema (Sequence[Union[ \
305+
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
306+
Sequence[Mapping[str, str]] \
307+
]]):
301308
Desired BigQuery schema. Number of columns must match number of
302309
columns in the DataFrame.
303310
@@ -310,6 +317,8 @@ def dataframe_to_arrow(dataframe, bq_schema):
310317
column_and_index_names = set(
311318
name for name, _ in list_columns_and_indexes(dataframe)
312319
)
320+
321+
bq_schema = schema._to_schema_fields(bq_schema)
313322
bq_field_names = set(field.name for field in bq_schema)
314323

315324
extra_fields = bq_field_names - column_and_index_names
@@ -354,7 +363,10 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN
354363
Args:
355364
dataframe (pandas.DataFrame):
356365
DataFrame to convert to Parquet file.
357-
bq_schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
366+
bq_schema (Sequence[Union[ \
367+
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
368+
Sequence[Mapping[str, str]] \
369+
]]):
358370
Desired BigQuery schema. Number of columns must match number of
359371
columns in the DataFrame.
360372
filepath (str):
@@ -368,6 +380,7 @@ def dataframe_to_parquet(dataframe, bq_schema, filepath, parquet_compression="SN
368380
if pyarrow is None:
369381
raise ValueError("pyarrow is required for BigQuery schema conversion.")
370382

383+
bq_schema = schema._to_schema_fields(bq_schema)
371384
arrow_table = dataframe_to_arrow(dataframe, bq_schema)
372385
pyarrow.parquet.write_table(arrow_table, filepath, compression=parquet_compression)
373386

@@ -388,20 +401,24 @@ def _tabledata_list_page_to_arrow(page, column_names, arrow_types):
388401
return pyarrow.RecordBatch.from_arrays(arrays, names=column_names)
389402

390403

391-
def download_arrow_tabledata_list(pages, schema):
404+
def download_arrow_tabledata_list(pages, bq_schema):
392405
"""Use tabledata.list to construct an iterable of RecordBatches.
393406
394407
Args:
395408
pages (Iterator[:class:`google.api_core.page_iterator.Page`]):
396409
An iterator over the result pages.
397-
schema (Sequence[google.cloud.bigquery.schema.SchemaField]):
410+
bq_schema (Sequence[Union[ \
411+
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
412+
Sequence[Mapping[str, str]] \
413+
]]):
398414
A decription of the fields in result pages.
399415
Yields:
400416
:class:`pyarrow.RecordBatch`
401417
The next page of records as a ``pyarrow`` record batch.
402418
"""
403-
column_names = bq_to_arrow_schema(schema) or [field.name for field in schema]
404-
arrow_types = [bq_to_arrow_data_type(field) for field in schema]
419+
bq_schema = schema._to_schema_fields(bq_schema)
420+
column_names = bq_to_arrow_schema(bq_schema) or [field.name for field in bq_schema]
421+
arrow_types = [bq_to_arrow_data_type(field) for field in bq_schema]
405422

406423
for page in pages:
407424
yield _tabledata_list_page_to_arrow(page, column_names, arrow_types)
@@ -422,9 +439,26 @@ def _tabledata_list_page_to_dataframe(page, column_names, dtypes):
422439
return pandas.DataFrame(columns, columns=column_names)
423440

424441

425-
def download_dataframe_tabledata_list(pages, schema, dtypes):
426-
"""Use (slower, but free) tabledata.list to construct a DataFrame."""
427-
column_names = [field.name for field in schema]
442+
def download_dataframe_tabledata_list(pages, bq_schema, dtypes):
443+
"""Use (slower, but free) tabledata.list to construct a DataFrame.
444+
445+
Args:
446+
pages (Iterator[:class:`google.api_core.page_iterator.Page`]):
447+
An iterator over the result pages.
448+
bq_schema (Sequence[Union[ \
449+
Sequence[:class:`~google.cloud.bigquery.schema.SchemaField`], \
450+
Sequence[Mapping[str, str]] \
451+
]]):
452+
A decription of the fields in result pages.
453+
dtypes(Mapping[str, numpy.dtype]):
454+
The types of columns in result data to hint construction of the
455+
resulting DataFrame. Not all column types have to be specified.
456+
Yields:
457+
:class:`pandas.DataFrame`
458+
The next page of records as a ``pandas.DataFrame`` record batch.
459+
"""
460+
bq_schema = schema._to_schema_fields(bq_schema)
461+
column_names = [field.name for field in bq_schema]
428462
for page in pages:
429463
yield _tabledata_list_page_to_dataframe(page, column_names, dtypes)
430464

bigquery/tests/unit/test__pandas_helpers.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,32 @@ def test_list_columns_and_indexes_with_multiindex(module_under_test):
701701
assert columns_and_indexes == expected
702702

703703

704+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
705+
def test_dataframe_to_bq_schema_dict_sequence(module_under_test):
706+
df_data = collections.OrderedDict(
707+
[
708+
("str_column", [u"hello", u"world"]),
709+
("int_column", [42, 8]),
710+
("bool_column", [True, False]),
711+
]
712+
)
713+
dataframe = pandas.DataFrame(df_data)
714+
715+
dict_schema = [
716+
{"name": "str_column", "type": "STRING", "mode": "NULLABLE"},
717+
{"name": "bool_column", "type": "BOOL", "mode": "REQUIRED"},
718+
]
719+
720+
returned_schema = module_under_test.dataframe_to_bq_schema(dataframe, dict_schema)
721+
722+
expected_schema = (
723+
schema.SchemaField("str_column", "STRING", "NULLABLE"),
724+
schema.SchemaField("int_column", "INTEGER", "NULLABLE"),
725+
schema.SchemaField("bool_column", "BOOL", "REQUIRED"),
726+
)
727+
assert returned_schema == expected_schema
728+
729+
704730
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
705731
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
706732
def test_dataframe_to_arrow_with_multiindex(module_under_test):
@@ -856,6 +882,28 @@ def test_dataframe_to_arrow_with_unknown_type(module_under_test):
856882
assert arrow_schema[3].name == "field03"
857883

858884

885+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
886+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
887+
def test_dataframe_to_arrow_dict_sequence_schema(module_under_test):
888+
dict_schema = [
889+
{"name": "field01", "type": "STRING", "mode": "REQUIRED"},
890+
{"name": "field02", "type": "BOOL", "mode": "NULLABLE"},
891+
]
892+
893+
dataframe = pandas.DataFrame(
894+
{"field01": [u"hello", u"world"], "field02": [True, False]}
895+
)
896+
897+
arrow_table = module_under_test.dataframe_to_arrow(dataframe, dict_schema)
898+
arrow_schema = arrow_table.schema
899+
900+
expected_fields = [
901+
pyarrow.field("field01", "string", nullable=False),
902+
pyarrow.field("field02", "bool", nullable=True),
903+
]
904+
assert list(arrow_schema) == expected_fields
905+
906+
859907
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
860908
def test_dataframe_to_parquet_without_pyarrow(module_under_test, monkeypatch):
861909
monkeypatch.setattr(module_under_test, "pyarrow", None)
@@ -908,6 +956,36 @@ def test_dataframe_to_parquet_compression_method(module_under_test):
908956
assert call_args.kwargs.get("compression") == "ZSTD"
909957

910958

959+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
960+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
961+
def test_dataframe_to_parquet_dict_sequence_schema(module_under_test):
962+
dict_schema = [
963+
{"name": "field01", "type": "STRING", "mode": "REQUIRED"},
964+
{"name": "field02", "type": "BOOL", "mode": "NULLABLE"},
965+
]
966+
967+
dataframe = pandas.DataFrame(
968+
{"field01": [u"hello", u"world"], "field02": [True, False]}
969+
)
970+
971+
write_table_patch = mock.patch.object(
972+
module_under_test.pyarrow.parquet, "write_table", autospec=True
973+
)
974+
to_arrow_patch = mock.patch.object(
975+
module_under_test, "dataframe_to_arrow", autospec=True
976+
)
977+
978+
with write_table_patch, to_arrow_patch as fake_to_arrow:
979+
module_under_test.dataframe_to_parquet(dataframe, dict_schema, None)
980+
981+
expected_schema_arg = [
982+
schema.SchemaField("field01", "STRING", mode="REQUIRED"),
983+
schema.SchemaField("field02", "BOOL", mode="NULLABLE"),
984+
]
985+
schema_arg = fake_to_arrow.call_args.args[1]
986+
assert schema_arg == expected_schema_arg
987+
988+
911989
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
912990
def test_download_arrow_tabledata_list_unknown_field_type(module_under_test):
913991
fake_page = api_core.page_iterator.Page(
@@ -977,3 +1055,62 @@ def test_download_arrow_tabledata_list_known_field_type(module_under_test):
9771055
col = result.columns[1]
9781056
assert type(col) is pyarrow.lib.StringArray
9791057
assert list(col) == ["2.2", "22.22", "222.222"]
1058+
1059+
1060+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
1061+
def test_download_arrow_tabledata_list_dict_sequence_schema(module_under_test):
1062+
fake_page = api_core.page_iterator.Page(
1063+
parent=mock.Mock(),
1064+
items=[{"page_data": "foo"}],
1065+
item_to_value=api_core.page_iterator._item_to_value_identity,
1066+
)
1067+
fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]]
1068+
pages = [fake_page]
1069+
1070+
dict_schema = [
1071+
{"name": "population_size", "type": "INTEGER", "mode": "NULLABLE"},
1072+
{"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"},
1073+
]
1074+
1075+
results_gen = module_under_test.download_arrow_tabledata_list(pages, dict_schema)
1076+
result = next(results_gen)
1077+
1078+
assert len(result.columns) == 2
1079+
col = result.columns[0]
1080+
assert type(col) is pyarrow.lib.Int64Array
1081+
assert list(col) == [1, 10, 100]
1082+
col = result.columns[1]
1083+
assert type(col) is pyarrow.lib.StringArray
1084+
assert list(col) == ["2.2", "22.22", "222.222"]
1085+
1086+
1087+
@pytest.mark.skipif(pandas is None, reason="Requires `pandas`")
1088+
@pytest.mark.skipif(pyarrow is None, reason="Requires `pyarrow`")
1089+
def test_download_dataframe_tabledata_list_dict_sequence_schema(module_under_test):
1090+
fake_page = api_core.page_iterator.Page(
1091+
parent=mock.Mock(),
1092+
items=[{"page_data": "foo"}],
1093+
item_to_value=api_core.page_iterator._item_to_value_identity,
1094+
)
1095+
fake_page._columns = [[1, 10, 100], ["2.2", "22.22", "222.222"]]
1096+
pages = [fake_page]
1097+
1098+
dict_schema = [
1099+
{"name": "population_size", "type": "INTEGER", "mode": "NULLABLE"},
1100+
{"name": "non_alien_field", "type": "STRING", "mode": "NULLABLE"},
1101+
]
1102+
1103+
results_gen = module_under_test.download_dataframe_tabledata_list(
1104+
pages, dict_schema, dtypes={}
1105+
)
1106+
result = next(results_gen)
1107+
1108+
expected_result = pandas.DataFrame(
1109+
collections.OrderedDict(
1110+
[
1111+
("population_size", [1, 10, 100]),
1112+
("non_alien_field", ["2.2", "22.22", "222.222"]),
1113+
]
1114+
)
1115+
)
1116+
assert result.equals(expected_result)

0 commit comments

Comments
 (0)