From 270938b5b02b5dcecb6252d854c6b8e48bc50a05 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 31 Mar 2020 17:52:00 +0200 Subject: [PATCH] ARROW-8292: [Python] Allow to manually specify schema in dataset() function --- python/pyarrow/dataset.py | 9 +++-- python/pyarrow/tests/test_dataset.py | 50 ++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index e2811b36e32..fff11be6d77 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -298,7 +298,7 @@ def _ensure_factory(src, **kwargs): def dataset(paths_or_factories, filesystem=None, partitioning=None, - format=None): + format=None, schema=None): """ Open a dataset. @@ -317,6 +317,9 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None, field names a DirectionaryPartitioning will be inferred. format : str Currently only "parquet" is supported. + schema : Schema, optional + Optionally provide the Schema for the Dataset, in which case it will + not be inferred from the source. Returns ------- @@ -347,8 +350,8 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None, factories = [_ensure_factory(f, **kwargs) for f in paths_or_factories] if single_dataset: - return factories[0].finish() - return UnionDatasetFactory(factories).finish() + return factories[0].finish(schema=schema) + return UnionDatasetFactory(factories).finish(schema=schema) def field(name): diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 938f426f246..9d46acf6e1c 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1002,6 +1002,56 @@ def test_multiple_factories_with_selectors(multisourcefs): assert dataset.schema.equals(expected_schema) +@pytest.mark.parquet +def test_specified_schema(tempdir): + import pyarrow.parquet as pq + + table = pa.table({'a': [1, 2, 3], 'b': [.1, .2, .3]}) + pq.write_table(table, tempdir / "data.parquet") + + def _check_dataset(schema, expected, expected_schema=None): + dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema) + if expected_schema is not None: + assert dataset.schema.equals(expected_schema) + else: + assert dataset.schema.equals(schema) + result = dataset.to_table() + assert result.equals(expected) + + # no schema specified + schema = None + expected = table + _check_dataset(schema, expected, expected_schema=table.schema) + + # identical schema specified + schema = table.schema + expected = table + _check_dataset(schema, expected) + + # Specifying schema with change column order + schema = pa.schema([('b', 'float64'), ('a', 'int64')]) + expected = pa.table([[.1, .2, .3], [1, 2, 3]], names=['b', 'a']) + _check_dataset(schema, expected) + + # Specifying schema with missing column + schema = pa.schema([('a', 'int64')]) + expected = pa.table({'a': [1, 2, 3]}) + _check_dataset(schema, expected) + + # Specifying schema with additional column + schema = pa.schema([('a', 'int64'), ('c', 'int32')]) + expected = pa.table({'a': [1, 2, 3], + 'c': pa.array([None, None, None], type='int32')}) + _check_dataset(schema, expected) + + # Specifying with incompatible schema + schema = pa.schema([('a', 'int32'), ('b', 'float64')]) + dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema) + assert dataset.schema.equals(schema) + with pytest.raises(TypeError): + dataset.to_table() + + def test_ipc_format(tempdir): table = pa.table({'a': pa.array([1, 2, 3], type="int8"), 'b': pa.array([.1, .2, .3], type="float64")})