Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def _ensure_factory(src, **kwargs):


def dataset(paths_or_factories, filesystem=None, partitioning=None,
format=None):
format=None, schema=None):
"""
Open a dataset.

Expand All @@ -317,6 +317,9 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None,
field names a DirectionaryPartitioning will be inferred.
format : str
Currently only "parquet" is supported.
schema : Schema, optional
Optionally provide the Schema for the Dataset, in which case it will
not be inferred from the source.

Returns
-------
Expand Down Expand Up @@ -347,8 +350,8 @@ def dataset(paths_or_factories, filesystem=None, partitioning=None,

factories = [_ensure_factory(f, **kwargs) for f in paths_or_factories]
if single_dataset:
return factories[0].finish()
return UnionDatasetFactory(factories).finish()
return factories[0].finish(schema=schema)
return UnionDatasetFactory(factories).finish(schema=schema)


def field(name):
Expand Down
50 changes: 50 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,56 @@ def test_multiple_factories_with_selectors(multisourcefs):
assert dataset.schema.equals(expected_schema)


@pytest.mark.parquet
def test_specified_schema(tempdir):
import pyarrow.parquet as pq

table = pa.table({'a': [1, 2, 3], 'b': [.1, .2, .3]})
pq.write_table(table, tempdir / "data.parquet")

def _check_dataset(schema, expected, expected_schema=None):
dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
if expected_schema is not None:
assert dataset.schema.equals(expected_schema)
else:
assert dataset.schema.equals(schema)
result = dataset.to_table()
assert result.equals(expected)

# no schema specified
schema = None
expected = table
_check_dataset(schema, expected, expected_schema=table.schema)

# identical schema specified
schema = table.schema
expected = table
_check_dataset(schema, expected)

# Specifying schema with change column order
schema = pa.schema([('b', 'float64'), ('a', 'int64')])
expected = pa.table([[.1, .2, .3], [1, 2, 3]], names=['b', 'a'])
_check_dataset(schema, expected)

# Specifying schema with missing column
schema = pa.schema([('a', 'int64')])
expected = pa.table({'a': [1, 2, 3]})
_check_dataset(schema, expected)

# Specifying schema with additional column
schema = pa.schema([('a', 'int64'), ('c', 'int32')])
expected = pa.table({'a': [1, 2, 3],
'c': pa.array([None, None, None], type='int32')})
_check_dataset(schema, expected)

# Specifying with incompatible schema
schema = pa.schema([('a', 'int32'), ('b', 'float64')])
dataset = ds.dataset(str(tempdir / "data.parquet"), schema=schema)
assert dataset.schema.equals(schema)
with pytest.raises(TypeError):
dataset.to_table()


def test_ipc_format(tempdir):
table = pa.table({'a': pa.array([1, 2, 3], type="int8"),
'b': pa.array([.1, .2, .3], type="float64")})
Expand Down