Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/array/builder_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class ARROW_EXPORT ArrayBuilder {
explicit ArrayBuilder(MemoryPool* pool) : pool_(pool), null_bitmap_builder_(pool) {}

virtual ~ArrayBuilder() = default;
ARROW_DEFAULT_MOVE_AND_ASSIGN(ArrayBuilder);

/// For nested types. Since the objects are owned by this class instance, we
/// skip shared pointers and just return a raw pointer
Expand Down
77 changes: 77 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,81 @@ struct ListImpl : public Selection<ListImpl<Type>, Type> {
}
};

struct DenseUnionImpl : public Selection<DenseUnionImpl, DenseUnionType> {
using Base = Selection<DenseUnionImpl, DenseUnionType>;
LIFT_BASE_MEMBERS();

TypedBufferBuilder<int32_t> value_offset_buffer_builder_;
TypedBufferBuilder<int8_t> child_id_buffer_builder_;
std::vector<int8_t> type_codes_;
std::vector<Int32Builder> child_indices_builders_;

DenseUnionImpl(KernelContext* ctx, const ExecBatch& batch, int64_t output_length,
Datum* out)
: Base(ctx, batch, output_length, out),
value_offset_buffer_builder_(ctx->memory_pool()),
child_id_buffer_builder_(ctx->memory_pool()),
type_codes_(checked_cast<const UnionType&>(*this->values->type).type_codes()),
child_indices_builders_(type_codes_.size()) {
for (auto& child_indices_builder : child_indices_builders_) {
child_indices_builder = Int32Builder(ctx->memory_pool());
}
}

template <typename Adapter>
Status GenerateOutput() {
DenseUnionArray typed_values(this->values);
Adapter adapter(this);
RETURN_NOT_OK(adapter.Generate(
[&](int64_t index) {
int8_t child_id = typed_values.child_id(index);
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
int32_t value_offset = typed_values.value_offset(index);
value_offset_buffer_builder_.UnsafeAppend(
static_cast<int32_t>(child_indices_builders_[child_id].length()));
RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1));
child_indices_builders_[child_id].UnsafeAppend(value_offset);
return Status::OK();
},
[&]() {
int8_t child_id = 0;
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
value_offset_buffer_builder_.UnsafeAppend(
static_cast<int32_t>(child_indices_builders_[child_id].length()));
RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1));
child_indices_builders_[child_id].UnsafeAppendNull();
return Status::OK();
}));
return Status::OK();
}

Status Init() override {
RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
RETURN_NOT_OK(value_offset_buffer_builder_.Reserve(output_length));
return Status::OK();
}

Status Finish() override {
ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish());
ARROW_ASSIGN_OR_RAISE(auto value_offsets_buffer,
value_offset_buffer_builder_.Finish());
DenseUnionArray typed_values(this->values);
auto num_fields = typed_values.num_fields();
auto num_rows = child_ids_buffer->size();
BufferVector buffers{nullptr, std::move(child_ids_buffer),
std::move(value_offsets_buffer)};
*out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
for (auto i = 0; i < num_fields; i++) {
ARROW_ASSIGN_OR_RAISE(auto child_indices_array,
child_indices_builders_[i].Finish());
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> child_array,
Take(*typed_values.field(i), *child_indices_array));
out->child_data.push_back(child_array->data());
}
return Status::OK();
}
};

struct FSLImpl : public Selection<FSLImpl, FixedSizeListType> {
Int64Builder child_index_builder;

Expand Down Expand Up @@ -2141,6 +2216,7 @@ void RegisterVectorSelection(FunctionRegistry* registry) {
{InputType::Array(Type::LIST), FilterExec<ListImpl<ListType>>},
{InputType::Array(Type::LARGE_LIST), FilterExec<ListImpl<LargeListType>>},
{InputType::Array(Type::FIXED_SIZE_LIST), FilterExec<FSLImpl>},
{InputType::Array(Type::DENSE_UNION), FilterExec<DenseUnionImpl>},
{InputType::Array(Type::STRUCT), StructFilter},
// TODO: Reuse ListType kernel for MAP
{InputType::Array(Type::MAP), FilterExec<ListImpl<MapType>>},
Expand Down Expand Up @@ -2170,6 +2246,7 @@ void RegisterVectorSelection(FunctionRegistry* registry) {
{InputType::Array(Type::LIST), TakeExec<ListImpl<ListType>>},
{InputType::Array(Type::LARGE_LIST), TakeExec<ListImpl<LargeListType>>},
{InputType::Array(Type::FIXED_SIZE_LIST), TakeExec<FSLImpl>},
{InputType::Array(Type::DENSE_UNION), TakeExec<DenseUnionImpl>},
{InputType::Array(Type::STRUCT), TakeExec<StructImpl>},
// TODO: Reuse ListType kernel for MAP
{InputType::Array(Type::MAP), TakeExec<ListImpl<MapType>>},
Expand Down
65 changes: 32 additions & 33 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,31 +607,31 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) {

class TestFilterKernelWithUnion : public TestFilterKernel<UnionType> {};

TEST_F(TestFilterKernelWithUnion, DISABLED_FilterUnion) {
for (auto union_ : UnionTypeFactories()) {
auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
null,
TEST_F(TestFilterKernelWithUnion, FilterUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
[5, "eh"],
null,
[2, 111]
[2, null],
[2, 111],
[5, null]
])";
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1]", R"([
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
[2, 222],
[5, "hello"],
null,
[2, 111]
[2, null],
[2, 111],
[5, null]
])");
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0]", R"([
null,
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
[2, null],
[5, "hello"],
null
[2, null]
])");
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1]", union_json);
}
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);
}

class TestFilterKernelWithRecordBatch : public TestFilterKernel<RecordBatch> {
Expand Down Expand Up @@ -1281,42 +1281,41 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {

class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};

// TODO: Restore Union take functionality
TEST_F(TestTakeKernelWithUnion, DISABLED_TakeUnion) {
for (auto union_ : UnionTypeFactories()) {
auto union_type = union_({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
null,
TEST_F(TestTakeKernelWithUnion, TakeUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
[5, "eh"],
null,
[2, 111]
[2, null],
[2, 111],
[5, null]
])";
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
CheckTake(union_type, union_json, "[4, 2, 1]", R"([
null,
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
[2, null],
[5, "hello"],
[2, 222]
[2, 222],
[5, null]
])");
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
null,
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
[5, "hello"],
[5, "hello"],
[5, "hello"],
[5, "hello"]
])");
}
}

class TestPermutationsWithTake : public TestBase {
Expand Down