diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
index 6444af99f64..537c4a26412 100644
--- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
+++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc
@@ -3431,7 +3431,6 @@ TEST(TestArrowWriteDictionaries, AutoReadAsDictionary) {
}
TEST(TestArrowWriteDictionaries, NestedSubfield) {
- // FIXME (ARROW-9943): Automatic decoding of dictionary subfields
auto offsets = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 2, 3]");
auto indices = ::arrow::ArrayFromJSON(::arrow::int32(), "[0, 0, 0]");
auto dict = ::arrow::ArrayFromJSON(::arrow::utf8(), "[\"foo\"]");
@@ -3442,20 +3441,14 @@ TEST(TestArrowWriteDictionaries, NestedSubfield) {
ASSERT_OK_AND_ASSIGN(auto values,
::arrow::ListArray::FromArrays(*offsets, *dict_values));
- auto dense_ty = ::arrow::list(::arrow::utf8());
- auto dense_values =
- ::arrow::ArrayFromJSON(dense_ty, "[[], [\"foo\", \"foo\"], [\"foo\"]]");
-
auto table = MakeSimpleTable(values, /*nullable=*/true);
- auto expected_table = MakeSimpleTable(dense_values, /*nullable=*/true);
auto props_store_schema = ArrowWriterProperties::Builder().store_schema()->build();
std::shared_ptr
actual;
DoRoundtrip(table, values->length(), &actual, default_writer_properties(),
props_store_schema);
- // The nested subfield is not automatically decoded to dictionary
- ::arrow::AssertTablesEqual(*expected_table, *actual);
+ ::arrow::AssertTablesEqual(*table, *actual);
}
} // namespace arrow
diff --git a/cpp/src/parquet/arrow/reader.cc b/cpp/src/parquet/arrow/reader.cc
index e1bc8a7b198..af437437c19 100644
--- a/cpp/src/parquet/arrow/reader.cc
+++ b/cpp/src/parquet/arrow/reader.cc
@@ -744,12 +744,20 @@ Status StructReader::BuildArray(int64_t length_upper_bound,
// ----------------------------------------------------------------------
// File reader implementation
-Status GetStorageReader(const SchemaField& field,
- const std::shared_ptr& ctx,
- std::unique_ptr* out) {
+Status GetReader(const SchemaField& field, const std::shared_ptr& arrow_field,
+ const std::shared_ptr& ctx,
+ std::unique_ptr* out) {
BEGIN_PARQUET_CATCH_EXCEPTIONS
- auto type_id = field.storage_field->type()->id();
+ auto type_id = arrow_field->type()->id();
+
+ if (type_id == ::arrow::Type::EXTENSION) {
+ auto storage_field = arrow_field->WithType(
+ checked_cast(*arrow_field->type()).storage_type());
+ RETURN_NOT_OK(GetReader(field, storage_field, ctx, out));
+ out->reset(new ExtensionReader(arrow_field, std::move(*out)));
+ return Status::OK();
+ }
if (field.children.size() == 0) {
if (!field.is_leaf()) {
@@ -761,10 +769,8 @@ Status GetStorageReader(const SchemaField& field,
}
std::unique_ptr input(
ctx->iterator_factory(field.column_index, ctx->reader));
- out->reset(
- new LeafReader(ctx, field.storage_field, std::move(input), field.level_info));
+ out->reset(new LeafReader(ctx, arrow_field, std::move(input), field.level_info));
} else if (type_id == ::arrow::Type::LIST) {
- auto list_field = field.storage_field;
auto child = &field.children[0];
std::unique_ptr child_reader;
RETURN_NOT_OK(GetReader(*child, ctx, &child_reader));
@@ -772,7 +778,7 @@ Status GetStorageReader(const SchemaField& field,
*out = nullptr;
return Status::OK();
}
- out->reset(new ListReader(ctx, list_field, field.level_info,
+ out->reset(new ListReader(ctx, arrow_field, field.level_info,
std::move(child_reader)));
} else if (type_id == ::arrow::Type::STRUCT) {
std::vector> child_fields;
@@ -792,12 +798,12 @@ Status GetStorageReader(const SchemaField& field,
return Status::OK();
}
auto filtered_field =
- ::arrow::field(field.storage_field->name(), ::arrow::struct_(child_fields),
- field.storage_field->nullable(), field.storage_field->metadata());
+ ::arrow::field(arrow_field->name(), ::arrow::struct_(child_fields),
+ arrow_field->nullable(), arrow_field->metadata());
out->reset(new StructReader(ctx, filtered_field, field.level_info,
std::move(child_readers)));
} else {
- return Status::Invalid("Unsupported nested type: ", field.storage_field->ToString());
+ return Status::Invalid("Unsupported nested type: ", arrow_field->ToString());
}
return Status::OK();
@@ -806,11 +812,7 @@ Status GetStorageReader(const SchemaField& field,
Status GetReader(const SchemaField& field, const std::shared_ptr& ctx,
std::unique_ptr* out) {
- RETURN_NOT_OK(GetStorageReader(field, ctx, out));
- if (field.field->type()->id() == ::arrow::Type::EXTENSION) {
- out->reset(new ExtensionReader(field.field, std::move(*out)));
- }
- return Status::OK();
+ return GetReader(field, field.field, ctx, out);
}
} // namespace
diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc
index 3f7ff922305..6babe9bc7cf 100644
--- a/cpp/src/parquet/arrow/schema.cc
+++ b/cpp/src/parquet/arrow/schema.cc
@@ -17,6 +17,7 @@
#include "parquet/arrow/schema.h"
+#include
#include
#include
@@ -35,6 +36,7 @@
#include "parquet/types.h"
using arrow::Field;
+using arrow::FieldVector;
using arrow::KeyValueMetadata;
using arrow::Status;
using arrow::internal::checked_cast;
@@ -425,7 +427,6 @@ Status PopulateLeaf(int column_index, const std::shared_ptr& field,
LevelInfo current_levels, SchemaTreeContext* ctx,
const SchemaField* parent, SchemaField* out) {
out->field = field;
- out->storage_field = field;
out->column_index = column_index;
out->level_info = current_levels;
ctx->RecordLeaf(out);
@@ -462,7 +463,6 @@ Status GroupToStruct(const GroupNode& node, LevelInfo current_levels,
auto struct_type = ::arrow::struct_(arrow_fields);
out->field = ::arrow::field(node.name(), struct_type, node.is_optional(),
FieldIdMetadata(node.field_id()));
- out->storage_field = out->field;
out->level_info = current_levels;
return Status::OK();
}
@@ -543,7 +543,6 @@ Status ListToSchemaField(const GroupNode& group, LevelInfo current_levels,
}
out->field = ::arrow::field(group.name(), ::arrow::list(child_field->field),
group.is_optional(), FieldIdMetadata(group.field_id()));
- out->storage_field = out->field;
out->level_info = current_levels;
// At this point current levels contains the def level for this list,
// we need to reset to the prior parent.
@@ -571,7 +570,6 @@ Status GroupToSchemaField(const GroupNode& node, LevelInfo current_levels,
RETURN_NOT_OK(GroupToStruct(node, current_levels, ctx, out, &out->children[0]));
out->field = ::arrow::field(node.name(), ::arrow::list(out->children[0].field),
/*nullable=*/false, FieldIdMetadata(node.field_id()));
- out->storage_field = out->field;
ctx->LinkParent(&out->children[0], out);
out->level_info = current_levels;
@@ -623,7 +621,6 @@ Status NodeToSchemaField(const Node& node, LevelInfo current_levels,
out->field = ::arrow::field(node.name(), ::arrow::list(child_field),
/*nullable=*/false, FieldIdMetadata(node.field_id()));
- out->storage_field = out->field;
out->level_info = current_levels;
// At this point current_levels has consider this list the ancestor so restore
// the actual ancenstor.
@@ -689,10 +686,63 @@ Status GetOriginSchema(const std::shared_ptr& metadata,
// but that is not necessarily present in the field reconstitued from Parquet data
// (for example, Parquet timestamp types doesn't carry timezone information).
-Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* inferred) {
+Result ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred);
+
+std::function(FieldVector)> GetNestedFactory(
+ const ArrowType& origin_type, const ArrowType& inferred_type) {
+ switch (inferred_type.id()) {
+ case ::arrow::Type::STRUCT:
+ if (origin_type.id() == ::arrow::Type::STRUCT) {
+ return ::arrow::struct_;
+ }
+ break;
+ case ::arrow::Type::LIST:
+ // TODO also allow LARGE_LIST and FIXED_SIZE_LIST
+ if (origin_type.id() == ::arrow::Type::LIST) {
+ return [](FieldVector fields) {
+ DCHECK_EQ(fields.size(), 1);
+ return ::arrow::list(std::move(fields[0]));
+ };
+ }
+ break;
+ default:
+ break;
+ }
+ return {};
+}
+
+Result ApplyOriginalStorageMetadata(const Field& origin_field,
+ SchemaField* inferred) {
+ bool modified = false;
+
auto origin_type = origin_field.type();
auto inferred_type = inferred->field->type();
+ const int num_children = inferred_type->num_fields();
+
+ if (num_children > 0 && origin_type->num_fields() == num_children) {
+ DCHECK_EQ(static_cast(inferred->children.size()), num_children);
+ const auto factory = GetNestedFactory(*origin_type, *inferred_type);
+ if (factory) {
+ // Apply original metadata recursively to children
+ for (int i = 0; i < inferred_type->num_fields(); ++i) {
+ ARROW_ASSIGN_OR_RAISE(
+ const bool child_modified,
+ ApplyOriginalMetadata(*origin_type->field(i), &inferred->children[i]));
+ modified |= child_modified;
+ }
+ if (modified) {
+ // Recreate this field using the modified child fields
+ ::arrow::FieldVector modified_children(inferred_type->num_fields());
+ for (int i = 0; i < inferred_type->num_fields(); ++i) {
+ modified_children[i] = inferred->children[i].field;
+ }
+ inferred->field =
+ inferred->field->WithType(factory(std::move(modified_children)));
+ }
+ }
+ }
+
if (origin_type->id() == ::arrow::Type::TIMESTAMP &&
inferred_type->id() == ::arrow::Type::TIMESTAMP) {
// Restore time zone, if any
@@ -706,15 +756,19 @@ Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* infe
ts_origin_type.timezone() != "") {
inferred->field = inferred->field->WithType(origin_type);
}
+ modified = true;
}
if (origin_type->id() == ::arrow::Type::DICTIONARY &&
inferred_type->id() != ::arrow::Type::DICTIONARY &&
IsDictionaryReadSupported(*inferred_type)) {
+ // Direct dictionary reads are only suppored for a couple primitive types,
+ // so no need to recurse on value types.
const auto& dict_origin_type =
checked_cast(*origin_type);
inferred->field = inferred->field->WithType(
::arrow::dictionary(::arrow::int32(), inferred_type, dict_origin_type.ordered()));
+ modified = true;
}
// Restore field metadata
@@ -725,23 +779,15 @@ Status ApplyOriginalStorageMetadata(const Field& origin_field, SchemaField* infe
field_metadata = field_metadata->Merge(*inferred->field->metadata());
}
inferred->field = inferred->field->WithMetadata(field_metadata);
+ modified = true;
}
- if (origin_type->id() == ::arrow::Type::EXTENSION) {
- // Restore extension type, if the storage type is as read from Parquet
- const auto& ex_type = checked_cast(*origin_type);
- if (ex_type.storage_type()->Equals(*inferred_type)) {
- inferred->field = inferred->field->WithType(origin_type);
- }
- }
-
- // TODO Should apply metadata recursively to children, but for that we need
- // to move metadata application inside NodeToSchemaField (ARROW-9943)
-
- return Status::OK();
+ return modified;
}
-Status ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) {
+Result ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) {
+ bool modified = false;
+
auto origin_type = origin_field.type();
auto inferred_type = inferred->field->type();
@@ -751,19 +797,18 @@ Status ApplyOriginalMetadata(const Field& origin_field, SchemaField* inferred) {
// Apply metadata recursively to storage type
RETURN_NOT_OK(ApplyOriginalStorageMetadata(*origin_storage_field, inferred));
- inferred->storage_field = inferred->field;
// Restore extension type, if the storage type is the same as inferred
// from the Parquet type
- if (ex_type.storage_type()->Equals(*inferred_type)) {
+ if (ex_type.storage_type()->Equals(*inferred->field->type())) {
inferred->field = inferred->field->WithType(origin_type);
}
+ modified = true;
} else {
- RETURN_NOT_OK(ApplyOriginalStorageMetadata(origin_field, inferred));
- inferred->storage_field = inferred->field;
+ ARROW_ASSIGN_OR_RAISE(modified, ApplyOriginalStorageMetadata(origin_field, inferred));
}
- return Status::OK();
+ return modified;
}
} // namespace
diff --git a/cpp/src/parquet/arrow/schema.h b/cpp/src/parquet/arrow/schema.h
index 5d7fe349945..dd60fde4342 100644
--- a/cpp/src/parquet/arrow/schema.h
+++ b/cpp/src/parquet/arrow/schema.h
@@ -89,9 +89,6 @@ ::arrow::Status FromParquetSchema(const SchemaDescriptor* parquet_schema,
/// \brief Bridge between an arrow::Field and parquet column indices.
struct PARQUET_EXPORT SchemaField {
std::shared_ptr<::arrow::Field> field;
- // If field has an extension type, an equivalent field with the storage type,
- // otherwise the field itself.
- std::shared_ptr<::arrow::Field> storage_field;
std::vector children;
// Only set for leaf nodes
diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py
index 98e169c83a6..e6c7da1721e 100644
--- a/python/pyarrow/tests/test_extension_type.py
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -68,13 +68,12 @@ def __reduce__(self):
class MyListType(pa.PyExtensionType):
- storage_type = pa.list_(pa.int32())
- def __init__(self):
- pa.PyExtensionType.__init__(self, self.storage_type)
+ def __init__(self, storage_type):
+ pa.PyExtensionType.__init__(self, storage_type)
def __reduce__(self):
- return MyListType, ()
+ return MyListType, (self.storage_type,)
def ipc_write_batch(batch):
@@ -503,7 +502,7 @@ def test_parquet_period(tmpdir, registered_period_type):
@pytest.mark.parquet
-def test_parquet_nested_storage(tmpdir):
+def test_parquet_extension_with_nested_storage(tmpdir):
# Parquet support for extension types with nested storage type
import pyarrow.parquet as pq
@@ -514,8 +513,8 @@ def test_parquet_nested_storage(tmpdir):
mystruct_array = pa.ExtensionArray.from_storage(MyStructType(),
struct_array)
- mylist_array = pa.ExtensionArray.from_storage(MyListType(),
- list_array)
+ mylist_array = pa.ExtensionArray.from_storage(
+ MyListType(list_array.type), list_array)
orig_table = pa.table({'structs': mystruct_array,
'lists': mylist_array})
@@ -528,7 +527,6 @@ def test_parquet_nested_storage(tmpdir):
assert table == orig_table
-@pytest.mark.xfail(reason="Need recursive metadata application (ARROW-9943)")
@pytest.mark.parquet
def test_parquet_nested_extension(tmpdir):
# Parquet support for extension types nested in struct or list
@@ -537,18 +535,54 @@ def test_parquet_nested_extension(tmpdir):
ext_type = IntegerType()
storage = pa.array([4, 5, 6, 7], type=pa.int64())
ext_array = pa.ExtensionArray.from_storage(ext_type, storage)
+
+ # Struct of extensions
struct_array = pa.StructArray.from_arrays(
[storage, ext_array],
names=['ints', 'exts'])
orig_table = pa.table({'structs': struct_array})
- filename = tmpdir / 'nested_extension_type.parquet'
+ filename = tmpdir / 'struct_of_ext.parquet'
pq.write_table(orig_table, filename)
table = pq.read_table(filename)
assert table.column(0).type == struct_array.type
assert table == orig_table
+ # List of extensions
+ list_array = pa.ListArray.from_arrays([0, 1, None, 3], ext_array)
+
+ orig_table = pa.table({'lists': list_array})
+ filename = tmpdir / 'list_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == list_array.type
+ assert table == orig_table
+
+
+@pytest.mark.parquet
+def test_parquet_extension_nested_in_extension(tmpdir):
+ # Parquet support for extension>
+ import pyarrow.parquet as pq
+
+ inner_ext_type = IntegerType()
+ inner_storage = pa.array([4, 5, 6, 7], type=pa.int64())
+ inner_ext_array = pa.ExtensionArray.from_storage(inner_ext_type,
+ inner_storage)
+
+ list_array = pa.ListArray.from_arrays([0, 1, None, 3], inner_ext_array)
+ mylist_array = pa.ExtensionArray.from_storage(
+ MyListType(list_array.type), list_array)
+
+ orig_table = pa.table({'lists': mylist_array})
+ filename = tmpdir / 'ext_of_list_of_ext.parquet'
+ pq.write_table(orig_table, filename)
+
+ table = pq.read_table(filename)
+ assert table.column(0).type == mylist_array.type
+ assert table == orig_table
+
def test_to_numpy():
period_type = PeriodType('D')