diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc index 6444af99f64..537c4a26412 100644 --- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc +++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc @@ -3431,7 +3431,6 @@ TEST(TestArrowWriteDictionaries, AutoReadAsDictionary) { } TEST(TestArrowWriteDictionaries, NestedSubfield) { - // FIXME (ARROW-9943): Automatic decoding of dictionary subfields auto offsets = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 2, 3]"); auto indices = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 0]"); auto dict = ::arrow::ArrayFromJSON(::arrow::utf8(), "[\"foo\"]"); @@ -3442,20 +3441,14 @@ TEST(TestArrowWriteDictionaries, NestedSubfield) { ASSERT_OK_AND_ASSIGN(auto values, ::arrow::ListArray::FromArrays(*offsets, *dict_values)); - auto dense_ty = ::arrow::list(::arrow::utf8()); - auto dense_values = - ::arrow::ArrayFromJSON(dense_ty, "[[], [\"foo\", \"foo\"], [\"foo\"]]"); - auto table = MakeSimpleTable(values, /*nullable=*/true); - auto expected_table = MakeSimpleTable(dense_values, /*nullable=*/true); auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build(); std::shared_ptr actual; DoRoundtrip(table, values->length(), &actual, default_writer_properties(), props_store_schema); - // The nested subfield is not automatically decoded to dictionary - ::arrow::AssertTablesEqual(*expected_table, *actual); + ::arrow::AssertTablesEqual(*table, *actual); } } // namespace arrow diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc index e1bc8a7b198..af437437c19 100644 --- a/cpp/src/parquet/arrow/reader.cc +++ b/cpp/src/parquet/arrow/reader.cc @@ -744,12 +744,20 @@ Status StructReader::BuildArray(int64_t length_upper_bound, // ---------------------------------------------------------------------- // File reader implementation -Status GetStorageReader(const SchemaField& field, - const std::shared_ptr& ctx, - std::unique_ptr* out) { +Status GetReader(const SchemaField& field, const std::shared_ptr& arrow_field, + const std::shared_ptr& ctx, + std::unique_ptr* out) { BEGIN_PARQUET_CATCH_EXCEPTIONS - auto type_id = field.storage_field->type()->id(); + auto type_id = arrow_field->type()->id(); + + if (type_id == ::arrow::Type::EXTENSION) { + auto storage_field = arrow_field->WithType( + checked_cast(*arrow_field->type()).storage_type()); + RETURN_NOT_OK(GetReader(field, storage_field, ctx, out)); + out->reset(new ExtensionReader(arrow_field, std::move(*out))); + return Status::OK(); + } if (field.children.size() == 0) { if (!field.is_leaf()) { @@ -761,10 +769,8 @@ Status GetStorageReader(const SchemaField& field, } std::unique_ptr input( ctx->iterator_factory(field.column_index, ctx->reader)); - out->reset( - new LeafReader(ctx, field.storage_field, std::move(input), field.level_info)); + out->reset(new LeafReader(ctx, arrow_field, std::move(input), field.level_info)); } else if (type_id == ::arrow::Type::LIST) { - auto list_field = field.storage_field; auto child = &field.children[0]; std::unique_ptr child_reader; RETURN_NOT_OK(GetReader(*child, ctx, &child_reader)); @@ -772,7 +778,7 @@ Status GetStorageReader(const SchemaField& field, *out = nullptr; return Status::OK(); } - out->reset(new ListReader(ctx, list_field, field.level_info, + out->reset(new ListReader(ctx, arrow_field, field.level_info, std::move(child_reader))); } else if (type_id == ::arrow::Type::STRUCT) { std::vector> child_fields; @@ -792,12 +798,12 @@ Status GetStorageReader(const SchemaField& field, return Status::OK(); } auto filtered_field = - ::arrow::field(field.storage_field->name(), ::arrow::struct_(child_fields), - field.storage_field->nullable(), field.storage_field->metadata()); + ::arrow::field(arrow_field->name(), ::arrow::struct_(child_fields), + arrow_field->nullable(), arrow_field->metadata()); out->reset(new StructReader(ctx, filtered_field, field.level_info, std::move(child_readers))); } else { - return Status::Invalid("Unsupported nested type: ", field.storage_field->ToString()); + return Status::Invalid("Unsupported nested type: ", arrow_field->ToString()); } return Status::OK(); @@ -806,11 +812,7 @@ Status GetStorageReader(const SchemaField& field, Status GetReader(const SchemaField& field, const std::shared_ptr& ctx, std::unique_ptr* out) { - RETURN_NOT_OK(GetStorageReader(field, ctx, out)); - if (field.field->type()->id() == ::arrow::Type::EXTENSION) { - out->reset(new ExtensionReader(field.field, std::move(*out))); - } - return Status::OK(); + return GetReader(field, field.field, ctx, out); } } // namespace diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 3f7ff922305..6babe9bc7cf 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -17,6 +17,7 @@ #include "parquet/arrow/schema.h" +#include #include #include @@ -35,6 +36,7 @@ #include "parquet/types.h" using arrow::Field; +using arrow::FieldVector; using arrow::KeyValueMetadata; using arrow::Status; using arrow::internal::checked_cast; @@ -425,7 +427,6 @@ Status PopulateLeaf(int column_index, const std::shared_ptr& field, LevelInfo current_levels, SchemaTreeContext* ctx, const SchemaField* parent, SchemaField* out) { out->field = field; - out->storage_field = field; out->column_index = column_index; out->level_info = current_levels; ctx->RecordLeaf(out); @@ -462,7 +463,6 @@ Status GroupToStruct(const GroupNode& node, LevelInfo current_levels, auto struct_type = ::arrow::struct_(arrow_fields); out->field = ::arrow::field(node.name(), struct_type, node.is_optional(), FieldIdMetadata(node.field_id())); - out->storage_field = out->field; out->level_info = current_levels; return Status::OK(); } @@ -543,7 +543,6 @@ Status ListToSchemaField(const GroupNode& group, LevelInfo current_levels, } out->field = ::arrow::field(group.name(), ::arrow::list(child_field->field), group.is_optional(), FieldIdMetadata(group.field_id())); - out->storage_field = out->field; out->level_info = current_levels; // At this point current levels contains the def level for this list, // we need to reset to the prior parent. @@ -571,7 +570,6 @@ Status GroupToSchemaField(const GroupNode& node, LevelInfo current_levels, RETURN_NOT_OK(GroupToStruct(node, current_levels, ctx, out, &out->children[0])); out->field = ::arrow::field(node.name(), ::arrow::list(out->children[0].field), /*nullable=*/false, FieldIdMetadata(node.field_id())); - out->storage_field = out->field; ctx->LinkParent(&out->children[0], out); out->level_info = current_levels; @@ -623,7 +621,6 @@ Status NodeToSchemaField(const Node& node, LevelInfo current_levels, out->field = ::arrow::field(node.name(), ::arrow::list(child_field), /*nullable=*/false, FieldIdMetadata(node.field_id())); - out->storage_field = out->field; out->level_info = current_levels; // At this point current_levels has consider this list the ancestor so restore // the actual ancenstor. @@ -689,10 +686,63 @@ Status GetOriginSchema(const std::shared_ptr& metadata, // but that is not necessarily present in the field reconstitued from Parquet data // (for example, Parquet timestamp types doesn't carry timezone information). -Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* inferred) { +Result ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred); + +std::function(FieldVector)> GetNestedFactory( + const ArrowType& origin_type, const ArrowType& inferred_type) { + switch (inferred_type.id()) { + case ::arrow::Type::STRUCT: + if (origin_type.id() == ::arrow::Type::STRUCT) { + return ::arrow::struct_; + } + break; + case ::arrow::Type::LIST: + // TODO also allow LARGE_LIST and FIXED_SIZE_LIST + if (origin_type.id() == ::arrow::Type::LIST) { + return [](FieldVector fields) { + DCHECK_EQ(fields.size(), 1); + return ::arrow::list(std::move(fields[0])); + }; + } + break; + default: + break; + } + return {}; +} + +Result ApplyOriginalStorageMetadata(const Field& origin_field, + SchemaField* inferred) { + bool modified = false; + auto origin_type = origin_field.type(); auto inferred_type = inferred->field->type(); + const int num_children = inferred_type->num_fields(); + + if (num_children > 0 && origin_type->num_fields() == num_children) { + DCHECK_EQ(static_cast(inferred->children.size()), num_children); + const auto factory = GetNestedFactory(*origin_type, *inferred_type); + if (factory) { + // Apply original metadata recursively to children + for (int i = 0; i < inferred_type->num_fields(); ++i) { + ARROW_ASSIGN_OR_RAISE( + const bool child_modified, + ApplyOriginalMetadata(*origin_type->field(i), &inferred->children[i])); + modified |= child_modified; + } + if (modified) { + // Recreate this field using the modified child fields + ::arrow::FieldVector modified_children(inferred_type->num_fields()); + for (int i = 0; i < inferred_type->num_fields(); ++i) { + modified_children[i] = inferred->children[i].field; + } + inferred->field = + inferred->field->WithType(factory(std::move(modified_children))); + } + } + } + if (origin_type->id() == ::arrow::Type::TIMESTAMP && inferred_type->id() == ::arrow::Type::TIMESTAMP) { // Restore time zone, if any @@ -706,15 +756,19 @@ Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* infe ts_origin_type.timezone() != "") { inferred->field = inferred->field->WithType(origin_type); } + modified = true; } if (origin_type->id() == ::arrow::Type::DICTIONARY && inferred_type->id() != ::arrow::Type::DICTIONARY && IsDictionaryReadSupported(*inferred_type)) { + // Direct dictionary reads are only suppored for a couple primitive types, + // so no need to recurse on value types. const auto& dict_origin_type = checked_cast(*origin_type); inferred->field = inferred->field->WithType( ::arrow::dictionary(::arrow::int32(), inferred_type, dict_origin_type.ordered())); + modified = true; } // Restore field metadata @@ -725,23 +779,15 @@ Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* infe field_metadata = field_metadata->Merge(*inferred->field->metadata()); } inferred->field = inferred->field->WithMetadata(field_metadata); + modified = true; } - if (origin_type->id() == ::arrow::Type::EXTENSION) { - // Restore extension type, if the storage type is as read from Parquet - const auto& ex_type = checked_cast(*origin_type); - if (ex_type.storage_type()->Equals(*inferred_type)) { - inferred->field = inferred->field->WithType(origin_type); - } - } - - // TODO Should apply metadata recursively to children, but for that we need - // to move metadata application inside NodeToSchemaField (ARROW-9943) - - return Status::OK(); + return modified; } -Status ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) { +Result ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) { + bool modified = false; + auto origin_type = origin_field.type(); auto inferred_type = inferred->field->type(); @@ -751,19 +797,18 @@ Status ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) { // Apply metadata recursively to storage type RETURN_NOT_OK(ApplyOriginalStorageMetadata(*origin_storage_field, inferred)); - inferred->storage_field = inferred->field; // Restore extension type, if the storage type is the same as inferred // from the Parquet type - if (ex_type.storage_type()->Equals(*inferred_type)) { + if (ex_type.storage_type()->Equals(*inferred->field->type())) { inferred->field = inferred->field->WithType(origin_type); } + modified = true; } else { - RETURN_NOT_OK(ApplyOriginalStorageMetadata(origin_field, inferred)); - inferred->storage_field = inferred->field; + ARROW_ASSIGN_OR_RAISE(modified, ApplyOriginalStorageMetadata(origin_field, inferred)); } - return Status::OK(); + return modified; } } // namespace diff --git a/cpp/src/parquet/arrow/schema.h b/cpp/src/parquet/arrow/schema.h index 5d7fe349945..dd60fde4342 100644 --- a/cpp/src/parquet/arrow/schema.h +++ b/cpp/src/parquet/arrow/schema.h @@ -89,9 +89,6 @@ ::arrow::Status FromParquetSchema(const SchemaDescriptor* parquet_schema, /// \brief Bridge between an arrow::Field and parquet column indices. struct PARQUET_EXPORT SchemaField { std::shared_ptr<::arrow::Field> field; - // If field has an extension type, an equivalent field with the storage type, - // otherwise the field itself. - std::shared_ptr<::arrow::Field> storage_field; std::vector children; // Only set for leaf nodes diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 98e169c83a6..e6c7da1721e 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -68,13 +68,12 @@ def __reduce__(self): class MyListType(pa.PyExtensionType): - storage_type = pa.list_(pa.int32()) - def __init__(self): - pa.PyExtensionType.__init__(self, self.storage_type) + def __init__(self, storage_type): + pa.PyExtensionType.__init__(self, storage_type) def __reduce__(self): - return MyListType, () + return MyListType, (self.storage_type,) def ipc_write_batch(batch): @@ -503,7 +502,7 @@ def test_parquet_period(tmpdir, registered_period_type): @pytest.mark.parquet -def test_parquet_nested_storage(tmpdir): +def test_parquet_extension_with_nested_storage(tmpdir): # Parquet support for extension types with nested storage type import pyarrow.parquet as pq @@ -514,8 +513,8 @@ def test_parquet_nested_storage(tmpdir): mystruct_array = pa.ExtensionArray.from_storage(MyStructType(), struct_array) - mylist_array = pa.ExtensionArray.from_storage(MyListType(), - list_array) + mylist_array = pa.ExtensionArray.from_storage( + MyListType(list_array.type), list_array) orig_table = pa.table({'structs': mystruct_array, 'lists': mylist_array}) @@ -528,7 +527,6 @@ def test_parquet_nested_storage(tmpdir): assert table == orig_table -@pytest.mark.xfail(reason="Need recursive metadata application (ARROW-9943)") @pytest.mark.parquet def test_parquet_nested_extension(tmpdir): # Parquet support for extension types nested in struct or list @@ -537,18 +535,54 @@ def test_parquet_nested_extension(tmpdir): ext_type = IntegerType() storage = pa.array([4, 5, 6, 7], type=pa.int64()) ext_array = pa.ExtensionArray.from_storage(ext_type, storage) + + # Struct of extensions struct_array = pa.StructArray.from_arrays( [storage, ext_array], names=['ints', 'exts']) orig_table = pa.table({'structs': struct_array}) - filename = tmpdir / 'nested_extension_type.parquet' + filename = tmpdir / 'struct_of_ext.parquet' pq.write_table(orig_table, filename) table = pq.read_table(filename) assert table.column(0).type == struct_array.type assert table == orig_table + # List of extensions + list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array) + + orig_table = pa.table({'lists': list_array}) + filename = tmpdir / 'list_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == list_array.type + assert table == orig_table + + +@pytest.mark.parquet +def test_parquet_extension_nested_in_extension(tmpdir): + # Parquet support for extension> + import pyarrow.parquet as pq + + inner_ext_type = IntegerType() + inner_storage = pa.array([4, 5, 6, 7], type=pa.int64()) + inner_ext_array = pa.ExtensionArray.from_storage(inner_ext_type, + inner_storage) + + list_array = pa.ListArray.from_arrays([0, 1, None, 3], inner_ext_array) + mylist_array = pa.ExtensionArray.from_storage( + MyListType(list_array.type), list_array) + + orig_table = pa.table({'lists': mylist_array}) + filename = tmpdir / 'ext_of_list_of_ext.parquet' + pq.write_table(orig_table, filename) + + table = pq.read_table(filename) + assert table.column(0).type == mylist_array.type + assert table == orig_table + def test_to_numpy(): period_type = PeriodType('D')