From 8a5545fad496395bc251203344486a36eff52a5e Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 21 Jul 2021 09:42:04 -0400
Subject: [PATCH 01/31] ARROW-13222: [C++] Improve type support for case_when
kernel
---
.../arrow/compute/kernels/scalar_if_else.cc | 163 +++++++++++++++++-
.../compute/kernels/scalar_if_else_test.cc | 82 +++++++++
2 files changed, 242 insertions(+), 3 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index cb261ec59a7..e130d2b3ad9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1413,6 +1413,148 @@ struct CaseWhenFunctor {
}
};
+template
+struct CaseWhenFunctor> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecScalar(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& conds = checked_cast(*batch.values[0].scalar());
+ Datum result;
+ for (size_t i = 0; i < batch.values.size() - 1; i++) {
+ if (i < conds.value.size()) {
+ const Scalar& cond = *conds.value[i];
+ if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) {
+ result = batch[i + 1];
+ break;
+ }
+ } else {
+ // ELSE clause
+ result = batch[i + 1];
+ break;
+ }
+ }
+ if (out->is_scalar()) {
+ *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
+ return Status::OK();
+ }
+ ArrayData* output = out->mutable_array();
+ if (!result.is_value()) {
+ // All conditions false, no 'else' argument
+ ARROW_ASSIGN_OR_RAISE(
+ auto array, MakeArrayOfNull(output->type, batch.length, ctx->memory_pool()));
+ *output = *array->data();
+ } else if (result.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto array,
+ MakeArrayFromScalar(*result.scalar(), batch.length, ctx->memory_pool()));
+ *output = *array->data();
+ } else {
+ // Copy offsets/data
+ const ArrayData& source = *result.array();
+ output->length = source.length;
+ output->SetNullCount(source.null_count);
+ if (source.MayHaveNulls()) {
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[0],
+ arrow::internal::CopyBitmap(ctx->memory_pool(), source.buffers[0]->data(),
+ source.offset, source.length));
+ }
+ ARROW_ASSIGN_OR_RAISE(output->buffers[1],
+ ctx->Allocate(sizeof(offset_type) * (source.length + 1)));
+ const offset_type* in_offsets = source.GetValues(1);
+ offset_type* out_offsets = output->GetMutableValues(1);
+ std::transform(in_offsets, in_offsets + source.length + 1, out_offsets,
+ [&](offset_type offset) { return offset - in_offsets[0]; });
+ auto data_length = out_offsets[output->length] - out_offsets[0];
+ ARROW_ASSIGN_OR_RAISE(output->buffers[2], ctx->Allocate(data_length));
+ std::memcpy(output->buffers[2]->mutable_data(),
+ source.buffers[2]->data() + in_offsets[0], data_length);
+ }
+ return Status::OK();
+ }
+
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& conds_array = *batch.values[0].array();
+ if (conds_array.GetNullCount() > 0) {
+ return Status::Invalid("cond struct must not have top-level nulls");
+ }
+ ArrayData* output = out->mutable_array();
+ const bool have_else_arg =
+ static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
+ BuilderType builder(batch[0].type(), ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(batch.length));
+ int64_t reservation = 0;
+ for (size_t arg = 1; arg < batch.values.size(); arg++) {
+ auto source = batch.values[arg];
+ if (source.is_scalar()) {
+ const auto& scalar = checked_cast(*source.scalar());
+ if (!scalar.value) continue;
+ reservation = std::max(reservation, batch.length * scalar.value->size());
+ } else {
+ const auto& array = *source.array();
+ const auto& offsets = array.GetValues(1);
+ reservation = std::max(reservation, offsets[array.length] - offsets[0]);
+ }
+ }
+ RETURN_NOT_OK(builder.ReserveData(reservation));
+
+ for (int64_t row = 0; row < batch.length; row++) {
+ int64_t selected = have_else_arg ? batch.values.size() - 1 : -1;
+ for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size();
+ arg++) {
+ const ArrayData& cond_array = *conds_array.child_data[arg];
+ if ((!cond_array.buffers[0] ||
+ BitUtil::GetBit(cond_array.buffers[0]->data(),
+ conds_array.offset + cond_array.offset + row)) &&
+ BitUtil::GetBit(cond_array.buffers[1]->data(),
+ conds_array.offset + cond_array.offset + row)) {
+ selected = arg + 1;
+ break;
+ }
+ }
+ if (selected < 0) {
+ RETURN_NOT_OK(builder.AppendNull());
+ continue;
+ }
+ const Datum& source = batch.values[selected];
+ if (source.is_scalar()) {
+ const auto& scalar = checked_cast(*source.scalar());
+ if (!scalar.is_valid) {
+ RETURN_NOT_OK(builder.AppendNull());
+ } else {
+ RETURN_NOT_OK(builder.Append(scalar.value->data(), scalar.value->size()));
+ }
+ } else {
+ const auto& array = *source.array();
+ if (!array.buffers[0] ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
+ const offset_type* offsets = array.GetValues(1);
+ RETURN_NOT_OK(builder.Append(array.buffers[2]->data() + offsets[row],
+ offsets[row + 1] - offsets[row]));
+ } else {
+ RETURN_NOT_OK(builder.AppendNull());
+ }
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
+ *output = *temp_output->data();
+ // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
+ output->type = batch[1].type();
+ return Status::OK();
+ }
+};
+
struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;
@@ -1841,9 +1983,15 @@ void AddCaseWhenKernel(const std::shared_ptr& scalar_function,
OutputType(LastType),
/*is_varargs=*/true),
exec);
- kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
- kernel.mem_allocation = MemAllocation::PREALLOCATE;
- kernel.can_write_into_slices = is_fixed_width(get_id.id);
+ if (is_fixed_width(get_id.id)) {
+ kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::PREALLOCATE;
+ kernel.can_write_into_slices = true;
+ } else {
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+ }
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}
@@ -1855,6 +2003,14 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr& scalar
}
}
+void AddBinaryCaseWhenKernels(const std::shared_ptr& scalar_function,
+ const std::vector>& types) {
+ for (auto&& type : types) {
+ auto exec = GenerateTypeAgnosticVarBinaryBase(*type);
+ AddCaseWhenKernel(scalar_function, type, std::move(exec));
+ }
+}
+
void AddCoalesceKernel(const std::shared_ptr& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(FirstType),
@@ -1957,6 +2113,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec);
+ AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 8a6ccd69865..90f73107338 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -531,6 +531,10 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) {
CheckScalar("case_when", {MakeStruct({}), values1}, values1);
CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
@@ -632,6 +636,10 @@ TEST(TestCaseWhen, Boolean) {
CheckScalar("case_when", {MakeStruct({}), values1}, values1);
CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
@@ -685,6 +693,10 @@ TEST(TestCaseWhen, DayTimeInterval) {
CheckScalar("case_when", {MakeStruct({}), values1}, values1);
CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
@@ -739,6 +751,10 @@ TEST(TestCaseWhen, Decimal) {
CheckScalar("case_when", {MakeStruct({}), values1}, values1);
CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
@@ -794,6 +810,10 @@ TEST(TestCaseWhen, FixedSizeBinary) {
CheckScalar("case_when", {MakeStruct({}), values1}, values1);
CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
@@ -830,6 +850,68 @@ TEST(TestCaseWhen, FixedSizeBinary) {
ArrayFromJSON(type, R"([null, null, null, "efg"])"));
}
+template
+class TestCaseWhenBinary : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenBinary, BinaryTypes);
+
+TYPED_TEST(TestCaseWhenBinary, Basics) {
+ auto type = default_type_instance();
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"("aBxYz")");
+ auto scalar2 = ScalarFromJSON(type, R"("b")");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"(["cDE", null, "degfhi", "efg"])");
+ auto values2 = ArrayFromJSON(type, R"(["fghijk", "ghi", null, "hi"])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, "aBxYz", "aBxYz"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"(["aBxYz", "aBxYz", "b", "aBxYz"])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"(["cDE", null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"(["cDE", null, null, "efg"])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, "efg"])"));
+}
+
TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
{struct_({field("", boolean())}), int64(), int64()});
From 2202b74434e7f1892c7b269363a0fe4cd0d321d8 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 21 Jul 2021 14:56:58 -0400
Subject: [PATCH 02/31] ARROW-13222: [C++] Add list type support for case_when
kernel
---
.../arrow/compute/kernels/scalar_if_else.cc | 206 +++++++++++++-----
.../compute/kernels/scalar_if_else_test.cc | 67 +++++-
cpp/src/arrow/testing/gtest_util.h | 2 +
3 files changed, 217 insertions(+), 58 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index e130d2b3ad9..d946c5ed78b 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+#include
#include
#include
#include
@@ -1413,6 +1414,46 @@ struct CaseWhenFunctor {
}
};
+template
+static Status ExecVarWidthScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out, CopyArray copy_array) {
+ const auto& conds = checked_cast(*batch.values[0].scalar());
+ Datum result;
+ for (size_t i = 0; i < batch.values.size() - 1; i++) {
+ if (i < conds.value.size()) {
+ const Scalar& cond = *conds.value[i];
+ if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) {
+ result = batch[i + 1];
+ break;
+ }
+ } else {
+ // ELSE clause
+ result = batch[i + 1];
+ break;
+ }
+ }
+ if (out->is_scalar()) {
+ *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
+ return Status::OK();
+ }
+ ArrayData* output = out->mutable_array();
+ if (!result.is_value()) {
+ // All conditions false, no 'else' argument
+ ARROW_ASSIGN_OR_RAISE(
+ auto array, MakeArrayOfNull(output->type, batch.length, ctx->memory_pool()));
+ *output = *array->data();
+ } else if (result.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(auto array, MakeArrayFromScalar(*result.scalar(), batch.length,
+ ctx->memory_pool()));
+ *output = *array->data();
+ } else {
+ // Copy offsets/data
+ const ArrayData& source = *result.array();
+ RETURN_NOT_OK(copy_array(ctx, source, output));
+ }
+ return Status::OK();
+}
+
template
struct CaseWhenFunctor> {
using offset_type = typename Type::offset_type;
@@ -1428,70 +1469,37 @@ struct CaseWhenFunctor> {
}
static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& conds = checked_cast(*batch.values[0].scalar());
- Datum result;
- for (size_t i = 0; i < batch.values.size() - 1; i++) {
- if (i < conds.value.size()) {
- const Scalar& cond = *conds.value[i];
- if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) {
- result = batch[i + 1];
- break;
- }
- } else {
- // ELSE clause
- result = batch[i + 1];
- break;
- }
- }
- if (out->is_scalar()) {
- *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type());
- return Status::OK();
- }
- ArrayData* output = out->mutable_array();
- if (!result.is_value()) {
- // All conditions false, no 'else' argument
- ARROW_ASSIGN_OR_RAISE(
- auto array, MakeArrayOfNull(output->type, batch.length, ctx->memory_pool()));
- *output = *array->data();
- } else if (result.is_scalar()) {
- ARROW_ASSIGN_OR_RAISE(
- auto array,
- MakeArrayFromScalar(*result.scalar(), batch.length, ctx->memory_pool()));
- *output = *array->data();
- } else {
- // Copy offsets/data
- const ArrayData& source = *result.array();
- output->length = source.length;
- output->SetNullCount(source.null_count);
- if (source.MayHaveNulls()) {
- ARROW_ASSIGN_OR_RAISE(
- output->buffers[0],
- arrow::internal::CopyBitmap(ctx->memory_pool(), source.buffers[0]->data(),
- source.offset, source.length));
- }
- ARROW_ASSIGN_OR_RAISE(output->buffers[1],
- ctx->Allocate(sizeof(offset_type) * (source.length + 1)));
- const offset_type* in_offsets = source.GetValues(1);
- offset_type* out_offsets = output->GetMutableValues(1);
- std::transform(in_offsets, in_offsets + source.length + 1, out_offsets,
- [&](offset_type offset) { return offset - in_offsets[0]; });
- auto data_length = out_offsets[output->length] - out_offsets[0];
- ARROW_ASSIGN_OR_RAISE(output->buffers[2], ctx->Allocate(data_length));
- std::memcpy(output->buffers[2]->mutable_data(),
- source.buffers[2]->data() + in_offsets[0], data_length);
- }
- return Status::OK();
+ return ExecVarWidthScalarCaseWhen(
+ ctx, batch, out,
+ [](KernelContext* ctx, const ArrayData& source, ArrayData* output) {
+ output->length = source.length;
+ output->SetNullCount(source.null_count);
+ if (source.MayHaveNulls()) {
+ ARROW_ASSIGN_OR_RAISE(
+ output->buffers[0],
+ arrow::internal::CopyBitmap(ctx->memory_pool(), source.buffers[0]->data(),
+ source.offset, source.length));
+ }
+ ARROW_ASSIGN_OR_RAISE(output->buffers[1],
+ ctx->Allocate(sizeof(offset_type) * (source.length + 1)));
+ const offset_type* in_offsets = source.GetValues(1);
+ offset_type* out_offsets = output->GetMutableValues(1);
+ std::transform(in_offsets, in_offsets + source.length + 1, out_offsets,
+ [&](offset_type offset) { return offset - in_offsets[0]; });
+ auto data_length = out_offsets[output->length] - out_offsets[0];
+ ARROW_ASSIGN_OR_RAISE(output->buffers[2], ctx->Allocate(data_length));
+ std::memcpy(output->buffers[2]->mutable_data(),
+ source.buffers[2]->data() + in_offsets[0], data_length);
+ return Status::OK();
+ });
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& conds_array = *batch.values[0].array();
- if (conds_array.GetNullCount() > 0) {
- return Status::Invalid("cond struct must not have top-level nulls");
- }
ArrayData* output = out->mutable_array();
const bool have_else_arg =
static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
- BuilderType builder(batch[0].type(), ctx->memory_pool());
+ BuilderType builder(out->type(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(batch.length));
int64_t reservation = 0;
for (size_t arg = 1; arg < batch.values.size(); arg++) {
@@ -1555,6 +1563,88 @@ struct CaseWhenFunctor> {
}
};
+template
+struct CaseWhenFunctor> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecScalar(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthScalarCaseWhen(
+ ctx, batch, out,
+ [](KernelContext* ctx, const ArrayData& source, ArrayData* output) {
+ *output = source;
+ return Status::OK();
+ });
+ }
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const auto& conds_array = *batch.values[0].array();
+ ArrayData* output = out->mutable_array();
+ const bool have_else_arg =
+ static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
+ std::unique_ptr raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ BuilderType* builder = checked_cast(raw_builder.get());
+ RETURN_NOT_OK(builder->Reserve(batch.length));
+
+ for (int64_t row = 0; row < batch.length; row++) {
+ int64_t selected = have_else_arg ? batch.values.size() - 1 : -1;
+ for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size();
+ arg++) {
+ const ArrayData& cond_array = *conds_array.child_data[arg];
+ if ((!cond_array.buffers[0] ||
+ BitUtil::GetBit(cond_array.buffers[0]->data(),
+ conds_array.offset + cond_array.offset + row)) &&
+ BitUtil::GetBit(cond_array.buffers[1]->data(),
+ conds_array.offset + cond_array.offset + row)) {
+ selected = arg + 1;
+ break;
+ }
+ }
+ if (selected < 0) {
+ RETURN_NOT_OK(builder->AppendNull());
+ continue;
+ }
+ const Datum& source = batch.values[selected];
+ // This is horrendously slow, but generic
+ if (source.is_scalar()) {
+ const auto& scalar = *source.scalar();
+ if (!scalar.is_valid) {
+ RETURN_NOT_OK(builder->AppendNull());
+ } else {
+ RETURN_NOT_OK(builder->AppendScalar(scalar));
+ }
+ } else {
+ const auto& array = *source.array();
+ if (!array.buffers[0] ||
+ BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
+ const auto boxed_array = source.make_array();
+ if (boxed_array->IsValid(row)) {
+ ARROW_ASSIGN_OR_RAISE(auto element, boxed_array->GetScalar(row));
+ RETURN_NOT_OK(builder->AppendScalar(*element));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, builder->Finish());
+ *output = *temp_output->data();
+ return Status::OK();
+ }
+};
+
struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;
@@ -2114,6 +2204,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec);
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
+ AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor::Exec);
+ AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 90f73107338..39ff7a90d08 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -853,7 +853,7 @@ TEST(TestCaseWhen, FixedSizeBinary) {
template
class TestCaseWhenBinary : public ::testing::Test {};
-TYPED_TEST_SUITE(TestCaseWhenBinary, BinaryTypes);
+TYPED_TEST_SUITE(TestCaseWhenBinary, BinaryArrowTypes);
TYPED_TEST(TestCaseWhenBinary, Basics) {
auto type = default_type_instance();
@@ -912,6 +912,71 @@ TYPED_TEST(TestCaseWhenBinary, Basics) {
ArrayFromJSON(type, R"([null, null, null, "efg"])"));
}
+template
+class TestCaseWhenList : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes);
+
+TYPED_TEST(TestCaseWhenList, ListOfString) {
+ auto type = std::make_shared(utf8());
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])");
+ auto scalar2 = ScalarFromJSON(type, R"(["b", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gf", "hi"], ["ef", "g"]])");
+ auto values2 = ArrayFromJSON(type, R"([["f", "ghi", "jk"], ["ghi"], null, ["hi"]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
+}
+
TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
{struct_({field("", boolean())}), int64(), int64()});
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index 3f9408ecdcb..a58064e4261 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -170,6 +170,8 @@ using BinaryArrowTypes =
using StringArrowTypes = ::testing::Types;
+using ListArrowTypes = ::testing::Types;
+
using UnionArrowTypes = ::testing::Types;
class Array;
From 737347a707999af31b2e4f76c5037fbc680e067b Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 22 Jul 2021 12:47:46 -0400
Subject: [PATCH 03/31] ARROW-13222: [C++] Add more benchmarks for case_when
kernel
---
.../arrow/compute/kernels/scalar_if_else.cc | 235 +++++++++---------
.../kernels/scalar_if_else_benchmark.cc | 103 ++++++--
2 files changed, 196 insertions(+), 142 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index d946c5ed78b..70ad1676de4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1454,6 +1454,62 @@ static Status ExecVarWidthScalarCaseWhen(KernelContext* ctx, const ExecBatch& ba
return Status::OK();
}
+template
+static Status ExecVarWidthArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch,
+ Datum* out, ReserveData reserve_data,
+ AppendScalar append_scalar,
+ AppendArray append_array) {
+ const auto& conds_array = *batch.values[0].array();
+ ArrayData* output = out->mutable_array();
+ const bool have_else_arg =
+ static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
+ std::unique_ptr raw_builder;
+ RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(batch.length));
+ reserve_data(raw_builder.get());
+
+ for (int64_t row = 0; row < batch.length; row++) {
+ int64_t selected = have_else_arg ? batch.values.size() - 1 : -1;
+ for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size();
+ arg++) {
+ const ArrayData& cond_array = *conds_array.child_data[arg];
+ if ((!cond_array.buffers[0] ||
+ BitUtil::GetBit(cond_array.buffers[0]->data(),
+ conds_array.offset + cond_array.offset + row)) &&
+ BitUtil::GetBit(cond_array.buffers[1]->data(),
+ conds_array.offset + cond_array.offset + row)) {
+ selected = arg + 1;
+ break;
+ }
+ }
+ if (selected < 0) {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ continue;
+ }
+ const Datum& source = batch.values[selected];
+ if (source.is_scalar()) {
+ const auto& scalar = *source.scalar();
+ if (!scalar.is_valid) {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ } else {
+ RETURN_NOT_OK(append_scalar(raw_builder.get(), scalar));
+ }
+ } else {
+ const auto& array = source.array();
+ if (!array->buffers[0] ||
+ BitUtil::GetBit(array->buffers[0]->data(), array->offset + row)) {
+ RETURN_NOT_OK(append_array(raw_builder.get(), array, row));
+ } else {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ }
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto temp_output, raw_builder->Finish());
+ *output = *temp_output->data();
+ return Status::OK();
+}
+
template
struct CaseWhenFunctor> {
using offset_type = typename Type::offset_type;
@@ -1495,71 +1551,43 @@ struct CaseWhenFunctor> {
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& conds_array = *batch.values[0].array();
- ArrayData* output = out->mutable_array();
- const bool have_else_arg =
- static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
- BuilderType builder(out->type(), ctx->memory_pool());
- RETURN_NOT_OK(builder.Reserve(batch.length));
- int64_t reservation = 0;
- for (size_t arg = 1; arg < batch.values.size(); arg++) {
- auto source = batch.values[arg];
- if (source.is_scalar()) {
- const auto& scalar = checked_cast(*source.scalar());
- if (!scalar.value) continue;
- reservation = std::max(reservation, batch.length * scalar.value->size());
- } else {
- const auto& array = *source.array();
- const auto& offsets = array.GetValues(1);
- reservation = std::max(reservation, offsets[array.length] - offsets[0]);
- }
- }
- RETURN_NOT_OK(builder.ReserveData(reservation));
-
- for (int64_t row = 0; row < batch.length; row++) {
- int64_t selected = have_else_arg ? batch.values.size() - 1 : -1;
- for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size();
- arg++) {
- const ArrayData& cond_array = *conds_array.child_data[arg];
- if ((!cond_array.buffers[0] ||
- BitUtil::GetBit(cond_array.buffers[0]->data(),
- conds_array.offset + cond_array.offset + row)) &&
- BitUtil::GetBit(cond_array.buffers[1]->data(),
- conds_array.offset + cond_array.offset + row)) {
- selected = arg + 1;
- break;
- }
- }
- if (selected < 0) {
- RETURN_NOT_OK(builder.AppendNull());
- continue;
- }
- const Datum& source = batch.values[selected];
- if (source.is_scalar()) {
- const auto& scalar = checked_cast(*source.scalar());
- if (!scalar.is_valid) {
- RETURN_NOT_OK(builder.AppendNull());
- } else {
- RETURN_NOT_OK(builder.Append(scalar.value->data(), scalar.value->size()));
- }
- } else {
- const auto& array = *source.array();
- if (!array.buffers[0] ||
- BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
- const offset_type* offsets = array.GetValues(1);
- RETURN_NOT_OK(builder.Append(array.buffers[2]->data() + offsets[row],
- offsets[row + 1] - offsets[row]));
- } else {
- RETURN_NOT_OK(builder.AppendNull());
- }
- }
- }
-
- ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
- *output = *temp_output->data();
- // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
- output->type = batch[1].type();
- return Status::OK();
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [&](ArrayBuilder* raw_builder) {
+ int64_t reservation = 0;
+ for (size_t arg = 1; arg < batch.values.size(); arg++) {
+ auto source = batch.values[arg];
+ if (source.is_scalar()) {
+ const auto& scalar =
+ checked_cast(*source.scalar());
+ if (!scalar.value) continue;
+ reservation =
+ std::max(reservation, batch.length * scalar.value->size());
+ } else {
+ const auto& array = *source.array();
+ const auto& offsets = array.GetValues(1);
+ reservation =
+ std::max(reservation, offsets[array.length] - offsets[0]);
+ }
+ }
+ // checked_cast works since (Large)StringBuilder <: (Large)BinaryBuilder
+ return checked_cast(raw_builder)->ReserveData(reservation);
+ },
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& raw_scalar) {
+ const auto& scalar = checked_cast(raw_scalar);
+ return checked_cast(raw_builder)
+ ->Append(scalar.value->data(), scalar.value->size());
+ },
+ // AppendArray
+ [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ const offset_type* offsets = array->GetValues(1);
+ return checked_cast(raw_builder)
+ ->Append(array->buffers[2]->data() + offsets[row],
+ offsets[row + 1] - offsets[row]);
+ });
}
};
@@ -1586,65 +1614,30 @@ struct CaseWhenFunctor> {
});
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- const auto& conds_array = *batch.values[0].array();
- ArrayData* output = out->mutable_array();
- const bool have_else_arg =
- static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1);
- std::unique_ptr raw_builder;
- RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder));
- BuilderType* builder = checked_cast(raw_builder.get());
- RETURN_NOT_OK(builder->Reserve(batch.length));
-
- for (int64_t row = 0; row < batch.length; row++) {
- int64_t selected = have_else_arg ? batch.values.size() - 1 : -1;
- for (int64_t arg = 0; static_cast(arg) < conds_array.child_data.size();
- arg++) {
- const ArrayData& cond_array = *conds_array.child_data[arg];
- if ((!cond_array.buffers[0] ||
- BitUtil::GetBit(cond_array.buffers[0]->data(),
- conds_array.offset + cond_array.offset + row)) &&
- BitUtil::GetBit(cond_array.buffers[1]->data(),
- conds_array.offset + cond_array.offset + row)) {
- selected = arg + 1;
- break;
- }
- }
- if (selected < 0) {
- RETURN_NOT_OK(builder->AppendNull());
- continue;
- }
- const Datum& source = batch.values[selected];
- // This is horrendously slow, but generic
- if (source.is_scalar()) {
- const auto& scalar = *source.scalar();
- if (!scalar.is_valid) {
- RETURN_NOT_OK(builder->AppendNull());
- } else {
- RETURN_NOT_OK(builder->AppendScalar(scalar));
- }
- } else {
- const auto& array = *source.array();
- if (!array.buffers[0] ||
- BitUtil::GetBit(array.buffers[0]->data(), array.offset + row)) {
- const auto boxed_array = source.make_array();
- if (boxed_array->IsValid(row)) {
- ARROW_ASSIGN_OR_RAISE(auto element, boxed_array->GetScalar(row));
- RETURN_NOT_OK(builder->AppendScalar(*element));
- } else {
- RETURN_NOT_OK(builder->AppendNull());
- }
- } else {
- RETURN_NOT_OK(builder->AppendNull());
- }
- }
- }
-
- ARROW_ASSIGN_OR_RAISE(auto temp_output, builder->Finish());
- *output = *temp_output->data();
- return Status::OK();
+ // TODO: horrendously slow, but generic
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [](ArrayBuilder*) {},
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ return raw_builder->AppendScalar(scalar);
+ },
+ // AppendArray
+ [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
+ return raw_builder->AppendScalar(*scalar);
+ });
}
};
+// TODO: map, fixed size list, struct, union, dictionary
+
+// TODO: file separate issue for dictionary? need utility to unify
+// dictionary and return mapping. what is an R factor? may need to
+// promote index type
+
struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
index 9b59d54c3da..dbe6f957247 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
@@ -27,32 +27,31 @@ namespace arrow {
namespace compute {
const int64_t kNumItems = 1024 * 1024;
+const int64_t kFewItems = 64 * 1024;
template
-struct SetBytesProcessed {};
+struct GetBytesProcessed {};
+
+template <>
+struct GetBytesProcessed {
+ static int64_t Get(const std::shared_ptr& arr) { return arr->length() / 8; }
+};
template
-struct SetBytesProcessed> {
- static void Set(const std::shared_ptr& cond, const std::shared_ptr& left,
- const std::shared_ptr& right, benchmark::State* state) {
+struct GetBytesProcessed> {
+ static int64_t Get(const std::shared_ptr& arr) {
using CType = typename Type::c_type;
- state->SetBytesProcessed(state->iterations() *
- (cond->length() / 8 + 2 * cond->length() * sizeof(CType)));
+ return arr->length() * sizeof(CType);
}
};
template
-struct SetBytesProcessed> {
- static void Set(const std::shared_ptr& cond, const std::shared_ptr& left,
- const std::shared_ptr& right, benchmark::State* state) {
+struct GetBytesProcessed> {
+ static int64_t Get(const std::shared_ptr& arr) {
using ArrayType = typename TypeTraits::ArrayType;
using OffsetType = typename TypeTraits::OffsetType::c_type;
-
- state->SetBytesProcessed(
- state->iterations() *
- (cond->length() / 8 + 2 * cond->length() * sizeof(OffsetType) +
- std::static_pointer_cast(left)->total_values_length() +
- std::static_pointer_cast(right)->total_values_length()));
+ return arr->length() * sizeof(OffsetType) +
+ std::static_pointer_cast(arr)->total_values_length();
}
};
@@ -80,7 +79,10 @@ static void IfElseBench(benchmark::State& state) {
ABORT_NOT_OK(IfElse(cond, left, right));
}
- SetBytesProcessed::Set(cond, left, right, &state);
+ state.SetBytesProcessed(state.iterations() *
+ (GetBytesProcessed::Get(cond) +
+ GetBytesProcessed::Get(left) +
+ GetBytesProcessed::Get(right)));
}
template
@@ -109,7 +111,10 @@ static void IfElseBenchContiguous(benchmark::State& state) {
ABORT_NOT_OK(IfElse(cond, left, right));
}
- SetBytesProcessed::Set(cond, left, right, &state);
+ state.SetBytesProcessed(state.iterations() *
+ (GetBytesProcessed::Get(cond) +
+ GetBytesProcessed::Get(left) +
+ GetBytesProcessed::Get(right)));
}
static void IfElseBench64(benchmark::State& state) {
@@ -146,7 +151,6 @@ static void IfElseBenchString32Contiguous(benchmark::State& state) {
template
static void CaseWhenBench(benchmark::State& state) {
- using CType = typename Type::c_type;
auto type = TypeTraits::type_singleton();
using ArrayType = typename TypeTraits::ArrayType;
@@ -180,12 +184,50 @@ static void CaseWhenBench(benchmark::State& state) {
val3->Slice(offset), val4->Slice(offset)}));
}
- state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType));
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() * GetBytesProcessed::Get(val1));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
+}
+
+static void CaseWhenBenchList(benchmark::State& state) {
+ auto type = list(int64());
+ auto fld = field("", type);
+
+ int64_t len = state.range(0);
+ int64_t offset = state.range(1);
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+
+ auto cond1 = std::static_pointer_cast(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
+ auto cond2 = std::static_pointer_cast(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
+ auto cond3 = std::static_pointer_cast(
+ rand.ArrayOf(boolean(), len, /*null_probability=*/0.01));
+ auto cond_field =
+ field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}}));
+ auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
+ key_value_metadata({{"null_probability", "0.0"}})),
+ len);
+ auto val1 = rand.ArrayOf(*fld, len);
+ auto val2 = rand.ArrayOf(*fld, len);
+ auto val3 = rand.ArrayOf(*fld, len);
+ auto val4 = rand.ArrayOf(*fld, len);
+ for (auto _ : state) {
+ ABORT_NOT_OK(
+ CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset),
+ val3->Slice(offset), val4->Slice(offset)}));
+ }
+
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() *
+ GetBytesProcessed::Get(
+ std::static_pointer_cast(val1)->values()));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
}
template
static void CaseWhenBenchContiguous(benchmark::State& state) {
- using CType = typename Type::c_type;
auto type = TypeTraits::type_singleton();
using ArrayType = typename TypeTraits::ArrayType;
@@ -216,7 +258,9 @@ static void CaseWhenBenchContiguous(benchmark::State& state) {
val3->Slice(offset)}));
}
- state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType));
+ // Set bytes processed to ~length of output
+ state.SetBytesProcessed(state.iterations() * GetBytesProcessed::Get(val1));
+ state.SetItemsProcessed(state.iterations() * (len - offset));
}
static void CaseWhenBench64(benchmark::State& state) {
@@ -227,6 +271,14 @@ static void CaseWhenBench64Contiguous(benchmark::State& state) {
return CaseWhenBenchContiguous(state);
}
+static void CaseWhenBenchString(benchmark::State& state) {
+ return CaseWhenBench(state);
+}
+
+static void CaseWhenBenchStringContiguous(benchmark::State& state) {
+ return CaseWhenBenchContiguous(state);
+}
+
template
static void CoalesceBench(benchmark::State& state) {
using CType = typename Type::c_type;
@@ -337,6 +389,15 @@ BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99});
+BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 99});
+
+BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 99});
+
+BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 0});
+BENCHMARK(CaseWhenBenchStringContiguous)->Args({kFewItems, 99});
+
BENCHMARK(CoalesceBench64)->Args({kNumItems, 0});
BENCHMARK(CoalesceBench64)->Args({kNumItems, 99});
From 6ffb2727955fa4f6f264628b4ba2b172ea198bbd Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 22 Jul 2021 15:10:05 -0400
Subject: [PATCH 04/31] ARROW-13222: [C++] Add basic implementations for map,
fixed size list
---
.../arrow/compute/kernels/scalar_if_else.cc | 83 +++++++-
.../compute/kernels/scalar_if_else_test.cc | 184 ++++++++++++++++++
2 files changed, 265 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 70ad1676de4..d7522988136 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1593,8 +1593,84 @@ struct CaseWhenFunctor> {
template
struct CaseWhenFunctor> {
- using offset_type = typename Type::offset_type;
- using BuilderType = typename TypeTraits::BuilderType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecScalar(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthScalarCaseWhen(
+ ctx, batch, out,
+ [](KernelContext* ctx, const ArrayData& source, ArrayData* output) {
+ *output = source;
+ return Status::OK();
+ });
+ }
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // TODO: horrendously slow, but generic
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [](ArrayBuilder*) {},
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ return raw_builder->AppendScalar(scalar);
+ },
+ // AppendArray
+ [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
+ return raw_builder->AppendScalar(*scalar);
+ });
+ }
+};
+
+template <>
+struct CaseWhenFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecScalar(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthScalarCaseWhen(
+ ctx, batch, out,
+ [](KernelContext* ctx, const ArrayData& source, ArrayData* output) {
+ *output = source;
+ return Status::OK();
+ });
+ }
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // TODO: horrendously slow, but generic
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [](ArrayBuilder*) {},
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ return raw_builder->AppendScalar(scalar);
+ },
+ // AppendArray
+ [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
+ return raw_builder->AppendScalar(*scalar);
+ });
+ }
+};
+
+template <>
+struct CaseWhenFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch[0].null_count() > 0) {
return Status::Invalid("cond struct must not have outer nulls");
@@ -2199,6 +2275,9 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor::Exec);
AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor::Exec);
+ AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST,
+ CaseWhenFunctor::Exec);
+ AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 39ff7a90d08..b1056ca94fe 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -977,6 +977,190 @@ TYPED_TEST(TestCaseWhenList, ListOfString) {
ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
}
+TEST(TestCaseWhen, Map) {
+ auto type = map(int64(), utf8());
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([[1, "abc"], [2, "de"]])");
+ auto scalar2 = ScalarFromJSON(type, R"([[3, "fghi"]])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, [[5, "mn"]], [[6, "o"], [7, "pq"]]])");
+ auto values2 = ArrayFromJSON(type, R"([[[8, "r"], [9, "st"]], [[10, "u"]], null, []])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar(
+ "case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type,
+ R"([null, null, [[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]]])"));
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type,
+ R"([[[1, "abc"], [2, "de"]], [[1, "abc"], [2, "de"]], [[3, "fghi"]], [[1, "abc"], [2, "de"]]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[[4, "kl"]], null, null, [[6, "o"], [7, "pq"]]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [[6, "o"], [7, "pq"]]])"));
+}
+
+TEST(TestCaseWhen, FixedSizeListOfInt) {
+ auto type = fixed_size_list(int64(), 2);
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"([1, 2])");
+ auto scalar2 = ScalarFromJSON(type, R"([3, null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[4, 5], null, [6, 7], [8, 9]])");
+ auto values2 = ArrayFromJSON(type, R"([[10, 11], [12, null], null, [null, 13]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, [1, 2], [1, 2]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(type, R"([[1, 2], [1, 2], [3, null], [1, 2]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[4, 5], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[4, 5], null, null, [8, 9]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [8, 9]])"));
+}
+
+TEST(TestCaseWhen, FixedSizeListOfString) {
+ auto type = fixed_size_list(utf8(), 2);
+ auto cond_true = ScalarFromJSON(boolean(), "true");
+ auto cond_false = ScalarFromJSON(boolean(), "false");
+ auto cond_null = ScalarFromJSON(boolean(), "null");
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto scalar_null = ScalarFromJSON(type, "null");
+ auto scalar1 = ScalarFromJSON(type, R"(["aB", "xYz"])");
+ auto scalar2 = ScalarFromJSON(type, R"(["b", null])");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([["cD", "E"], null, ["de", "gfhi"], ["ef", "g"]])");
+ auto values2 =
+ ArrayFromJSON(type, R"([["fghi", "jk"], ["ghi", null], null, [null, "hi"]])");
+
+ CheckScalar("case_when", {MakeStruct({}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({}), values_null}, values_null);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), scalar1, values1},
+ *MakeArrayFromScalar(*scalar1, 4));
+ CheckScalar("case_when", {MakeStruct({cond_false}), scalar1, values1}, values1);
+
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1);
+ CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2);
+ CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2);
+
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2},
+ values_null);
+ CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2},
+ values1);
+ CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2},
+ values2);
+ CheckScalar("case_when",
+ {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2);
+
+ CheckScalar(
+ "case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2},
+ ArrayFromJSON(type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], null])"));
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null);
+ CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1},
+ ArrayFromJSON(type, R"([null, null, ["aB", "xYz"], ["aB", "xYz"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1},
+ ArrayFromJSON(
+ type, R"([["aB", "xYz"], ["aB", "xYz"], ["b", null], ["aB", "xYz"]])"));
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([["cD", "E"], null, null, ["ef", "g"]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
+}
+
TEST(TestCaseWhen, DispatchBest) {
CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()},
{struct_({field("", boolean())}), int64(), int64()});
From 0d42d408a6af540ea7370269ba2237503997c445 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 23 Jul 2021 15:22:10 -0400
Subject: [PATCH 05/31] ARROW-13222: [C++] Much faster, but less generic
case_when for nested types
---
cpp/src/arrow/array/builder_base.h | 11 ++
cpp/src/arrow/array/builder_primitive.h | 15 ++
cpp/src/arrow/buffer_builder.h | 11 ++
.../arrow/compute/kernels/scalar_if_else.cc | 166 ++++++++++++++++--
.../compute/kernels/scalar_if_else_test.cc | 35 ++++
5 files changed, 220 insertions(+), 18 deletions(-)
diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h
index c2aba4e959f..b52a6b668ec 100644
--- a/cpp/src/arrow/array/builder_base.h
+++ b/cpp/src/arrow/array/builder_base.h
@@ -189,6 +189,17 @@ class ARROW_EXPORT ArrayBuilder {
null_count_ = null_bitmap_builder_.false_count();
}
+ // Vector append. Copy from a given bitmap. If bitmap is null assume
+ // all of length bits are valid.
+ void UnsafeAppendToBitmap(const uint8_t* bitmap, int64_t offset, int64_t length) {
+ if (bitmap == NULLPTR) {
+ return UnsafeSetNotNull(length);
+ }
+ null_bitmap_builder_.UnsafeAppend(bitmap, offset, length);
+ length_ += length;
+ null_count_ = null_bitmap_builder_.false_count();
+ }
+
// Append the same validity value a given number of times.
void UnsafeAppendToBitmap(const int64_t num_bits, bool value) {
if (value) {
diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h
index e0f39f97967..2643006b072 100644
--- a/cpp/src/arrow/array/builder_primitive.h
+++ b/cpp/src/arrow/array/builder_primitive.h
@@ -153,6 +153,21 @@ class NumericBuilder : public ArrayBuilder {
return Status::OK();
}
+ /// \brief Append a sequence of elements in one shot
+ /// \param[in] values a contiguous C array of values
+ /// \param[in] length the number of values to append
+ /// \param[in] bitmap a validity bitmap to copy (may be null)
+ /// \param[in] bitmap_offset an offset into the validity bitmap
+ /// \return Status
+ Status AppendValues(const value_type* values, int64_t length, const uint8_t* bitmap,
+ int64_t bitmap_offset) {
+ ARROW_RETURN_NOT_OK(Reserve(length));
+ data_builder_.UnsafeAppend(values, length);
+ // length_ is update by these
+ ArrayBuilder::UnsafeAppendToBitmap(bitmap, bitmap_offset, length);
+ return Status::OK();
+ }
+
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
diff --git a/cpp/src/arrow/buffer_builder.h b/cpp/src/arrow/buffer_builder.h
index eb3f68affc0..2c5afde1fdc 100644
--- a/cpp/src/arrow/buffer_builder.h
+++ b/cpp/src/arrow/buffer_builder.h
@@ -350,6 +350,17 @@ class TypedBufferBuilder {
bit_length_ += num_elements;
}
+ void UnsafeAppend(const uint8_t* bitmap, int64_t offset, int64_t num_elements) {
+ if (num_elements == 0) return;
+ int64_t i = offset;
+ internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
+ bool value = BitUtil::GetBit(bitmap, i++);
+ false_count_ += !value;
+ return value;
+ });
+ bit_length_ += num_elements;
+ }
+
void UnsafeAppend(const int64_t num_copies, bool value) {
BitUtil::SetBitsTo(mutable_data(), bit_length_, num_copies, value);
false_count_ += num_copies * !value;
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index d7522988136..78187cc0cde 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -16,6 +16,7 @@
// under the License.
#include
+#include
#include
#include
#include
@@ -1591,8 +1592,101 @@ struct CaseWhenFunctor> {
}
};
+using ArrayAppenderFunc = std::function&, int64_t, int64_t)>;
+
+static Status GetValueAppenders(const DataType& type, ArrayAppenderFunc* array_appender);
+
+struct GetAppenders {
+ template
+ enable_if_number Visit(const T&) {
+ using BuilderType = typename TypeTraits::BuilderType;
+ using c_type = typename T::c_type;
+ array_appender = [](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ return checked_cast(raw_builder)
+ ->AppendValues(array->GetValues(1) + offset, length,
+ array->GetValues(0, 0), array->offset + offset);
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const StringType&) {
+ array_appender = [](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ auto builder = checked_cast(raw_builder);
+ auto bitmap = array->GetValues(0, 0);
+ auto offsets = array->GetValues(1);
+ auto data = array->GetValues(2, 0);
+ for (int64_t i = 0; i < length; i++) {
+ if (!bitmap || BitUtil::GetBit(bitmap, offset + i)) {
+ const int32_t start = offsets[offset + i];
+ const int32_t end = offsets[offset + i + 1];
+ RETURN_NOT_OK(builder->Append(data + start, end - start));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+ };
+ return Status::OK();
+ }
+
+ template
+ enable_if_var_size_list Visit(const T& ty) {
+ // TODO: reuse this below? Or make a fully generic (but runtime
+ // dispatched) impl of the top level case when
+ using BuilderType = typename TypeTraits::BuilderType;
+ using offset_type = typename T::offset_type;
+ const auto& list_ty = checked_cast(ty);
+ ArrayAppenderFunc sub_appender;
+ RETURN_NOT_OK(GetValueAppenders(*list_ty.value_type(), &sub_appender));
+ array_appender = [=](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ auto builder = checked_cast(raw_builder);
+ auto child_builder = builder->value_builder();
+ const offset_type* offsets = array->GetValues(1);
+ const uint8_t* validity =
+ array->MayHaveNulls() ? array->buffers[0]->data() : nullptr;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array->offset + row)) {
+ RETURN_NOT_OK(builder->Append());
+ int64_t length = offsets[row + 1] - offsets[row];
+ RETURN_NOT_OK(
+ sub_appender(child_builder, array->child_data[0], offsets[row], length));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const DataType& ty) {
+ return Status::NotImplemented("Appender for type ", ty);
+ }
+
+ ArrayAppenderFunc GetArrayAppender() { return array_appender; }
+
+ ArrayAppenderFunc array_appender;
+};
+
+static Status GetValueAppenders(const DataType& type, ArrayAppenderFunc* array_appender) {
+ // TODO: should cover scalars too
+ GetAppenders get_appenders;
+ RETURN_NOT_OK(VisitTypeInline(type, &get_appenders));
+ *array_appender = std::move(get_appenders.array_appender);
+ return Status::OK();
+}
+
template
struct CaseWhenFunctor> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits::BuilderType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch[0].null_count() > 0) {
return Status::Invalid("cond struct must not have outer nulls");
@@ -1612,20 +1706,45 @@ struct CaseWhenFunctor> {
});
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- // TODO: horrendously slow, but generic
+ const auto& ty = checked_cast(*out->type());
+ ArrayAppenderFunc array_appender;
+ RETURN_NOT_OK(GetValueAppenders(*ty.value_type(), &array_appender));
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
- [](ArrayBuilder*) {},
+ [&](ArrayBuilder* raw_builder) {
+ auto builder = checked_cast(raw_builder);
+ auto child_builder = builder->value_builder();
+
+ int64_t reservation = 0;
+ for (size_t arg = 1; arg < batch.values.size(); arg++) {
+ auto source = batch.values[arg];
+ if (!source.is_array()) {
+ const auto& scalar = checked_cast(*source.scalar());
+ if (!scalar.value) continue;
+ reservation =
+ std::max(reservation, batch.length * scalar.value->length());
+ } else {
+ const auto& array = *source.array();
+ reservation = std::max(reservation, array.child_data[0]->length);
+ }
+ }
+ return child_builder->Reserve(reservation);
+ },
// AppendScalar
[](ArrayBuilder* raw_builder, const Scalar& scalar) {
return raw_builder->AppendScalar(scalar);
},
// AppendArray
- [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
- const int64_t row) {
- ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
- return raw_builder->AppendScalar(*scalar);
+ [&](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ auto builder = checked_cast(raw_builder);
+ auto child_builder = builder->value_builder();
+ RETURN_NOT_OK(builder->Append());
+ const offset_type* offsets = array->GetValues(1);
+ int64_t length = offsets[row + 1] - offsets[row];
+ return array_appender(child_builder, array->child_data[0], offsets[row],
+ length);
});
}
};
@@ -1690,30 +1809,41 @@ struct CaseWhenFunctor {
});
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- // TODO: horrendously slow, but generic
+ const auto& ty = checked_cast(*out->type());
+ const int64_t width = ty.list_size();
+ ArrayAppenderFunc array_appender;
+ RETURN_NOT_OK(GetValueAppenders(*ty.value_type(), &array_appender));
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
- [](ArrayBuilder*) {},
+ [&](ArrayBuilder* raw_builder) {
+ int64_t children = batch.length * width;
+ return checked_cast(raw_builder)
+ ->value_builder()
+ ->Reserve(children);
+ },
// AppendScalar
[](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ // // Append the boxed array to the child builder, then append a new offset
+ // auto child_builder =
+ // checked_cast(raw_builder)->value_builder(); return
+ // raw_builder->Append();
return raw_builder->AppendScalar(scalar);
},
// AppendArray
- [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
- const int64_t row) {
- ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
- return raw_builder->AppendScalar(*scalar);
+ [&](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ // Append a slice of the child array to the child builder, then append a new
+ // offset
+ auto builder = checked_cast(raw_builder);
+ auto child_builder = builder->value_builder();
+ array_appender(child_builder, array->child_data[0],
+ width * (array->offset + row), width);
+ return builder->Append();
});
}
};
-// TODO: map, fixed size list, struct, union, dictionary
-
-// TODO: file separate issue for dictionary? need utility to unify
-// dictionary and return mapping. what is an R factor? may need to
-// promote index type
-
struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index b1056ca94fe..02dc5f1cc94 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -977,6 +977,41 @@ TYPED_TEST(TestCaseWhenList, ListOfString) {
ArrayFromJSON(type, R"([null, null, null, ["ef", "g"]])"));
}
+TYPED_TEST(TestCaseWhenList, ListOfInt) {
+ // More minimal test to check type coverage
+ auto type = std::make_shared(int64());
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 = ArrayFromJSON(type, R"([[1, 2], null, [3, 4, 5], [6, null]])");
+ auto values2 = ArrayFromJSON(type, R"([[8, 9, 10], [11], null, [12]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[1, 2], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[1, 2], null, null, [6, null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [6, null]])"));
+}
+
+TYPED_TEST(TestCaseWhenList, ListOfListOfInt) {
+ // More minimal test to check type coverage
+ auto type = std::make_shared(list(int64()));
+ auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
+ auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = ArrayFromJSON(type, "[null, null, null, null]");
+ auto values1 =
+ ArrayFromJSON(type, R"([[[1, 2], []], null, [[3, 4, 5]], [[6, null], null]])");
+ auto values2 = ArrayFromJSON(type, R"([[[8, 9, 10]], [[11]], null, [[12]]])");
+
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2},
+ ArrayFromJSON(type, R"([[[1, 2], []], null, null, null])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1},
+ ArrayFromJSON(type, R"([[[1, 2], []], null, null, [[6, null], null]])"));
+ CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1},
+ ArrayFromJSON(type, R"([null, null, null, [[6, null], null]])"));
+}
+
TEST(TestCaseWhen, Map) {
auto type = map(int64(), utf8());
auto cond_true = ScalarFromJSON(boolean(), "true");
From 3ec727ab83d036c7035e6cb5ea8c7a75e51a4b88 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 26 Jul 2021 12:38:22 -0400
Subject: [PATCH 06/31] ARROW-13222: [C++] Expand type support
---
.../arrow/compute/kernels/scalar_if_else.cc | 173 ++++++++++++++----
.../compute/kernels/scalar_if_else_test.cc | 121 ++++++++++++
2 files changed, 262 insertions(+), 32 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 78187cc0cde..e1896e591c9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1592,6 +1592,7 @@ struct CaseWhenFunctor> {
}
};
+// Given an array and a builder, append a slice of the array to the builder
using ArrayAppenderFunc = std::function&, int64_t, int64_t)>;
@@ -1612,13 +1613,16 @@ struct GetAppenders {
return Status::OK();
}
- Status Visit(const StringType&) {
+ template
+ enable_if_base_binary Visit(const T&) {
+ using BuilderType = typename TypeTraits::BuilderType;
+ using offset_type = typename T::offset_type;
array_appender = [](ArrayBuilder* raw_builder,
const std::shared_ptr& array, const int64_t offset,
const int64_t length) {
- auto builder = checked_cast(raw_builder);
+ auto builder = checked_cast(raw_builder);
auto bitmap = array->GetValues(0, 0);
- auto offsets = array->GetValues(1);
+ auto offsets = array->GetValues(1);
auto data = array->GetValues(2, 0);
for (int64_t i = 0; i < length; i++) {
if (!bitmap || BitUtil::GetBit(bitmap, offset + i)) {
@@ -1636,8 +1640,6 @@ struct GetAppenders {
template
enable_if_var_size_list Visit(const T& ty) {
- // TODO: reuse this below? Or make a fully generic (but runtime
- // dispatched) impl of the top level case when
using BuilderType = typename TypeTraits::BuilderType;
using offset_type = typename T::offset_type;
const auto& list_ty = checked_cast(ty);
@@ -1666,6 +1668,88 @@ struct GetAppenders {
return Status::OK();
}
+ Status Visit(const MapType& ty) {
+ const auto& map_ty = checked_cast(ty);
+ ArrayAppenderFunc key_appender, item_appender;
+ RETURN_NOT_OK(GetValueAppenders(*map_ty.key_type(), &key_appender));
+ RETURN_NOT_OK(GetValueAppenders(*map_ty.item_type(), &item_appender));
+ array_appender = [=](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ auto builder = checked_cast(raw_builder);
+ auto key_builder = builder->key_builder();
+ auto item_builder = builder->item_builder();
+ const int32_t* offsets = array->GetValues(1);
+ const uint8_t* validity =
+ array->MayHaveNulls() ? array->buffers[0]->data() : nullptr;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array->offset + row)) {
+ RETURN_NOT_OK(builder->Append());
+ int64_t length = offsets[row + 1] - offsets[row];
+ RETURN_NOT_OK(key_appender(key_builder, array->child_data[0]->child_data[0],
+ offsets[row], length));
+ RETURN_NOT_OK(item_appender(item_builder, array->child_data[0]->child_data[1],
+ offsets[row], length));
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const StructType& ty) {
+ const auto& struct_ty = checked_cast(ty);
+ std::vector appenders(struct_ty.num_fields());
+ for (int i = 0; static_cast(i) < appenders.size(); i++) {
+ RETURN_NOT_OK(GetValueAppenders(*struct_ty.field(i)->type(), &appenders[i]));
+ }
+ array_appender = [=](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ auto builder = checked_cast(raw_builder);
+ for (int i = 0; static_cast(i) < appenders.size(); i++) {
+ RETURN_NOT_OK(appenders[i](builder->field_builder(i), array->child_data[i],
+ array->offset + offset, length));
+ }
+ const uint8_t* validity =
+ array->MayHaveNulls() ? array->buffers[0]->data() : nullptr;
+ for (int64_t row = offset; row < offset + length; row++) {
+ RETURN_NOT_OK(
+ builder->Append(!validity || BitUtil::GetBit(validity, array->offset + row)));
+ }
+ return Status::OK();
+ };
+ return Status::OK();
+ }
+
+ Status Visit(const FixedSizeListType& ty) {
+ const auto& list_ty = checked_cast(ty);
+ const int64_t width = list_ty.list_size();
+ ArrayAppenderFunc sub_appender;
+ RETURN_NOT_OK(GetValueAppenders(*list_ty.value_type(), &sub_appender));
+ array_appender = [=](ArrayBuilder* raw_builder,
+ const std::shared_ptr& array, const int64_t offset,
+ const int64_t length) {
+ auto builder = checked_cast(raw_builder);
+ auto child_builder = builder->value_builder();
+ const uint8_t* validity =
+ array->MayHaveNulls() ? array->buffers[0]->data() : nullptr;
+ for (int64_t row = offset; row < offset + length; row++) {
+ if (!validity || BitUtil::GetBit(validity, array->offset + row)) {
+ RETURN_NOT_OK(sub_appender(child_builder, array->child_data[0],
+ width * (array->offset + row), width));
+ RETURN_NOT_OK(builder->Append());
+ } else {
+ RETURN_NOT_OK(builder->AppendNull());
+ }
+ }
+ return Status::OK();
+ };
+ return Status::OK();
+ }
+
Status Visit(const DataType& ty) {
return Status::NotImplemented("Appender for type ", ty);
}
@@ -1708,7 +1792,7 @@ struct CaseWhenFunctor> {
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const auto& ty = checked_cast(*out->type());
ArrayAppenderFunc array_appender;
- RETURN_NOT_OK(GetValueAppenders(*ty.value_type(), &array_appender));
+ RETURN_NOT_OK(GetValueAppenders(ty, &array_appender));
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
@@ -1738,13 +1822,7 @@ struct CaseWhenFunctor> {
// AppendArray
[&](ArrayBuilder* raw_builder, const std::shared_ptr& array,
const int64_t row) {
- auto builder = checked_cast(raw_builder);
- auto child_builder = builder->value_builder();
- RETURN_NOT_OK(builder->Append());
- const offset_type* offsets = array->GetValues(1);
- int64_t length = offsets[row + 1] - offsets[row];
- return array_appender(child_builder, array->child_data[0], offsets[row],
- length);
+ return array_appender(raw_builder, array, row, /*length=*/1);
});
}
};
@@ -1770,7 +1848,8 @@ struct CaseWhenFunctor {
});
}
static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
- // TODO: horrendously slow, but generic
+ ArrayAppenderFunc array_appender;
+ RETURN_NOT_OK(GetValueAppenders(*out->type(), &array_appender));
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
@@ -1780,10 +1859,48 @@ struct CaseWhenFunctor {
return raw_builder->AppendScalar(scalar);
},
// AppendArray
- [](ArrayBuilder* raw_builder, const std::shared_ptr& array,
- const int64_t row) {
- ARROW_ASSIGN_OR_RAISE(auto scalar, MakeArray(array)->GetScalar(row));
- return raw_builder->AppendScalar(*scalar);
+ [&](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ return array_appender(raw_builder, array, row, /*length=*/1);
+ });
+ }
+};
+
+template <>
+struct CaseWhenFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ if (batch[0].null_count() > 0) {
+ return Status::Invalid("cond struct must not have outer nulls");
+ }
+ if (batch[0].is_scalar()) {
+ return ExecScalar(ctx, batch, out);
+ }
+ return ExecArray(ctx, batch, out);
+ }
+
+ static Status ExecScalar(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ return ExecVarWidthScalarCaseWhen(
+ ctx, batch, out,
+ [](KernelContext* ctx, const ArrayData& source, ArrayData* output) {
+ *output = source;
+ return Status::OK();
+ });
+ }
+ static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ ArrayAppenderFunc array_appender;
+ RETURN_NOT_OK(GetValueAppenders(*out->type(), &array_appender));
+ return ExecVarWidthArrayCaseWhen(
+ ctx, batch, out,
+ // ReserveData
+ [](ArrayBuilder*) {},
+ // AppendScalar
+ [](ArrayBuilder* raw_builder, const Scalar& scalar) {
+ return raw_builder->AppendScalar(scalar);
+ },
+ // AppendArray
+ [&](ArrayBuilder* raw_builder, const std::shared_ptr& array,
+ const int64_t row) {
+ return array_appender(raw_builder, array, row, /*length=*/1);
});
}
};
@@ -1812,7 +1929,7 @@ struct CaseWhenFunctor {
const auto& ty = checked_cast(*out->type());
const int64_t width = ty.list_size();
ArrayAppenderFunc array_appender;
- RETURN_NOT_OK(GetValueAppenders(*ty.value_type(), &array_appender));
+ RETURN_NOT_OK(GetValueAppenders(ty, &array_appender));
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
@@ -1824,22 +1941,12 @@ struct CaseWhenFunctor