diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index 15c726241b5..5f447bbad11 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -51,6 +51,7 @@ class ARROW_EXPORT ArrayBuilder { explicit ArrayBuilder(MemoryPool* pool) : pool_(pool), null_bitmap_builder_(pool) {} virtual ~ArrayBuilder() = default; + ARROW_DEFAULT_MOVE_AND_ASSIGN(ArrayBuilder); /// For nested types. Since the objects are owned by this class instance, we /// skip shared pointers and just return a raw pointer diff --git a/cpp/src/arrow/compute/kernels/vector_selection.cc b/cpp/src/arrow/compute/kernels/vector_selection.cc index 6376ae10404..5845a7ee2d0 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection.cc @@ -1668,6 +1668,81 @@ struct ListImpl : public Selection, Type> { } }; +struct DenseUnionImpl : public Selection { + using Base = Selection; + LIFT_BASE_MEMBERS(); + + TypedBufferBuilder value_offset_buffer_builder_; + TypedBufferBuilder child_id_buffer_builder_; + std::vector type_codes_; + std::vector child_indices_builders_; + + DenseUnionImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length, + Datum* out) + : Base(ctx, batch, output_length, out), + value_offset_buffer_builder_(ctx->memory_pool()), + child_id_buffer_builder_(ctx->memory_pool()), + type_codes_(checked_cast(*this->values->type).type_codes()), + child_indices_builders_(type_codes_.size()) { + for (auto& child_indices_builder : child_indices_builders_) { + child_indices_builder = Int32Builder(ctx->memory_pool()); + } + } + + template + Status GenerateOutput() { + DenseUnionArray typed_values(this->values); + Adapter adapter(this); + RETURN_NOT_OK(adapter.Generate( + [&](int64_t index) { + int8_t child_id = typed_values.child_id(index); + child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); + int32_t value_offset = typed_values.value_offset(index); + value_offset_buffer_builder_.UnsafeAppend( + static_cast(child_indices_builders_[child_id].length())); + RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1)); + child_indices_builders_[child_id].UnsafeAppend(value_offset); + return Status::OK(); + }, + [&]() { + int8_t child_id = 0; + child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); + value_offset_buffer_builder_.UnsafeAppend( + static_cast(child_indices_builders_[child_id].length())); + RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1)); + child_indices_builders_[child_id].UnsafeAppendNull(); + return Status::OK(); + })); + return Status::OK(); + } + + Status Init() override { + RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length)); + RETURN_NOT_OK(value_offset_buffer_builder_.Reserve(output_length)); + return Status::OK(); + } + + Status Finish() override { + ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish()); + ARROW_ASSIGN_OR_RAISE(auto value_offsets_buffer, + value_offset_buffer_builder_.Finish()); + DenseUnionArray typed_values(this->values); + auto num_fields = typed_values.num_fields(); + auto num_rows = child_ids_buffer->size(); + BufferVector buffers{nullptr, std::move(child_ids_buffer), + std::move(value_offsets_buffer)}; + *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0); + for (auto i = 0; i < num_fields; i++) { + ARROW_ASSIGN_OR_RAISE(auto child_indices_array, + child_indices_builders_[i].Finish()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr child_array, + Take(*typed_values.field(i), *child_indices_array)); + out->child_data.push_back(child_array->data()); + } + return Status::OK(); + } +}; + struct FSLImpl : public Selection { Int64Builder child_index_builder; @@ -2141,6 +2216,7 @@ void RegisterVectorSelection(FunctionRegistry* registry) { {InputType::Array(Type::LIST), FilterExec>}, {InputType::Array(Type::LARGE_LIST), FilterExec>}, {InputType::Array(Type::FIXED_SIZE_LIST), FilterExec}, + {InputType::Array(Type::DENSE_UNION), FilterExec}, {InputType::Array(Type::STRUCT), StructFilter}, // TODO: Reuse ListType kernel for MAP {InputType::Array(Type::MAP), FilterExec>}, @@ -2170,6 +2246,7 @@ void RegisterVectorSelection(FunctionRegistry* registry) { {InputType::Array(Type::LIST), TakeExec>}, {InputType::Array(Type::LARGE_LIST), TakeExec>}, {InputType::Array(Type::FIXED_SIZE_LIST), TakeExec}, + {InputType::Array(Type::DENSE_UNION), TakeExec}, {InputType::Array(Type::STRUCT), TakeExec}, // TODO: Reuse ListType kernel for MAP {InputType::Array(Type::MAP), TakeExec>}, diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index f428da0fe35..e367d888d00 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -607,31 +607,31 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) { class TestFilterKernelWithUnion : public TestFilterKernel {}; -TEST_F(TestFilterKernelWithUnion, DISABLED_FilterUnion) { - for (auto union_ : UnionTypeFactories()) { - auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ - null, +TEST_F(TestFilterKernelWithUnion, FilterUnion) { + auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); + auto union_json = R"([ + [2, null], [2, 222], [5, "hello"], [5, "eh"], - null, - [2, 111] + [2, null], + [2, 111], + [5, null] ])"; - this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0]", "[]"); - this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1]", R"([ + this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); + this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ [2, 222], [5, "hello"], - null, - [2, 111] + [2, null], + [2, 111], + [5, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0]", R"([ - null, + this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ + [2, null], [5, "hello"], - null + [2, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1]", union_json); - } + this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); } class TestFilterKernelWithRecordBatch : public TestFilterKernel { @@ -1281,34 +1281,34 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) { class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; -// TODO: Restore Union take functionality -TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) { - for (auto union_ : UnionTypeFactories()) { - auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ - null, +TEST_F(TestTakeKernelWithUnion, TakeUnion) { + auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); + auto union_json = R"([ + [2, null], [2, 222], [5, "hello"], [5, "eh"], - null, - [2, 111] + [2, null], + [2, 111], + [5, null] ])"; - CheckTake(union_type, union_json, "[]", "[]"); - CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ + CheckTake(union_type, union_json, "[]", "[]"); + CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ [5, "eh"], [2, 222], [5, "eh"], [2, 222], [5, "eh"] ])"); - CheckTake(union_type, union_json, "[4, 2, 1]", R"([ - null, + CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([ + [2, null], [5, "hello"], - [2, 222] + [2, 222], + [5, null] ])"); - CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5]", union_json); - CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ - null, + CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); + CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + [2, null], [5, "hello"], [5, "hello"], [5, "hello"], @@ -1316,7 +1316,6 @@ TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) { [5, "hello"], [5, "hello"] ])"); - } } class TestPermutationsWithTake : public TestBase {