diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 46f78d48d30..273d65ecbdd 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -2149,7 +2149,7 @@ cdef class DatasetFactory(_Weakrefable): result = self.factory.Inspect(options) return pyarrow_wrap_schema(GetResultValue(result)) - def finish(self, Schema schema=None): + def finish(self, Schema schema=None, object validate_schema=None): """ Create a Dataset using the inspected schema or an explicit schema (if given). @@ -2165,16 +2165,48 @@ cdef class DatasetFactory(_Weakrefable): Dataset """ cdef: + CFinishOptions options + CInspectOptions inspect_options shared_ptr[CSchema] sp_schema CResult[shared_ptr[CDataset]] result + bint validate_fragments = False + int fragments if schema is not None: sp_schema = pyarrow_unwrap_schema(schema) - with nogil: - result = self.factory.FinishWithSchema(sp_schema) + options.schema = sp_schema + + if validate_schema is None: + # default of not validating the schema (if specified) or + # using the first fragment to inspect schema (fragments=1) + pass + elif isinstance(validate_schema, bool): + if validate_schema: + # validate_schema=True -> validate or inspect all fragments + validate_fragments = True + inspect_options.fragments = -1 + else: + if schema is None: + raise ValueError( + "cannot specify validate_schema=False when no schema " + "was specified manually" + ) else: - with nogil: - result = self.factory.Finish() + fragments = validate_schema + if fragments > 0: + validate_fragments = True + inspect_options.fragments = fragments + else: + raise ValueError( + "need to specify positive number of fragments for which " + "to validate the schema" + ) + + options.validate_fragments = validate_fragments + options.inspect_options = inspect_options + + with nogil: + result = self.factory.FinishWithOptions(options) return Dataset.wrap(GetResultValue(result)) diff --git a/python/pyarrow/dataset.py b/python/pyarrow/dataset.py index b3c142f6323..f9a63604802 100644 --- a/python/pyarrow/dataset.py +++ b/python/pyarrow/dataset.py @@ -382,7 +382,7 @@ def _ensure_single_source(path, filesystem=None): def _filesystem_dataset(source, schema=None, filesystem=None, partitioning=None, format=None, partition_base_dir=None, exclude_invalid_files=None, - selector_ignore_prefixes=None): + selector_ignore_prefixes=None, validate_schema=None): """ Create a FileSystemDataset which can be used to build a Dataset. @@ -408,7 +408,7 @@ def _filesystem_dataset(source, schema=None, filesystem=None, ) factory = FileSystemDatasetFactory(fs, paths_or_selector, format, options) - return factory.finish(schema) + return factory.finish(schema, validate_schema=validate_schema) def _in_memory_dataset(source, schema=None, **kwargs): @@ -499,7 +499,8 @@ def parquet_dataset(metadata_path, schema=None, filesystem=None, format=None, def dataset(source, schema=None, format=None, filesystem=None, partitioning=None, partition_base_dir=None, - exclude_invalid_files=None, ignore_prefixes=None): + exclude_invalid_files=None, ignore_prefixes=None, + validate_schema=None): """ Open a dataset. @@ -575,6 +576,17 @@ def dataset(source, schema=None, format=None, filesystem=None, discovery process. This is matched to the basename of a path. By default this is ['.', '_']. Note that discovery happens only if a directory is passed as source. + validate_schema : bool or int, optional + Whether to validate the specified or inspected schema. By default + (``validate_schema=None``), it will not validate a specified schema + or will infer the schema from the first fragment if no `schema` is + manually specified. + When specifying ``validate_schema=True``, all fragments will be + checked for inspecting the schema or for validating the specified + schema. + You can further specify an integer to have greater control on the + exact number of fragments that will be inspected to infer or validate + the schema. Returns ------- @@ -649,7 +661,8 @@ def dataset(source, schema=None, format=None, filesystem=None, format=format, partition_base_dir=partition_base_dir, exclude_invalid_files=exclude_invalid_files, - selector_ignore_prefixes=ignore_prefixes + selector_ignore_prefixes=ignore_prefixes, + validate_schema=validate_schema, ) if _is_path_like(source): diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index 4da29783b20..08432ae6d3f 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -159,6 +159,8 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: CResult[shared_ptr[CSchema]] Inspect(CInspectOptions) CResult[shared_ptr[CDataset]] FinishWithSchema "Finish"( const shared_ptr[CSchema]& schema) + CResult[shared_ptr[CDataset]] FinishWithOptions "Finish"( + CFinishOptions options) CResult[shared_ptr[CDataset]] Finish() const CExpression& root_partition() CStatus SetRootPartition(CExpression partition) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 26c14e14822..3d300ccce09 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -2658,6 +2658,114 @@ def test_filter_mismatching_schema(tempdir): assert filtered["col"].equals(table["col"].cast('int64').slice(2)) +@pytest.mark.parquet +def test_dataset_validate_schema_keyword(tempdir): + # ARROW-8221 + import pyarrow.parquet as pq + + basedir = tempdir / "dataset_mismatched_schemas" + basedir.mkdir() + + table1 = pa.table({'a': [1, 2, 3], 'b': [1, 2, 3]}) + pq.write_table(table1, basedir / "data1.parquet") + table2 = pa.table({'a': ["a", "b", "c"], 'b': [1, 2, 3]}) + pq.write_table(table2, basedir / "data2.parquet") + + msg_scanning = "matching names but differing types" + msg_inspecting = "Unable to merge: Field a has incompatible types" + + # default (inspecting first fragments) works, but fails scanning + dataset = ds.dataset(basedir) + assert dataset.schema.equals(table1.schema) + with pytest.raises(TypeError, match=msg_scanning): + dataset.to_table() + + # validate_schema=True -> inspect all elements -> fails on inspection + with pytest.raises(ValueError, match=msg_inspecting): + ds.dataset(basedir, validate_schema=True) + + # validate_schema=False -> not possible when not specifying a schema + with pytest.raises(ValueError, match="no schema was specified manually"): + ds.dataset(basedir, validate_schema=False) + + # validate_schema=integer -> the number of fragments to inspect + dataset = ds.dataset(basedir, validate_schema=1) + assert dataset.schema.equals(table1.schema) + with pytest.raises(TypeError, match=msg_scanning): + dataset.to_table() + + with pytest.raises(ValueError, match=msg_inspecting): + ds.dataset(basedir, validate_schema=2) + + # with specifying a schema + schema1 = pa.schema([('a', 'int64'), ('b', 'int64')]) + schema2 = pa.schema([('a', 'string'), ('b', 'int64')]) + + # default (no validation) works, but fails scanning + dataset = ds.dataset(basedir, schema=schema1) + assert dataset.schema.equals(schema1) + with pytest.raises(TypeError, match=msg_scanning): + dataset.to_table() + + # validate_schema=False -> same as default + dataset = ds.dataset(basedir, schema=schema1, validate_schema=False) + assert dataset.schema.equals(schema1) + with pytest.raises(TypeError, match=msg_scanning): + dataset.to_table() + + # validate_schema=True -> validate schema of all fragments -> fails + with pytest.raises(ValueError, match=msg_inspecting): + ds.dataset(basedir, schema=schema1, validate_schema=True) + + # validate_schema=integer -> the number of fragments to validate + dataset = ds.dataset(basedir, schema=schema1, validate_schema=1) + assert dataset.schema.equals(schema1) + with pytest.raises(TypeError, match=msg_scanning): + dataset.to_table() + + with pytest.raises(ValueError, match=msg_inspecting): + ds.dataset(basedir, schema=schema1, validate_schema=2) + + with pytest.raises(ValueError, match=msg_inspecting): + ds.dataset(basedir, schema=schema2, validate_schema=1) + + # validate_schema=integer -> integer needs to be positive + with pytest.raises(ValueError, match="positive number of fragments"): + ds.dataset(basedir, validate_schema=0) + + with pytest.raises(ValueError, match="positive number of fragments"): + ds.dataset(basedir, schema=schema1, validate_schema=0) + + # invalid value for the keyword + with pytest.raises(TypeError,): + ds.dataset(basedir, validate_schema="test") + + +@pytest.mark.parquet +def test_dataset_validate_schema_unify(tempdir): + # ARROW-8221 + import pyarrow.parquet as pq + + basedir = tempdir / "dataset_mismatched_schemas" + basedir.mkdir() + + table1 = pa.table({'a': [1, 2, 3], 'b': [1, 2, 3]}) + pq.write_table(table1, basedir / "data1.parquet") + table2 = pa.table({'b': [4, 5, 6], 'c': ["a", "b", "c"]}) + pq.write_table(table2, basedir / "data2.parquet") + + dataset = ds.dataset(basedir) + # default (inspecting first fragments) + dataset = ds.dataset(basedir) + assert dataset.to_table().schema.equals(table1.schema) + + # Inspecting all fragments -> unify schemas + dataset = ds.dataset(basedir, validate_schema=True) + expected_schema = pa.schema( + [("a", "int64"), ("b", "int64"), ("c", "string")]) + assert dataset.to_table().schema.equals(expected_schema) + + @pytest.mark.parquet @pytest.mark.pandas def test_dataset_project_only_partition_columns(tempdir):