Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion cpp/src/arrow/dataset/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ class KeyValuePartitioningInspectImpl {
return ::arrow::schema(std::move(fields));
}

std::vector<std::string> FieldNames() {
std::vector<std::string> names;
names.reserve(name_to_index_.size());

for (auto kv : name_to_index_) {
names.push_back(kv.first);
}
return names;
}

private:
std::unordered_map<std::string, int> name_to_index_;
std::vector<std::set<std::string>> values_;
Expand Down Expand Up @@ -646,15 +656,29 @@ class HivePartitioningFactory : public PartitioningFactory {
}
}

field_names_ = impl.FieldNames();
return impl.Finish(&dictionaries_);
}

Result<std::shared_ptr<Partitioning>> Finish(
const std::shared_ptr<Schema>& schema) const override {
return std::shared_ptr<Partitioning>(new HivePartitioning(schema, dictionaries_));
if (dictionaries_.empty()) {
return std::make_shared<HivePartitioning>(schema, dictionaries_);
} else {
for (FieldRef ref : field_names_) {
// ensure all of field_names_ are present in schema
RETURN_NOT_OK(ref.FindOne(*schema).status());
}

// drop fields which aren't in field_names_
auto out_schema = SchemaFromColumnNames(schema, field_names_);

return std::make_shared<HivePartitioning>(std::move(out_schema), dictionaries_);
}
}

private:
std::vector<std::string> field_names_;
ArrayVector dictionaries_;
PartitioningFactoryOptions options_;
};
Expand Down
29 changes: 29 additions & 0 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,6 +1303,35 @@ def test_open_dataset_non_existing_file():
ds.dataset('file:i-am-not-existing.parquet', format='parquet')


@pytest.mark.parquet
@pytest.mark.parametrize('partitioning', ["directory", "hive"])
def test_open_dataset_partitioned_dictionary_type(tempdir, partitioning):
# ARROW-9288
import pyarrow.parquet as pq
table = pa.table({'a': range(9), 'b': [0.] * 4 + [1.] * 5})

path = tempdir / "dataset"
path.mkdir()

for part in ["A", "B", "C"]:
fmt = "{}" if partitioning == "directory" else "part={}"
part = path / fmt.format(part)
part.mkdir()
pq.write_table(table, part / "test.parquet")

if partitioning == "directory":
part = ds.DirectoryPartitioning.discover(
["part"], max_partition_dictionary_size=-1)
else:
part = ds.HivePartitioning.discover(max_partition_dictionary_size=-1)

dataset = ds.dataset(str(path), partitioning=part)
expected_schema = table.schema.append(
pa.field("part", pa.dictionary(pa.int32(), pa.string()))
)
assert dataset.schema.equals(expected_schema)


@pytest.fixture
def s3_example_simple(s3_connection, s3_server):
from pyarrow.fs import FileSystem
Expand Down