diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 4631b4f02e5..dcb91fa8fac 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -306,6 +306,16 @@ class KeyValuePartitioningInspectImpl { return ::arrow::schema(std::move(fields)); } + std::vector FieldNames() { + std::vector names; + names.reserve(name_to_index_.size()); + + for (auto kv : name_to_index_) { + names.push_back(kv.first); + } + return names; + } + private: std::unordered_map name_to_index_; std::vector> values_; @@ -646,15 +656,29 @@ class HivePartitioningFactory : public PartitioningFactory { } } + field_names_ = impl.FieldNames(); return impl.Finish(&dictionaries_); } Result> Finish( const std::shared_ptr& schema) const override { - return std::shared_ptr(new HivePartitioning(schema, dictionaries_)); + if (dictionaries_.empty()) { + return std::make_shared(schema, dictionaries_); + } else { + for (FieldRef ref : field_names_) { + // ensure all of field_names_ are present in schema + RETURN_NOT_OK(ref.FindOne(*schema).status()); + } + + // drop fields which aren't in field_names_ + auto out_schema = SchemaFromColumnNames(schema, field_names_); + + return std::make_shared(std::move(out_schema), dictionaries_); + } } private: + std::vector field_names_; ArrayVector dictionaries_; PartitioningFactoryOptions options_; }; diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index e71bb7cee40..b7200ed497c 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1303,6 +1303,35 @@ def test_open_dataset_non_existing_file(): ds.dataset('file:i-am-not-existing.parquet', format='parquet') +@pytest.mark.parquet +@pytest.mark.parametrize('partitioning', ["directory", "hive"]) +def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning): + # ARROW-9288 + import pyarrow.parquet as pq + table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5}) + + path = tempdir / "dataset" + path.mkdir() + + for part in ["A", "B", "C"]: + fmt = "{}" if partitioning == "directory" else "part={}" + part = path / fmt.format(part) + part.mkdir() + pq.write_table(table, part / "test.parquet") + + if partitioning == "directory": + part = ds.DirectoryPartitioning.discover( + ["part"], max_partition_dictionary_size=-1) + else: + part = ds.HivePartitioning.discover(max_partition_dictionary_size=-1) + + dataset = ds.dataset(str(path), partitioning=part) + expected_schema = table.schema.append( + pa.field("part", pa.dictionary(pa.int32(), pa.string())) + ) + assert dataset.schema.equals(expected_schema) + + @pytest.fixture def s3_example_simple(s3_connection, s3_server): from pyarrow.fs import FileSystem