From d71deda1093cb18438bafc78824d049d7aaf5a82 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 5 Nov 2021 16:04:29 -0400
Subject: [PATCH 1/3] ARROW-14615: [C++] Refactor nested field refs and add
union support
---
cpp/src/arrow/array/array_nested.cc | 36 +++++
cpp/src/arrow/array/array_nested.h | 8 +
cpp/src/arrow/array/array_union_test.cc | 49 ++++++
cpp/src/arrow/compute/api_scalar.cc | 8 +
cpp/src/arrow/compute/api_scalar.h | 12 ++
cpp/src/arrow/compute/exec/expression.cc | 28 +---
.../arrow/compute/kernels/scalar_nested.cc | 150 ++++++++++++++++++
.../compute/kernels/scalar_nested_test.cc | 103 ++++++++++++
docs/source/cpp/compute.rst | 43 +++--
docs/source/python/api/compute.rst | 1 +
10 files changed, 406 insertions(+), 32 deletions(-)
diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc
index 2b4006961c7..b954b4cdd64 100644
--- a/cpp/src/arrow/array/array_nested.cc
+++ b/cpp/src/arrow/array/array_nested.cc
@@ -35,6 +35,7 @@
#include "arrow/type_traits.h"
#include "arrow/util/atomic_shared_ptr.h"
#include "arrow/util/bit_util.h"
+#include "arrow/util/bitmap_generate.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
@@ -650,6 +651,41 @@ SparseUnionArray::SparseUnionArray(std::shared_ptr type, int64_t lengt
SetData(std::move(internal_data));
}
+Result> SparseUnionArray::GetFlattenedField(
+ int index, MemoryPool* pool) const {
+ auto child_data = data_->child_data[index]->Copy();
+ // Adjust the result offset/length to be absolute.
+ if (data_->offset != 0 || data_->length != child_data->length) {
+ child_data = child_data->Slice(data_->offset, data_->length);
+ }
+ std::shared_ptr child_null_bitmap = child_data->buffers[0];
+ const int64_t child_offset = child_data->offset;
+
+ // Synthesize a null bitmap based on the union discriminant.
+ // Make sure the bitmap has extra bits corresponding to the child offset.
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr flattened_null_bitmap,
+ AllocateEmptyBitmap(child_data->length + child_offset, pool));
+ const int8_t type_code = union_type()->type_codes()[index];
+ const int8_t* type_codes = raw_type_codes();
+ int64_t offset = 0;
+ internal::GenerateBitsUnrolled(flattened_null_bitmap->mutable_data(), child_offset,
+ data_->length,
+ [&] { return type_codes[offset++] == type_code; });
+
+ // The validity of a flattened datum is the logical AND of the synthesized
+ // null bitmap buffer and the individual field element's validity.
+ if (child_null_bitmap) {
+ BitmapAnd(flattened_null_bitmap->data(), child_offset, child_null_bitmap->data(),
+ child_offset, child_data->length, child_offset,
+ flattened_null_bitmap->mutable_data());
+ }
+
+ auto flattened_data = child_data->Copy();
+ flattened_data->buffers[0] = std::move(flattened_null_bitmap);
+ flattened_data->null_count = kUnknownNullCount;
+ return MakeArray(flattened_data);
+}
+
DenseUnionArray::DenseUnionArray(const std::shared_ptr& data) {
SetData(data);
}
diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h
index 178a0589d5a..b89680b79c0 100644
--- a/cpp/src/arrow/array/array_nested.h
+++ b/cpp/src/arrow/array/array_nested.h
@@ -464,6 +464,14 @@ class ARROW_EXPORT SparseUnionArray : public UnionArray {
return internal::checked_cast(union_type_);
}
+ /// \brief Get one of the child arrays, adjusting its null bitmap
+ /// where the union array type code does not match.
+ ///
+ /// \param[in] index Which child array to get (i.e. the physical index, not the type
+ /// code) \param[in] pool The pool to allocate null bitmaps from, if necessary
+ Result> GetFlattenedField(
+ int index, MemoryPool* pool = default_memory_pool()) const;
+
protected:
void SetData(std::shared_ptr data);
};
diff --git a/cpp/src/arrow/array/array_union_test.cc b/cpp/src/arrow/array/array_union_test.cc
index d3afe40df8d..2aeccbed31d 100644
--- a/cpp/src/arrow/array/array_union_test.cc
+++ b/cpp/src/arrow/array/array_union_test.cc
@@ -68,6 +68,55 @@ TEST(TestUnionArray, TestSliceEquals) {
CheckUnion(batch->column(1));
}
+TEST(TestSparseUnionArray, GetFlattenedField) {
+ auto ty = sparse_union({field("ints", int64()), field("strs", utf8())}, {2, 7});
+ auto ints = ArrayFromJSON(int64(), "[0, 1, 2, 3]");
+ auto strs = ArrayFromJSON(utf8(), R"(["a", "b", "c", "d"])");
+ auto ids = ArrayFromJSON(int8(), "[2, 7, 2, 7]")->data()->buffers[1];
+ const int length = 4;
+
+ {
+ SparseUnionArray arr(ty, length, {ints, strs}, ids);
+ ASSERT_OK(arr.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[0, null, 0, null]"), *flattened,
+ /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, "b", null, "d"])"), *flattened,
+ /*verbose=*/true);
+
+ const auto& sliced = checked_cast(*arr.Slice(1, 2));
+
+ ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 0]"), *flattened, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["b", null])"), *flattened,
+ /*verbose=*/true);
+ }
+ {
+ SparseUnionArray arr(ty, length - 2, {ints->Slice(1, 2), strs->Slice(1, 2)}, ids);
+ ASSERT_OK(arr.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 0]"), *flattened, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["b", null])"), *flattened,
+ /*verbose=*/true);
+
+ const auto& sliced = checked_cast(*arr.Slice(1, 1));
+
+ ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[0]"), *flattened, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null])"), *flattened, /*verbose=*/true);
+ }
+}
+
TEST(TestSparseUnionArray, Validate) {
auto a = ArrayFromJSON(int32(), "[4, 5]");
auto type = sparse_union({field("a", int32())});
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index e3fe1bdf73d..44aaa52a8f6 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -223,6 +223,8 @@ static auto kExtractRegexOptionsType = GetFunctionOptionsType(
DataMember("value_set", &SetLookupOptions::value_set),
DataMember("skip_nulls", &SetLookupOptions::skip_nulls));
+static auto kStructFieldOptionsType = GetFunctionOptionsType(
+ DataMember("indices", &StructFieldOptions::indices));
static auto kStrptimeOptionsType = GetFunctionOptionsType(
DataMember("format", &StrptimeOptions::format),
DataMember("unit", &StrptimeOptions::unit));
@@ -351,6 +353,11 @@ SetLookupOptions::SetLookupOptions(Datum value_set, bool skip_nulls)
SetLookupOptions::SetLookupOptions() : SetLookupOptions({}, false) {}
constexpr char SetLookupOptions::kTypeName[];
+StructFieldOptions::StructFieldOptions(std::vector indices)
+ : FunctionOptions(internal::kStructFieldOptionsType), indices(std::move(indices)) {}
+StructFieldOptions::StructFieldOptions() : StructFieldOptions(std::vector()) {}
+constexpr char StructFieldOptions::kTypeName[];
+
StrptimeOptions::StrptimeOptions(std::string format, TimeUnit::type unit)
: FunctionOptions(internal::kStrptimeOptionsType),
format(std::move(format)),
@@ -444,6 +451,7 @@ void RegisterScalarOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kReplaceSubstringOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kExtractRegexOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSetLookupOptionsType));
+ DCHECK_OK(registry->AddFunctionOptionsType(kStructFieldOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kStrptimeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kStrftimeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kAssumeTimezoneOptionsType));
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index 4bb18b37527..d2234a6182d 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -223,6 +223,18 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
bool skip_nulls;
};
+/// Options for struct_field function
+class ARROW_EXPORT StructFieldOptions : public FunctionOptions {
+ public:
+ explicit StructFieldOptions(std::vector indices);
+ StructFieldOptions();
+ constexpr static char const kTypeName[] = "StructFieldOptions";
+
+ /// The child indices to extract. For instance, to get the 2nd child
+ /// of the 1st child of a struct or union, this would be {0, 1}.
+ std::vector indices;
+};
+
class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
public:
explicit StrptimeOptions(std::string format, TimeUnit::type unit);
diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc
index 03db24b5413..4249179e1bf 100644
--- a/cpp/src/arrow/compute/exec/expression.cc
+++ b/cpp/src/arrow/compute/exec/expression.cc
@@ -510,29 +510,11 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i
}
Datum field = input[param->indices[0]];
- for (auto it = param->indices.begin() + 1; it != param->indices.end(); ++it) {
- if (field.type()->id() != Type::STRUCT) {
- return Status::Invalid("Nested field reference into a non-struct: ",
- *field.type());
- }
- const int index = *it;
- if (index < 0 || index >= field.type()->num_fields()) {
- return Status::Invalid("Out of bounds field reference: ", index, " but type has ",
- field.type()->num_fields(), " fields");
- }
- if (field.is_scalar()) {
- const auto& struct_scalar = field.scalar_as();
- if (!struct_scalar.is_valid) {
- return MakeNullScalar(param->descr.type);
- }
- field = struct_scalar.value[index];
- } else if (field.is_array()) {
- const auto& struct_array = field.array_as();
- ARROW_ASSIGN_OR_RAISE(
- field, struct_array->GetFlattenedField(index, exec_context->memory_pool()));
- } else {
- return Status::NotImplemented("Nested field reference into a ", field.ToString());
- }
+ if (param->indices.size() > 1) {
+ std::vector indices(param->indices.begin() + 1, param->indices.end());
+ compute::StructFieldOptions options(std::move(indices));
+ ARROW_ASSIGN_OR_RAISE(
+ field, compute::CallFunction("struct_field", {std::move(field)}, &options));
}
if (!field.type()->Equals(param->descr.type)) {
return Status::Invalid("Referenced field ", expr.ToString(), " was ",
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc
index aeac0d747b1..330ea5120e5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc
@@ -22,6 +22,7 @@
#include "arrow/compute/kernels/common.h"
#include "arrow/result.h"
#include "arrow/util/bit_block_counter.h"
+#include "arrow/util/bitmap_generate.h"
namespace arrow {
namespace compute {
@@ -187,6 +188,150 @@ const FunctionDoc list_element_doc(
"is emitted. Null values emit a null in the output."),
{"lists", "index"});
+struct StructFieldFunctor {
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = OptionsWrapper::Get(ctx);
+ std::shared_ptr current = batch[0].make_array();
+ for (const auto& index : options.indices) {
+ RETURN_NOT_OK(CheckIndex(index, *current->type()));
+ switch (current->type()->id()) {
+ case Type::STRUCT: {
+ const auto& struct_array = checked_cast(*current);
+ ARROW_ASSIGN_OR_RAISE(
+ current, struct_array.GetFlattenedField(index, ctx->memory_pool()));
+ break;
+ }
+ case Type::DENSE_UNION: {
+ // We implement this here instead of in DenseUnionArray since it's
+ // easiest to do via Take(), but DenseUnionArray can't rely on
+ // arrow::compute. See ARROW-8891.
+ const auto& union_array = checked_cast(*current);
+
+ // Generate a bitmap for the offsets buffer based on the type codes buffer.
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr take_bitmap,
+ ctx->AllocateBitmap(union_array.length() + union_array.offset()));
+ const int8_t* type_codes = union_array.raw_type_codes();
+ const int8_t type_code = union_array.union_type()->type_codes()[index];
+ int64_t offset = 0;
+ arrow::internal::GenerateBitsUnrolled(
+ take_bitmap->mutable_data(), union_array.offset(), union_array.length(),
+ [&] { return type_codes[offset++] == type_code; });
+
+ // Pass the combined buffer to Take().
+ Datum take_indices(
+ ArrayData(int32(), union_array.length(),
+ {std::move(take_bitmap), union_array.value_offsets()},
+ kUnknownNullCount, union_array.offset()));
+ // Do not slice the child since the indices are relative to the unsliced array.
+ ARROW_ASSIGN_OR_RAISE(
+ Datum result,
+ CallFunction("take", {union_array.field(index), std::move(take_indices)}));
+ current = result.make_array();
+ break;
+ }
+ case Type::SPARSE_UNION: {
+ const auto& union_array = checked_cast(*current);
+ ARROW_ASSIGN_OR_RAISE(current,
+ union_array.GetFlattenedField(index, ctx->memory_pool()));
+ break;
+ }
+ default:
+ // Should have been checked in ResolveStructFieldType
+ return Status::Invalid("struct_field: cannot reference child field of type ",
+ *current->type());
+ }
+ }
+ *out = current;
+ return Status::OK();
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& options = OptionsWrapper::Get(ctx);
+ const std::shared_ptr* current = &batch[0].scalar();
+ for (const auto& index : options.indices) {
+ RETURN_NOT_OK(CheckIndex(index, *(*current)->type));
+ if (!(*current)->is_valid) {
+ // out should already be a null scalar of the appropriate type
+ return Status::OK();
+ }
+
+ switch ((*current)->type->id()) {
+ case Type::STRUCT: {
+ current = &checked_cast(**current).value[index];
+ break;
+ }
+ case Type::DENSE_UNION:
+ case Type::SPARSE_UNION: {
+ const auto& union_scalar = checked_cast(**current);
+ const auto& union_ty = checked_cast(*(*current)->type);
+ if (union_scalar.type_code != union_ty.type_codes()[index]) {
+ // out should already be a null scalar of the appropriate type
+ return Status::OK();
+ }
+ current = &union_scalar.value;
+ break;
+ }
+ default:
+ // Should have been checked in ResolveStructFieldType
+ return Status::Invalid("struct_field: cannot reference child field of type ",
+ *(*current)->type);
+ }
+ }
+ *out = *current;
+ return Status::OK();
+ }
+
+ static Status CheckIndex(int index, const DataType& type) {
+ if (!ValidParentType(type)) {
+ return Status::Invalid("struct_field: cannot subscript field of type ", type);
+ } else if (index < 0 || index > type.num_fields()) {
+ return Status::Invalid("struct_field: out-of-bounds field reference to field ",
+ index, " in type ", type, " with ", type.num_fields(),
+ " fields");
+ }
+ return Status::OK();
+ }
+
+ static bool ValidParentType(const DataType& type) {
+ return type.id() == Type::STRUCT || type.id() == Type::DENSE_UNION ||
+ type.id() == Type::SPARSE_UNION;
+ }
+};
+
+Result ResolveStructFieldType(KernelContext* ctx,
+ const std::vector& descrs) {
+ const auto& options = OptionsWrapper::Get(ctx);
+ const std::shared_ptr* type = &descrs.front().type;
+ for (const auto& index : options.indices) {
+ RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, **type));
+ type = &(*type)->field(index)->type();
+ }
+ return ValueDescr(*type, descrs.front().shape);
+}
+
+void AddStructFieldKernels(ScalarFunction* func) {
+ for (const auto shape : {ValueDescr::ARRAY, ValueDescr::SCALAR}) {
+ for (const auto in_type : {Type::STRUCT, Type::DENSE_UNION, Type::SPARSE_UNION}) {
+ ScalarKernel kernel({InputType(in_type, shape)}, OutputType(ResolveStructFieldType),
+ shape == ValueDescr::ARRAY ? StructFieldFunctor::ExecArray
+ : StructFieldFunctor::ExecScalar,
+ OptionsWrapper::Init);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ }
+ }
+}
+
+const FunctionDoc struct_field_doc(
+ "Extract children of a struct or union value by index.",
+ ("Given a series of indices, extract the child array or scalar referenced "
+ "by the index. For union values, mask the child based on the type codes "
+ "of the union array. The indices are always the child index and not the "
+ "type code (for unions) - so the first child is always index 0."),
+ {"container"}, "StructFieldOptions");
+
Result MakeStructResolve(KernelContext* ctx,
const std::vector& descrs) {
auto names = OptionsWrapper::Get(ctx).field_names;
@@ -298,6 +443,11 @@ void RegisterScalarNested(FunctionRegistry* registry) {
AddListElementScalarKernels(list_element.get());
DCHECK_OK(registry->AddFunction(std::move(list_element)));
+ auto struct_field =
+ std::make_shared("struct_field", Arity::Unary(), &struct_field_doc);
+ AddStructFieldKernels(struct_field.get());
+ DCHECK_OK(registry->AddFunction(std::move(struct_field)));
+
static MakeStructOptions kDefaultMakeStructOptions;
auto make_struct_function = std::make_shared(
"make_struct", Arity::VarArgs(), &make_struct_doc, &kDefaultMakeStructOptions);
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
index cb16257399d..5733ef81293 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
@@ -107,6 +107,109 @@ TEST(TestScalarNested, ListElementInvalid) {
Raises(StatusCode::Invalid));
}
+TEST(TestScalarNested, StructField) {
+ StructFieldOptions trivial;
+ StructFieldOptions extract0({0});
+ StructFieldOptions extract20({2, 0});
+ StructFieldOptions invalid1({-1});
+ StructFieldOptions invalid2({2, 4});
+ StructFieldOptions invalid3({0, 1});
+ FieldVector fields = {field("a", int32()), field("b", utf8()),
+ field("c", struct_({
+ field("d", int64()),
+ field("e", float64()),
+ }))};
+ {
+ auto arr = ArrayFromJSON(struct_(fields), R"([
+ [1, "a", [10, 10.0]],
+ [null, "b", [11, 11.0]],
+ [3, null, [12, 12.0]],
+ null
+ ])");
+ CheckScalar("struct_field", {arr}, arr, &trivial);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"),
+ &extract0);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"),
+ &extract20);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid1));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid2));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ CallFunction("struct_field", {arr}, &invalid3));
+ }
+ {
+ auto ty = dense_union(fields, {2, 5, 8});
+ auto arr = ArrayFromJSON(ty, R"([
+ [2, 1],
+ [5, "foo"],
+ [8, null],
+ [8, [10, 10.0]]
+ ])");
+ CheckScalar("struct_field", {arr}, arr, &trivial);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
+ &extract20);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid1));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid2));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ CallFunction("struct_field", {arr}, &invalid3));
+
+ // Test edge cases for union representation
+ auto ints = ArrayFromJSON(fields[0]->type(), "[null, 2, 3]");
+ auto strs = ArrayFromJSON(fields[1]->type(), R"([null, "bar"])");
+ auto nested = ArrayFromJSON(fields[2]->type(), R"([null, [10, 10.0]])");
+ auto type_ids = ArrayFromJSON(int8(), "[2, 5, 8, 2, 5, 8]")->data()->buffers[1];
+ auto offsets = ArrayFromJSON(int32(), "[0, 0, 0, 1, 1, 1]")->data()->buffers[1];
+
+ arr = std::make_shared(ty, /*length=*/6,
+ ArrayVector{ints, strs, nested}, type_ids,
+ offsets, /*offset=*/0);
+ // Sliced parent
+ CheckScalar("struct_field", {arr->Slice(3, 3)},
+ ArrayFromJSON(int32(), "[2, null, null]"), &extract0);
+ // Sliced child
+ arr = std::make_shared(ty, /*length=*/6,
+ ArrayVector{ints->Slice(1, 2), strs, nested},
+ type_ids, offsets, /*offset=*/0);
+ CheckScalar("struct_field", {arr},
+ ArrayFromJSON(int32(), "[2, null, null, 3, null, null]"), &extract0);
+ // Sliced parent + sliced child
+ CheckScalar("struct_field", {arr->Slice(3, 3)},
+ ArrayFromJSON(int32(), "[3, null, null]"), &extract0);
+ }
+ {
+ // The underlying implementation is tested directly/more thoroughly in
+ // array_union_test.cc.
+ auto arr = ArrayFromJSON(sparse_union(fields, {2, 5, 8}), R"([
+ [2, 1],
+ [5, "foo"],
+ [8, null],
+ [8, [10, 10.0]]
+ ])");
+ CheckScalar("struct_field", {arr}, arr, &trivial);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
+ &extract20);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid1));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("out-of-bounds field reference"),
+ CallFunction("struct_field", {arr}, &invalid2));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ CallFunction("struct_field", {arr}, &invalid3));
+ }
+}
+
struct {
Result operator()(std::vector args) {
return CallFunction("make_struct", args);
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 824170481a3..455a1bc0e09 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1587,15 +1587,17 @@ in the respective option classes.
Structural transforms
~~~~~~~~~~~~~~~~~~~~~
-+---------------------+------------+-------------------------------------+------------------+--------+
-| Function name | Arity | Input types | Output type | Notes |
-+=====================+============+=====================================+==================+========+
-| list_element | Binary | List-like (Arg 0), Integral (Arg 1) | List value type | \(1) |
-+---------------------+------------+-------------------------------------+------------------+--------+
-| list_flatten | Unary | List-like | List value type | \(2) |
-+---------------------+------------+-------------------------------------+------------------+--------+
-| list_parent_indices | Unary | List-like | Int32 or Int64 | \(3) |
-+---------------------+------------+-------------------------------------+------------------+--------+
++---------------------+------------+-------------------------------------+------------------+------------------------------+--------+
+| Function name | Arity | Input types | Output type | Options class | Notes |
++=====================+============+=====================================+==================+==============================+========+
+| list_element | Binary | List-like (Arg 0), Integral (Arg 1) | List value type | | \(1) |
++---------------------+------------+-------------------------------------+------------------+------------------------------+--------+
+| list_flatten | Unary | List-like | List value type | | \(2) |
++---------------------+------------+-------------------------------------+------------------+------------------------------+--------+
+| list_parent_indices | Unary | List-like | Int32 or Int64 | | \(3) |
++---------------------+------------+-------------------------------------+------------------+------------------------------+--------+
+| struct_field | Unary | Struct or Union | Computed | :struct:`StructFieldOptions` | \(4) |
++---------------------+------------+-------------------------------------+------------------+------------------------------+--------+
* \(1) Output is an array of the same length as the input list array. The
output values are the values at the specified index of each child list.
@@ -1609,6 +1611,29 @@ Structural transforms
are discarded. Output type is Int32 for List and FixedSizeList, Int64 for
LargeList.
+* \(4) Extract a child value based on a sequence of indices passed in
+ the options. The validity bitmap of the result will be the
+ intersection of all intermediate validity bitmaps. For example, for
+ an array with type ``struct>``:
+
+ * An empty sequence of indices yields the original value unchanged.
+ * The index ``0`` yields an array of type ``int32`` whose validity
+ bitmap is the intersection of the bitmap for the outermost struct
+ and the bitmap for the child ``a``.
+ * The index ``1, 1`` yields an array of type ``float64`` whose
+ validity bitmap is the intersection of the bitmaps for the
+ outermost struct, for struct ``b``, and for the child ``d``.
+
+ For unions, a validity bitmap is synthesized based on the type
+ codes. Also, the index is always the child index and not a type code.
+ Hence for array with type ``sparse_union<2: int32, 7: utf8>``:
+
+ * The index ``0`` yields an array of type ``int32``, which is valid
+ at an index *n* if and only if the child array ``a`` is valid at
+ index *n* and the type code at index *n* is 2.
+ * The indices ``2`` and ``7`` are invalid.
+
These functions create a copy of the first input with some elements
replaced, based on the remaining inputs.
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 225d853718f..00897a24983 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -497,3 +497,4 @@ Structural Transforms
list_value_length
make_struct
replace_with_mask
+ struct_field
From 745baa1b48f3a34dd016c5c3efc2dc2924d8a62e Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 8 Nov 2021 13:35:13 -0500
Subject: [PATCH 2/3] ARROW-14615: [C++] Fix test case
---
cpp/src/arrow/array/array_nested.cc | 3 +++
cpp/src/arrow/array/array_union_test.cc | 34 ++++++++++++++-----------
2 files changed, 22 insertions(+), 15 deletions(-)
diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc
index b954b4cdd64..d5768383e79 100644
--- a/cpp/src/arrow/array/array_nested.cc
+++ b/cpp/src/arrow/array/array_nested.cc
@@ -653,6 +653,9 @@ SparseUnionArray::SparseUnionArray(std::shared_ptr type, int64_t lengt
Result> SparseUnionArray::GetFlattenedField(
int index, MemoryPool* pool) const {
+ if (index < 0 || index >= num_fields()) {
+ return Status::Invalid("Index out of range: ", index);
+ }
auto child_data = data_->child_data[index]->Copy();
// Adjust the result offset/length to be absolute.
if (data_->offset != 0 || data_->length != child_data->length) {
diff --git a/cpp/src/arrow/array/array_union_test.cc b/cpp/src/arrow/array/array_union_test.cc
index 2aeccbed31d..22e1a56134f 100644
--- a/cpp/src/arrow/array/array_union_test.cc
+++ b/cpp/src/arrow/array/array_union_test.cc
@@ -32,6 +32,7 @@
namespace arrow {
using internal::checked_cast;
+using internal::checked_pointer_cast;
TEST(TestUnionArray, TestSliceEquals) {
std::shared_ptr batch;
@@ -71,7 +72,7 @@ TEST(TestUnionArray, TestSliceEquals) {
TEST(TestSparseUnionArray, GetFlattenedField) {
auto ty = sparse_union({field("ints", int64()), field("strs", utf8())}, {2, 7});
auto ints = ArrayFromJSON(int64(), "[0, 1, 2, 3]");
- auto strs = ArrayFromJSON(utf8(), R"(["a", "b", "c", "d"])");
+ auto strs = ArrayFromJSON(utf8(), R"(["a", null, "c", "d"])");
auto ids = ArrayFromJSON(int8(), "[2, 7, 2, 7]")->data()->buffers[1];
const int length = 4;
@@ -80,40 +81,43 @@ TEST(TestSparseUnionArray, GetFlattenedField) {
ASSERT_OK(arr.ValidateFull());
ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
- AssertArraysEqual(*ArrayFromJSON(int64(), "[0, null, 0, null]"), *flattened,
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[0, null, 2, null]"), *flattened,
/*verbose=*/true);
ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
- AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, "b", null, "d"])"), *flattened,
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null, null, "d"])"), *flattened,
/*verbose=*/true);
- const auto& sliced = checked_cast(*arr.Slice(1, 2));
+ const auto sliced = checked_pointer_cast(arr.Slice(1, 2));
- ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(0));
- AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 0]"), *flattened, /*verbose=*/true);
+ ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 2]"), *flattened, /*verbose=*/true);
- ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(1));
- AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["b", null])"), *flattened,
+ ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, null])"), *flattened,
/*verbose=*/true);
+
+ ASSERT_RAISES(Invalid, arr.GetFlattenedField(-1));
+ ASSERT_RAISES(Invalid, arr.GetFlattenedField(2));
}
{
SparseUnionArray arr(ty, length - 2, {ints->Slice(1, 2), strs->Slice(1, 2)}, ids);
ASSERT_OK(arr.ValidateFull());
ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
- AssertArraysEqual(*ArrayFromJSON(int64(), "[null, 0]"), *flattened, /*verbose=*/true);
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[1, null]"), *flattened, /*verbose=*/true);
ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
- AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["b", null])"), *flattened,
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null, "c"])"), *flattened,
/*verbose=*/true);
- const auto& sliced = checked_cast(*arr.Slice(1, 1));
+ const auto sliced = checked_pointer_cast(arr.Slice(1, 1));
- ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(0));
- AssertArraysEqual(*ArrayFromJSON(int64(), "[0]"), *flattened, /*verbose=*/true);
+ ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[null]"), *flattened, /*verbose=*/true);
- ASSERT_OK_AND_ASSIGN(flattened, sliced.GetFlattenedField(1));
- AssertArraysEqual(*ArrayFromJSON(utf8(), R"([null])"), *flattened, /*verbose=*/true);
+ ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["c"])"), *flattened, /*verbose=*/true);
}
}
From 6f82b430a701e5f83362b3118d2943c269b18552 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 9 Nov 2021 14:06:13 -0500
Subject: [PATCH 3/3] ARROW-14615: [C++] Address feedback
---
cpp/src/arrow/array/array_nested.cc | 7 +++----
cpp/src/arrow/array/array_union_test.cc | 12 +++++++++++
.../arrow/compute/kernels/scalar_nested.cc | 20 ++++++++++---------
.../compute/kernels/scalar_nested_test.cc | 11 +++++++---
4 files changed, 34 insertions(+), 16 deletions(-)
diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc
index d5768383e79..a3c1fab054e 100644
--- a/cpp/src/arrow/array/array_nested.cc
+++ b/cpp/src/arrow/array/array_nested.cc
@@ -683,10 +683,9 @@ Result> SparseUnionArray::GetFlattenedField(
flattened_null_bitmap->mutable_data());
}
- auto flattened_data = child_data->Copy();
- flattened_data->buffers[0] = std::move(flattened_null_bitmap);
- flattened_data->null_count = kUnknownNullCount;
- return MakeArray(flattened_data);
+ child_data->buffers[0] = std::move(flattened_null_bitmap);
+ child_data->null_count = kUnknownNullCount;
+ return MakeArray(child_data);
}
DenseUnionArray::DenseUnionArray(const std::shared_ptr& data) {
diff --git a/cpp/src/arrow/array/array_union_test.cc b/cpp/src/arrow/array/array_union_test.cc
index 22e1a56134f..3bd87a3438f 100644
--- a/cpp/src/arrow/array/array_union_test.cc
+++ b/cpp/src/arrow/array/array_union_test.cc
@@ -119,6 +119,18 @@ TEST(TestSparseUnionArray, GetFlattenedField) {
ASSERT_OK_AND_ASSIGN(flattened, sliced->GetFlattenedField(1));
AssertArraysEqual(*ArrayFromJSON(utf8(), R"(["c"])"), *flattened, /*verbose=*/true);
}
+ {
+ SparseUnionArray arr(ty, /*length=*/0, {ints->Slice(length), strs->Slice(length)},
+ ids);
+ ASSERT_OK(arr.ValidateFull());
+
+ ASSERT_OK_AND_ASSIGN(auto flattened, arr.GetFlattenedField(0));
+ AssertArraysEqual(*ArrayFromJSON(int64(), "[]"), *flattened, /*verbose=*/true);
+
+ ASSERT_OK_AND_ASSIGN(flattened, arr.GetFlattenedField(1));
+ AssertArraysEqual(*ArrayFromJSON(utf8(), "[]"), *flattened,
+ /*verbose=*/true);
+ }
}
TEST(TestSparseUnionArray, Validate) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc
index 330ea5120e5..682f73632b2 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc
@@ -238,8 +238,8 @@ struct StructFieldFunctor {
}
default:
// Should have been checked in ResolveStructFieldType
- return Status::Invalid("struct_field: cannot reference child field of type ",
- *current->type());
+ return Status::TypeError("struct_field: cannot reference child field of type ",
+ *current->type());
}
}
*out = current;
@@ -274,8 +274,8 @@ struct StructFieldFunctor {
}
default:
// Should have been checked in ResolveStructFieldType
- return Status::Invalid("struct_field: cannot reference child field of type ",
- *(*current)->type);
+ return Status::TypeError("struct_field: cannot reference child field of type ",
+ *(*current)->type);
}
}
*out = *current;
@@ -284,7 +284,7 @@ struct StructFieldFunctor {
static Status CheckIndex(int index, const DataType& type) {
if (!ValidParentType(type)) {
- return Status::Invalid("struct_field: cannot subscript field of type ", type);
+ return Status::TypeError("struct_field: cannot subscript field of type ", type);
} else if (index < 0 || index > type.num_fields()) {
return Status::Invalid("struct_field: out-of-bounds field reference to field ",
index, " in type ", type, " with ", type.num_fields(),
@@ -326,10 +326,12 @@ void AddStructFieldKernels(ScalarFunction* func) {
const FunctionDoc struct_field_doc(
"Extract children of a struct or union value by index.",
- ("Given a series of indices, extract the child array or scalar referenced "
- "by the index. For union values, mask the child based on the type codes "
- "of the union array. The indices are always the child index and not the "
- "type code (for unions) - so the first child is always index 0."),
+ ("Given a series of indices (passed via StructFieldOptions), extract the "
+ "child array or scalar referenced by the index. For union values, mask "
+ "the child based on the type codes of the union array. The indices are "
+ "always the child index and not the type code (for unions) - so the "
+ "first child is always index 0. An empty set of indices returns the "
+ "argument unchanged."),
{"container"}, "StructFieldOptions");
Result MakeStructResolve(KernelContext* ctx,
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
index 5733ef81293..0b6f7bcc1ec 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
@@ -137,7 +137,7 @@ TEST(TestScalarNested, StructField) {
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid2));
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
CallFunction("struct_field", {arr}, &invalid3));
}
{
@@ -159,7 +159,7 @@ TEST(TestScalarNested, StructField) {
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid2));
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
CallFunction("struct_field", {arr}, &invalid3));
// Test edge cases for union representation
@@ -205,9 +205,14 @@ TEST(TestScalarNested, StructField) {
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid2));
- EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
CallFunction("struct_field", {arr}, &invalid3));
}
+ {
+ auto arr = ArrayFromJSON(int32(), "[0, 1, 2, 3]");
+ ASSERT_RAISES(NotImplemented, CallFunction("struct_field", {arr}, &trivial));
+ ASSERT_RAISES(NotImplemented, CallFunction("struct_field", {arr}, &extract0));
+ }
}
struct {