From a0a308e3b56614894b77791664e659affd81ab2e Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 17 Jun 2021 10:30:09 -0400 Subject: [PATCH 01/11] ARROW-13220: [C++] Implement 'case_when' for fixed-width types --- cpp/src/arrow/compute/api_scalar.cc | 4 + cpp/src/arrow/compute/api_scalar.h | 16 + .../arrow/compute/kernels/codegen_internal.cc | 2 +- .../arrow/compute/kernels/scalar_if_else.cc | 395 +++++++++++++++++- .../kernels/scalar_if_else_benchmark.cc | 102 +++++ .../compute/kernels/scalar_if_else_test.cc | 325 +++++++++++++- cpp/src/arrow/compute/kernels/test_util.cc | 33 +- cpp/src/arrow/compute/kernels/test_util.h | 3 +- docs/source/cpp/compute.rst | 46 +- docs/source/python/api/compute.rst | 1 + 10 files changed, 863 insertions(+), 64 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index be6498a74c6..0588f748812 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -466,6 +466,10 @@ Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_fa return CallFunction("if_else", {cond, if_true, if_false}, ctx); } +Result CaseWhen(const std::vector& cases, ExecContext* ctx) { + return CallFunction("case_when", cases, ctx); +} + // ---------------------------------------------------------------------- // Temporal functions diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f0aebc8e032..28a0ca53c52 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -741,6 +741,22 @@ ARROW_EXPORT Result IfElse(const Datum& cond, const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for +/// each row, select the first value for which the corresponding condition is +/// true, or (if given) select the 'else' value, else emit null. Note that a +/// null condition is the same as false. +/// +/// \param[in] cases Zero or more pairs of conditions (Boolean) & values (any +/// type), along with an optional 'else' value. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result CaseWhen(const std::vector& cases, ExecContext* ctx = NULLPTR); + /// \brief Year returns year for each element of `values` /// /// \param[in] values input to extract year from diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index e723bd7838e..b9bde999447 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -253,7 +253,7 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { if (max_width_unsigned == 32) return uint32(); if (max_width_unsigned == 16) return uint16(); DCHECK_EQ(max_width_unsigned, 8); - return int8(); + return uint8(); } if (max_width_signed <= max_width_unsigned) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 54e0725fce7..7baaeeabe1e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -30,6 +30,7 @@ using internal::Bitmap; using internal::BitmapWordReader; namespace compute { +namespace internal { namespace { @@ -676,7 +677,348 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth { + static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const bool value = UnboxScalar::Unbox(scalar); + BitUtil::SetBitsTo(out_values, offset, length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, + out_values, offset); + } +}; +template +struct CopyFixedWidth> { + using CType = typename TypeTraits::CType; + static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast(raw_out_values); + const CType value = UnboxScalar::Unbox(values); + std::fill(out_values + offset, out_values + offset + length, value); + } + static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, + const int64_t offset, const int64_t length) { + CType* out_values = reinterpret_cast(raw_out_values); + const CType* in_values = array.GetValues(1); + std::copy(in_values + offset, in_values + offset + length, out_values + offset); + } +}; +template +struct CopyFixedWidth> { + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast(values); + if (!scalar.is_valid) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +template +struct CopyFixedWidth> { + using ScalarType = typename TypeTraits::ScalarType; + static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast(*values.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto& scalar = checked_cast(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, + const int64_t length) { + const int32_t width = + checked_cast(*array.type).byte_width(); + uint8_t* next = out_values + (width * offset); + const auto* in_values = array.GetValues(1, (offset + array.offset) * width); + std::memcpy(next, in_values, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template +void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, + const int64_t offset, const int64_t length) { + using Copier = CopyFixedWidth; + if (values.is_scalar()) { + const auto& scalar = *values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + } + Copier::CopyScalar(scalar, out_values, offset, length); + } else { + const ArrayData& array = *values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, + length, out_valid, offset); + } else { + BitUtil::SetBitsTo(out_valid, offset, length, true); + } + } + Copier::CopyArray(array, out_values, offset, length); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + std::vector value_types; + for (size_t i = 0; i < values->size() - 1; i += 2) { + ValueDescr* cond = &(*values)[i]; + if (cond->type->id() == Type::NA) { + cond->type = boolean(); + } + if (cond->type->id() != Type::BOOL) { + return Status::TypeError("Condition arguments must be boolean, but argument ", i, + " was ", cond->type->ToString()); + } + value_types.push_back((*values)[i + 1]); + } + if (values->size() % 2 != 0) { + // Have an ELSE clause + value_types.push_back(values->back()); + } + EnsureDictionaryDecoded(&value_types); + if (auto type = CommonNumeric(value_types)) { + ReplaceTypes(type, &value_types); + } + + const DataType& common_values_type = *value_types.front().type; + auto next_type = value_types.cbegin(); + for (size_t i = 0; i < values->size(); i += 2) { + if (!common_values_type.Equals(next_type->type)) { + return Status::TypeError("Value arguments must be of same type, but argument ", i, + " was ", next_type->type->ToString(), " (expected ", + common_values_type.ToString(), ")"); + } + if (i == values->size() - 1) { + // ELSE + (*values)[i] = *next_type++; + } else { + (*values)[i + 1] = *next_type++; + } + } + + // We register a unary kernel for each value type and dispatch to it after validation. + if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar arguments +Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Scalar& cond = *batch[i].scalar(); + if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) { + *out = batch[i + 1]; + return Status::OK(); + } + } + if (batch.values.size() % 2 == 0) { + // No ELSE + *out = MakeNullScalar(batch[1].type()); + } else { + *out = batch.values.back(); + } + return Status::OK(); +} + +// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, +// given helper functions to copy data from a source array to a target array and to +// allocate a values buffer +template +Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ArrayData* output = out->mutable_array(); + const bool have_else_arg = batch.values.size() % 2 != 0; + // Check if we may need a validity bitmap + uint8_t* out_valid = nullptr; + + bool need_valid_bitmap = false; + if (!have_else_arg) { + // If we don't have an else arg -> need a bitmap since we may emit nulls + need_valid_bitmap = true; + } else if (batch.values.back().null_count() > 0) { + // If the 'else' array has a null count we need a validity bitmap + need_valid_bitmap = true; + } else { + // Otherwise if any value array has a null count we need a validity bitmap + for (size_t i = 1; i < batch.values.size(); i += 2) { + if (batch[i].null_count() > 0) { + need_valid_bitmap = true; + break; + } + } + } + if (need_valid_bitmap) { + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length)); + out_valid = output->buffers[0]->mutable_data(); + } + + // Initialize values buffer + uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { + // Copy 'else' value into output + CopyValues(batch.values.back(), out_valid, out_values, /*offset=*/0, + batch.length); + } else if (need_valid_bitmap) { + // There's no 'else' argument, so we should have an all-null validity bitmap + std::memset(out_valid, 0x00, output->buffers[0]->size()); + } + + // Allocate a temporary bitmap to determine which elements still need setting. + ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length)); + uint8_t* mask = mask_buffer->mutable_data(); + std::memset(mask, 0xFF, mask_buffer->size()); + // Then iterate through each argument in turn and set elements. + for (size_t i = 0; i < batch.values.size() - 1; i += 2) { + const Datum& cond_datum = batch[i]; + const Datum& values_datum = batch[i + 1]; + if (cond_datum.is_scalar()) { + const Scalar& cond_scalar = *cond_datum.scalar(); + const bool cond = + cond_scalar.is_valid && UnboxScalar::Unbox(cond_scalar); + if (!cond) continue; + BitBlockCounter counter(mask, /*start_offset=*/0, batch.length); + int64_t offset = 0; + while (offset < batch.length) { + const auto block = counter.NextWord(); + if (block.AllSet()) { + CopyValues(values_datum, out_valid, out_values, offset, block.length); + } else if (block.popcount) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(mask, offset + j)) { + CopyValues(values_datum, out_valid, out_values, offset + j, + /*length=*/1); + } + } + } + offset += block.length; + } + break; + } + + const ArrayData& cond_array = *cond_datum.array(); + const uint8_t* cond_values = cond_array.buffers[1]->data(); + int64_t offset = 0; + // If no valid buffer, visit mask & value bitmap simultaneously + if (cond_array.GetNullCount() == 0) { + BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, + cond_array.offset, batch.length); + while (offset < batch.length) { + const auto block = counter.NextAndWord(); + if (block.AllSet()) { + CopyValues(values_datum, out_valid, out_values, offset, block.length); + BitUtil::SetBitsTo(mask, offset, block.length, false); + } else if (block.popcount) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_values, cond_array.offset + offset + j)) { + CopyValues(values_datum, out_valid, out_values, offset + j, + /*length=*/1); + BitUtil::SetBitTo(mask, offset + j, false); + } + } + } + offset += block.length; + } + continue; + } + + // Else visit all three bitmaps simultaneously + const uint8_t* cond_valid = cond_array.buffers[0]->data(); + Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, + {cond_values, cond_array.offset, batch.length}, + {cond_valid, cond_array.offset, batch.length}}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + const uint64_t word = words[0] & words[1] & words[2]; + const int64_t block_length = std::min(64, batch.length - offset); + if (word == std::numeric_limits::max()) { + CopyValues(values_datum, out_valid, out_values, offset, block_length); + BitUtil::SetBitsTo(mask, offset, block_length, false); + } else if (word) { + for (int64_t j = 0; j < block_length; ++j) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_valid, cond_array.offset + offset + j) && + BitUtil::GetBit(cond_values, cond_array.offset + offset + j)) { + CopyValues(values_datum, out_valid, out_values, offset + j, + /*length=*/1); + BitUtil::SetBitTo(mask, offset + j, false); + } + } + } + offset += block_length; + }); + } + return Status::OK(); +} + +template +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + for (const auto& datum : batch.values) { + if (datum.is_array()) { + return ExecArrayCaseWhen(ctx, batch, out); + } + } + return ExecScalarCaseWhen(ctx, batch, out); + } +}; + +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } +}; + +Result LastType(KernelContext*, const std::vector& descrs) { + ValueDescr result = descrs.back(); + result.shape = GetBroadcastShape(descrs); + return result; +} + +void AddCaseWhenKernel(const std::shared_ptr& scalar_function, + detail::GetTypeId get_id, ArrayKernelExec exec) { + ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(LastType), + /*is_varargs=*/true), + exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); +} + +void AddPrimitiveCaseWhenKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = GenerateTypeAgnosticPrimitive(*type); + AddCaseWhenKernel(scalar_function, type, std::move(exec)); + } +} const FunctionDoc if_else_doc{"Choose values based on a condition", ("`cond` must be a Boolean scalar/ array. \n`left` or " @@ -685,22 +1027,45 @@ const FunctionDoc if_else_doc{"Choose values based on a condition", " output."), {"cond", "left", "right"}}; -namespace internal { +const FunctionDoc case_when_doc{ + "Choose values based on multiple conditions", + ("`cond` must be a sequence of alternating Boolean condition data " + "and value data (of any type, but all must be the same type or " + "castable to a common type), along with an optional datum of " + "\"else\" values. At least one datum must be given.\n" + "Each row of the output will be the corresponding value of the " + "first value datum for which the corresponding condition datum " + "is true, or otherwise the \"else\" value (if given), or null. " + "Essentially, this implements a switch-case or if-else if-else " + "statement."), + {"*cond"}}; +} // namespace void RegisterScalarIfElse(FunctionRegistry* registry) { - ScalarKernel scalar_kernel; - scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - - auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); - - AddPrimitiveIfElseKernels(func, NumericTypes()); - AddPrimitiveIfElseKernels(func, TemporalTypes()); - AddPrimitiveIfElseKernels(func, {boolean()}); - AddNullIfElseKernel(func); - // todo add binary kernels - - DCHECK_OK(registry->AddFunction(std::move(func))); + { + auto func = + std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), day_time_interval(), month_interval()}); + AddNullIfElseKernel(func); + // todo add binary kernels + DCHECK_OK(registry->AddFunction(std::move(func))); + } + { + auto func = std::make_shared( + "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); + AddPrimitiveCaseWhenKernels(func, NumericTypes()); + AddPrimitiveCaseWhenKernels(func, TemporalTypes()); + AddPrimitiveCaseWhenKernels( + func, {boolean(), null(), day_time_interval(), month_interval()}); + AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, + CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 98fb675da40..83639235bb5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -97,6 +97,96 @@ static void IfElseBench32Contiguous(benchmark::State& state) { return IfElseBenchContiguous(state); } +template +static void CaseWhenBench(benchmark::State& state) { + using CType = typename Type::c_type; + auto type = TypeTraits::type_singleton(); + using ArrayType = typename TypeTraits::ArrayType; + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + random::RandomArrayGenerator rand(/*seed=*/0); + + auto cond1 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond2 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond3 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto val1 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val2 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val3 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val4 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + for (auto _ : state) { + ABORT_NOT_OK( + CaseWhen({cond1->Slice(offset), val1->Slice(offset), cond2->Slice(offset), + val2->Slice(offset), cond3->Slice(offset), val3->Slice(offset), + val4->Slice(offset)})); + } + + state.SetBytesProcessed(state.iterations() * + ((len - offset) / 8 + 4 * (len - offset) * sizeof(CType))); +} + +template +static void CaseWhenBenchContiguous(benchmark::State& state) { + using CType = typename Type::c_type; + auto type = TypeTraits::type_singleton(); + using ArrayType = typename TypeTraits::ArrayType; + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + ASSERT_OK_AND_ASSIGN(auto trues, MakeArrayFromScalar(BooleanScalar(true), len / 3)); + ASSERT_OK_AND_ASSIGN(auto falses, MakeArrayFromScalar(BooleanScalar(false), len / 3)); + auto null_scalar = MakeNullScalar(boolean()); + ASSERT_OK_AND_ASSIGN(auto nulls, + MakeArrayFromScalar(*null_scalar, len - 2 * (len / 3))); + ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({trues, falses, nulls})); + auto cond1 = std::static_pointer_cast(concat); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto cond2 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto val1 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val2 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val3 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + for (auto _ : state) { + ABORT_NOT_OK( + CaseWhen({cond1->Slice(offset), val1->Slice(offset), cond2->Slice(offset), + val2->Slice(offset), val3->Slice(offset)})); + } + + state.SetBytesProcessed(state.iterations() * + ((len - offset) / 8 + 3 * (len - offset) * sizeof(CType))); +} + +static void CaseWhenBench64(benchmark::State& state) { + return CaseWhenBench(state); +} + +static void CaseWhenBench32(benchmark::State& state) { + return CaseWhenBench(state); +} + +static void CaseWhenBench64Contiguous(benchmark::State& state) { + return CaseWhenBenchContiguous(state); +} + +static void CaseWhenBench32Contiguous(benchmark::State& state) { + return CaseWhenBenchContiguous(state); +} + BENCHMARK(IfElseBench32)->Args({elems, 0}); BENCHMARK(IfElseBench64)->Args({elems, 0}); @@ -109,5 +199,17 @@ BENCHMARK(IfElseBench64Contiguous)->Args({elems, 0}); BENCHMARK(IfElseBench32Contiguous)->Args({elems, 99}); BENCHMARK(IfElseBench64Contiguous)->Args({elems, 99}); +BENCHMARK(CaseWhenBench32)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64)->Args({elems, 0}); + +BENCHMARK(CaseWhenBench32)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64)->Args({elems, 99}); + +BENCHMARK(CaseWhenBench32Contiguous)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); + +BENCHMARK(CaseWhenBench32Contiguous)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99}); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 670a2d42a3a..0a9f5548368 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include -#include -#include #include +#include "arrow/array.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/registry.h" +#include "arrow/testing/gtest_util.h" namespace arrow { namespace compute { @@ -45,15 +46,16 @@ class TestIfElseKernel : public ::testing::Test {}; template class TestIfElsePrimitive : public ::testing::Test {}; -using PrimitiveTypes = ::testing::Types; +using NumericBasedTypes = + ::testing::Types; -TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); +TYPED_TEST_SUITE(TestIfElsePrimitive, NumericBasedTypes); TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { using ArrayType = typename TypeTraits::ArrayType; - auto type = TypeTraits::type_singleton(); + auto type = default_type_instance(); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; @@ -71,7 +73,7 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - typename TypeTraits::BuilderType builder; + typename TypeTraits::BuilderType builder(type, default_memory_pool()); for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || @@ -155,7 +157,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, } TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { - auto type = TypeTraits::type_singleton(); + auto type = default_type_instance(); CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), ArrayFromJSON(type, "[1, 2, 3, 4]"), @@ -316,5 +318,304 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +template +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +TYPED_TEST(TestCaseWhenNumeric, FixedSize) { + auto type = default_type_instance(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "1"); + auto scalar2 = ScalarFromJSON(type, "2"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[3, null, 5, 6]"); + auto values2 = ArrayFromJSON(type, "[7, 8, null, 10]"); + + CheckScalar("case_when", {values1}, values1); + CheckScalar("case_when", {values_null}, values_null); + + CheckScalar("case_when", {cond_true, values1}, values1); + CheckScalar("case_when", {cond_false, values1}, values_null); + CheckScalar("case_when", {cond_null, values1}, values_null); + CheckScalar("case_when", {cond_true, values1, values2}, values1); + CheckScalar("case_when", {cond_false, values1, values2}, values2); + CheckScalar("case_when", {cond_null, values1, values2}, values2); + + CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); + CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, "[1, 1, 2, null]")); + CheckScalar("case_when", {cond1, scalar_null}, values_null); + CheckScalar("case_when", {cond1, scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, 1, 1]")); + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + ArrayFromJSON(type, "[1, 1, 2, 1]")); + + CheckScalar("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, "[3, null, null, null]")); + CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, "[3, null, null, 6]")); + CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + ArrayFromJSON(type, "[null, null, null, 6]")); + + CheckScalar("case_when", + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]"), + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, null, null, 26, null, null]")); + CheckScalar("case_when", + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]"), + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"), + ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]")); +} + +TEST(TestCaseWhen, Null) { + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_arr = ArrayFromJSON(boolean(), "[true, true, false, null]"); + auto scalar = ScalarFromJSON(null(), "null"); + auto array = ArrayFromJSON(null(), "[null, null, null, null]"); + CheckScalar("case_when", {array}, array); + CheckScalar("case_when", {cond_false, array}, array); + CheckScalar("case_when", {cond_true, array, array}, array); + CheckScalar("case_when", {cond_arr, array, cond_true, array}, array); +} + +TEST(TestCaseWhen, Boolean) { + auto type = boolean(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "true"); + auto scalar2 = ScalarFromJSON(type, "false"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[true, null, true, true]"); + auto values2 = ArrayFromJSON(type, "[false, false, null, false]"); + + CheckScalar("case_when", {values1}, values1); + CheckScalar("case_when", {values_null}, values_null); + + CheckScalar("case_when", {cond_true, values1}, values1); + CheckScalar("case_when", {cond_false, values1}, values_null); + CheckScalar("case_when", {cond_null, values1}, values_null); + CheckScalar("case_when", {cond_true, values1, values2}, values1); + CheckScalar("case_when", {cond_false, values1, values2}, values2); + CheckScalar("case_when", {cond_null, values1, values2}, values2); + + CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); + CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, "[true, true, false, null]")); + CheckScalar("case_when", {cond1, scalar_null}, values_null); + CheckScalar("case_when", {cond1, scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, true, true]")); + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + ArrayFromJSON(type, "[true, true, false, true]")); + + CheckScalar("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, "[true, null, null, null]")); + CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, "[true, null, null, true]")); + CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + ArrayFromJSON(type, "[null, null, null, true]")); +} + +TEST(TestCaseWhen, DayTimeInterval) { + auto type = day_time_interval(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[1, 1]"); + auto scalar2 = ScalarFromJSON(type, "[2, 2]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[[3, 3], null, [5, 5], [6, 6]]"); + auto values2 = ArrayFromJSON(type, "[[7, 7], [8, 8], null, [10, 10]]"); + + CheckScalar("case_when", {values1}, values1); + CheckScalar("case_when", {values_null}, values_null); + + CheckScalar("case_when", {cond_true, values1}, values1); + CheckScalar("case_when", {cond_false, values1}, values_null); + CheckScalar("case_when", {cond_null, values1}, values_null); + CheckScalar("case_when", {cond_true, values1, values2}, values1); + CheckScalar("case_when", {cond_false, values1, values2}, values2); + CheckScalar("case_when", {cond_null, values1, values2}, values2); + + CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); + CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], null]")); + CheckScalar("case_when", {cond1, scalar_null}, values_null); + CheckScalar("case_when", {cond1, scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, [1, 1], [1, 1]]")); + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], [1, 1]]")); + + CheckScalar("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, "[[3, 3], null, null, null]")); + CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, "[[3, 3], null, null, [6, 6]]")); + CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + ArrayFromJSON(type, "[null, null, null, [6, 6]]")); +} + +TEST(TestCaseWhen, Decimal) { + for (const auto& type : + std::vector>{decimal128(3, 2), decimal256(3, 2)}) { + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"("1.23")"); + auto scalar2 = ScalarFromJSON(type, R"("2.34")"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"(["3.45", null, "5.67", "6.78"])"); + auto values2 = ArrayFromJSON(type, R"(["7.89", "8.90", null, "1.01"])"); + + CheckScalar("case_when", {values1}, values1); + CheckScalar("case_when", {values_null}, values_null); + + CheckScalar("case_when", {cond_true, values1}, values1); + CheckScalar("case_when", {cond_false, values1}, values_null); + CheckScalar("case_when", {cond_null, values1}, values_null); + CheckScalar("case_when", {cond_true, values1, values2}, values1); + CheckScalar("case_when", {cond_false, values1, values2}, values2); + CheckScalar("case_when", {cond_null, values1, values2}, values2); + + CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); + CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, + values2); + + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", null])")); + CheckScalar("case_when", {cond1, scalar_null}, values_null); + CheckScalar("case_when", {cond1, scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, "1.23", "1.23"])")); + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", "1.23"])")); + + CheckScalar("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, R"(["3.45", null, null, null])")); + CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, R"(["3.45", null, null, "6.78"])")); + CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, "6.78"])")); + } +} + +TEST(TestCaseWhen, FixedSizeBinary) { + auto type = fixed_size_binary(3); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"("abc")"); + auto scalar2 = ScalarFromJSON(type, R"("bcd")"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"(["cde", null, "def", "efg"])"); + auto values2 = ArrayFromJSON(type, R"(["fgh", "ghi", null, "hij"])"); + + CheckScalar("case_when", {values1}, values1); + CheckScalar("case_when", {values_null}, values_null); + + CheckScalar("case_when", {cond_true, values1}, values1); + CheckScalar("case_when", {cond_false, values1}, values_null); + CheckScalar("case_when", {cond_null, values1}, values_null); + CheckScalar("case_when", {cond_true, values1, values2}, values1); + CheckScalar("case_when", {cond_false, values1, values2}, values2); + CheckScalar("case_when", {cond_null, values1, values2}, values2); + + CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); + CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); + CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); + CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + ArrayFromJSON(type, R"(["abc", "abc", "bcd", null])")); + CheckScalar("case_when", {cond1, scalar_null}, values_null); + CheckScalar("case_when", {cond1, scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, "abc", "abc"])")); + CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + ArrayFromJSON(type, R"(["abc", "abc", "bcd", "abc"])")); + + CheckScalar("case_when", {cond1, values1, cond2, values2}, + ArrayFromJSON(type, R"(["cde", null, null, null])")); + CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + ArrayFromJSON(type, R"(["cde", null, null, "efg"])")); + CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, "efg"])")); +} + +TEST(TestCaseWhen, DispatchBest) { + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction("case_when")); + auto Check = + [&](std::vector original_values) -> Result> { + auto values = original_values; + RETURN_NOT_OK(function->DispatchBest(&values)); + return values; + }; + + // Since DispatchBest for this kernel pulls tricks, we can't compare it to DispatchExact + // as CheckDispatchBest does + EXPECT_EQ((std::vector{int32()}), *Check({int32()})); + EXPECT_EQ((std::vector{boolean(), int64(), int64()}), + *Check({boolean(), int32(), int64()})); + EXPECT_EQ((std::vector{boolean(), int64(), int64()}), + *Check({null(), int32(), int64()})); + ASSERT_RAISES(TypeError, Check({boolean(), utf8(), int32()})); + ASSERT_RAISES(TypeError, Check({int32(), int32(), int32()})); + ASSERT_RAISES(Invalid, Check({})); +} + +TEST(TestCaseWhen, Errors) { + ASSERT_RAISES(Invalid, CaseWhen({})); + ASSERT_RAISES(TypeError, CaseWhen({ArrayFromJSON(utf8(), "[\"\"]"), + ArrayFromJSON(int32(), "[0]")})); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index a1151717d8b..ce8d42e34c2 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -47,12 +47,10 @@ DatumVector GetDatums(const std::vector& inputs) { } void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, - const std::shared_ptr& expected, - const FunctionOptions* options) { + const Datum& expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - std::shared_ptr actual = std::move(out).make_array(); - ValidateOutput(*actual); - AssertArraysEqual(*expected, *actual, /*verbose=*/true); + ValidateOutput(out); + AssertDatumsEqual(expected, out, /*verbose=*/true); } template @@ -103,35 +101,38 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, } } -void CheckScalar(std::string func_name, const DatumVector& inputs, - std::shared_ptr expected, const FunctionOptions* options) { - CheckScalarNonRecursive(func_name, inputs, expected, options); +void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected_datum, + const FunctionOptions* options) { + CheckScalarNonRecursive(func_name, inputs, expected_datum, options); + + if (expected_datum.is_scalar()) return; + ASSERT_TRUE(expected_datum.is_array()) + << "CheckScalar is only implemented for scalar/array expected values"; + auto expected = expected_datum.make_array(); // check for at least 1 array, and make sure the others are of equal length - std::shared_ptr array; + bool has_array = false; for (const auto& input : inputs) { if (input.is_array()) { - if (!array) { - array = input.make_array(); - } else { - ASSERT_EQ(input.array()->length, array->length()); - } + ASSERT_EQ(input.array()->length, expected->length()); + has_array = true; } } + ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output"; // Check all the input scalars, if scalars are implemented if (std::none_of(inputs.begin(), inputs.end(), [](const Datum& datum) { return datum.type()->id() == Type::EXTENSION; })) { // Check all the input scalars - for (int64_t i = 0; i < array->length(); ++i) { + for (int64_t i = 0; i < expected->length(); ++i) { CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options); } } // Since it's a scalar function, calling it on sliced inputs should // result in the sliced expected output. - const auto slice_length = array->length() / 3; + const auto slice_length = expected->length() / 3; if (slice_length > 0) { CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, slice_length), expected->Slice(0, slice_length), options); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index c691a9f3be3..a3fb9308f58 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -95,8 +95,7 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); -void CheckScalar(std::string func_name, const DatumVector& inputs, - std::shared_ptr expected, +void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 6ce808aba67..39f07b348e5 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -863,30 +863,40 @@ Structural transforms +--------------------------+------------+------------------------------------------------+---------------------+---------+ | Function name | Arity | Input types | Output type | Notes | +==========================+============+================================================+=====================+=========+ -| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | +| case_when | Varargs | Boolean, Any fixed-width | Input type | \(1) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(2) | +| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(2) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(3) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(3) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(4) | +| is_finite | Unary | Float, Double | Boolean | \(4) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(5) | +| is_inf | Unary | Float, Double | Boolean | \(5) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(6) | +| is_nan | Unary | Float, Double | Boolean | \(6) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(7) | +| is_null | Unary | Any | Boolean | \(7) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(8) | +| is_valid | Unary | Any | Boolean | \(8) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| project | Varargs | Any | Struct | \(9) | +| list_value_length | Unary | List-like | Int32 or Int64 | \(9) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ +| project | Varargs | Any | Struct | \(10) | ++--------------------------+------------+------------------------------------------------+---------------------+---------+ + +* \(1) This function acts like a SQL 'case when' statement or switch-case. The + input is any number of alternating Boolean and value data, followed by an + optional value datum to represent the 'else' or 'default' case. At least one + input must be provided. The output is of the same type as the value inputs; + each row will be the corresponding value from the first value datum for which + the corresponding Boolean is true, or the corresponding value from the + 'default' input, or null otherwise. -* \(1) First input must be an array, second input a scalar of the same type. +* \(2) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values as the first input, except for nulls replaced with the second input value. -* \(2) First input must be a Boolean scalar or array. Second and third inputs +* \(3) First input must be a Boolean scalar or array. Second and third inputs could be scalars or arrays and must be of the same type. Output is an array (or scalar if all inputs are scalar) of the same type as the second/ third input. If the nulls present on the first input, they will be promoted to the @@ -894,21 +904,21 @@ Structural transforms Also see: :ref:`replace_with_mask `. -* \(3) Output is true iff the corresponding input element is finite (not Infinity, +* \(4) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). -* \(4) Output is true iff the corresponding input element is Infinity/-Infinity. +* \(5) Output is true iff the corresponding input element is Infinity/-Infinity. -* \(5) Output is true iff the corresponding input element is NaN. +* \(6) Output is true iff the corresponding input element is NaN. -* \(6) Output is true iff the corresponding input element is null. +* \(7) Output is true iff the corresponding input element is null. -* \(7) Output is true iff the corresponding input element is non-null. +* \(8) Output is true iff the corresponding input element is non-null. -* \(8) Each output element is the length of the corresponding input element +* \(9) Each output element is the length of the corresponding input element (null if input is null). Output type is Int32 for List, Int64 for LargeList. -* \(9) The output struct's field types are the types of its arguments. The +* \(10) The output struct's field types are the types of its arguments. The field names are specified using an instance of :struct:`ProjectOptions`. The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 09c67598193..c12f2f91b26 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -335,6 +335,7 @@ Structural Transforms :toctree: ../generated/ binary_length + case_when fill_null if_else is_finite From 72dfd853b7d3ac6b1a4ea6318f1a4ea38fc39c01 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jul 2021 12:37:12 -0400 Subject: [PATCH 02/11] ARROW-13064: [C++] Tweak signature of case_when --- cpp/src/arrow/compute/api_scalar.cc | 8 +- cpp/src/arrow/compute/api_scalar.h | 7 +- cpp/src/arrow/compute/kernel.cc | 23 +- cpp/src/arrow/compute/kernel.h | 6 +- cpp/src/arrow/compute/kernel_test.cc | 31 +- .../arrow/compute/kernels/codegen_internal.cc | 18 +- .../arrow/compute/kernels/codegen_internal.h | 3 + .../arrow/compute/kernels/scalar_if_else.cc | 251 ++++++------ .../kernels/scalar_if_else_benchmark.cc | 17 +- .../compute/kernels/scalar_if_else_test.cc | 363 ++++++++++-------- 10 files changed, 415 insertions(+), 312 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 0588f748812..68df5f98b10 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -466,8 +466,12 @@ Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_fa return CallFunction("if_else", {cond, if_true, if_false}, ctx); } -Result CaseWhen(const std::vector& cases, ExecContext* ctx) { - return CallFunction("case_when", cases, ctx); +Result CaseWhen(const Datum& cond, const std::vector& cases, + ExecContext* ctx) { + std::vector args = {cond}; + args.reserve(cases.size() + 1); + args.insert(args.end(), cases.begin(), cases.end()); + return CallFunction("case_when", args, ctx); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 28a0ca53c52..bbaa4d13a21 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -746,8 +746,8 @@ Result IfElse(const Datum& cond, const Datum& left, const Datum& right, /// true, or (if given) select the 'else' value, else emit null. Note that a /// null condition is the same as false. /// -/// \param[in] cases Zero or more pairs of conditions (Boolean) & values (any -/// type), along with an optional 'else' value. +/// \param[in] cond Conditions (Boolean) +/// \param[in] cases Values (any type), along with an optional 'else' value. /// \param[in] ctx the function execution context, optional /// /// \return the resulting datum @@ -755,7 +755,8 @@ Result IfElse(const Datum& cond, const Datum& left, const Datum& right, /// \since 5.0.0 /// \note API not yet finalized ARROW_EXPORT -Result CaseWhen(const std::vector& cases, ExecContext* ctx = NULLPTR); +Result CaseWhen(const Datum& cond, const std::vector& cases, + ExecContext* ctx = NULLPTR); /// \brief Year returns year for each element of `values` /// diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 6cdd17adcc9..f131f524d2e 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -402,8 +402,7 @@ KernelSignature::KernelSignature(std::vector in_types, OutputType out out_type_(std::move(out_type)), is_varargs_(is_varargs), hash_code_(0) { - // VarArgs sigs must have only a single input type to use for argument validation - DCHECK(!is_varargs || (is_varargs && (in_types_.size() == 1))); + DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1))); } std::shared_ptr KernelSignature::Make(std::vector in_types, @@ -430,8 +429,8 @@ bool KernelSignature::Equals(const KernelSignature& other) const { bool KernelSignature::MatchesInputs(const std::vector& args) const { if (is_varargs_) { - for (const auto& arg : args) { - if (!in_types_[0].Matches(arg)) { + for (size_t i = 0; i < args.size(); ++i) { + if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) { return false; } } @@ -464,15 +463,19 @@ std::string KernelSignature::ToString() const { std::stringstream ss; if (is_varargs_) { - ss << "varargs[" << in_types_[0].ToString() << "]"; + ss << "varargs["; } else { ss << "("; - for (size_t i = 0; i < in_types_.size(); ++i) { - if (i > 0) { - ss << ", "; - } - ss << in_types_[i].ToString(); + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (i > 0) { + ss << ", "; } + ss << in_types_[i].ToString(); + } + if (is_varargs_) { + ss << "]"; + } else { ss << ")"; } ss << " -> " << out_type_.ToString(); diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 50b1dd8e55e..36d20c7289e 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -366,8 +366,10 @@ class ARROW_EXPORT OutputType { /// \brief Holds the input types and output type of the kernel. /// -/// VarArgs functions should pass a single input type to be used to validate -/// the input types of a function invocation. +/// VarArgs functions with minimum N arguments should pass up to N input types to be +/// used to validate the input types of a function invocation. The first N-1 types +/// will be matched against the first N-1 arguments, and the last type will be +/// matched against the remaining arguments. class ARROW_EXPORT KernelSignature { public: KernelSignature(std::vector in_types, OutputType out_type, diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index a5ef9d44e18..a63c42d4fde 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -468,15 +468,28 @@ TEST(KernelSignature, MatchesInputs) { } TEST(KernelSignature, VarArgsMatchesInputs) { - KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - - std::vector args = {int8()}; - ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(int8())); - args.push_back(ValueDescr::Array(int8())); - ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(int32()); - ASSERT_FALSE(sig.MatchesInputs(args)); + { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(int8())); + args.push_back(ValueDescr::Array(int8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); + } + { + KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(utf8())); + args.push_back(ValueDescr::Array(utf8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); + } } TEST(KernelSignature, ToString) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index b9bde999447..673db088eae 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -218,9 +218,14 @@ void ReplaceTypes(const std::shared_ptr& type, } std::shared_ptr CommonNumeric(const std::vector& descrs) { - DCHECK(!descrs.empty()) << "tried to find CommonNumeric type of an empty set"; + return CommonNumeric(descrs.data(), descrs.size()); +} - for (const auto& descr : descrs) { +std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { + DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set"; + + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); auto id = descr.type->id(); if (!is_floating(id) && !is_integer(id)) { // a common numeric type is only possible if all types are numeric @@ -232,17 +237,20 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { } } - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); if (descr.type->id() == Type::DOUBLE) return float64(); } - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); if (descr.type->id() == Type::FLOAT) return float32(); } int max_width_signed = 0, max_width_unsigned = 0; - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); auto id = descr.type->id(); auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned); *max_width = std::max(bit_width(id), *max_width); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 12e80423f7f..d28ede4f77a 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1367,6 +1367,9 @@ void ReplaceTypes(const std::shared_ptr&, std::vector* des ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); +ARROW_EXPORT +std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); + ARROW_EXPORT std::shared_ptr CommonTimestamp(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 7baaeeabe1e..0364bc40a1f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -786,88 +786,94 @@ struct CaseWhenFunction : ScalarFunction { using ScalarFunction::ScalarFunction; Result DispatchBest(std::vector* values) const override { + // The first function is a struct of booleans, where the number of fields in the + // struct is either equal to the number of other arguments or is one less. RETURN_NOT_OK(CheckArity(*values)); - std::vector value_types; - for (size_t i = 0; i < values->size() - 1; i += 2) { - ValueDescr* cond = &(*values)[i]; - if (cond->type->id() == Type::NA) { - cond->type = boolean(); - } - if (cond->type->id() != Type::BOOL) { - return Status::TypeError("Condition arguments must be boolean, but argument ", i, - " was ", cond->type->ToString()); - } - value_types.push_back((*values)[i + 1]); + EnsureDictionaryDecoded(values); + auto first_type = (*values)[0].type; + if (first_type->id() != Type::STRUCT) { + return Status::TypeError("case_when: first argument must be STRUCT, not ", + *first_type); } - if (values->size() % 2 != 0) { - // Have an ELSE clause - value_types.push_back(values->back()); + auto num_fields = static_cast(first_type->num_fields()); + if (num_fields < values->size() - 2 || num_fields >= values->size()) { + return Status::Invalid( + "case_when: number of struct fields must be equal to or one less than count of " + "remaining arguments (", + values->size() - 1, "), got: ", first_type->num_fields()); } - EnsureDictionaryDecoded(&value_types); - if (auto type = CommonNumeric(value_types)) { - ReplaceTypes(type, &value_types); + for (const auto& field : first_type->fields()) { + if (field->type()->id() != Type::BOOL) { + return Status::TypeError( + "case_when: all fields of first argument must be BOOL, but ", field->name(), + " was of type: ", *field->type()); + } } - const DataType& common_values_type = *value_types.front().type; - auto next_type = value_types.cbegin(); - for (size_t i = 0; i < values->size(); i += 2) { - if (!common_values_type.Equals(next_type->type)) { - return Status::TypeError("Value arguments must be of same type, but argument ", i, - " was ", next_type->type->ToString(), " (expected ", - common_values_type.ToString(), ")"); - } - if (i == values->size() - 1) { - // ELSE - (*values)[i] = *next_type++; - } else { - (*values)[i + 1] = *next_type++; + if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { + for (auto it = values->begin() + 1; it != values->end(); it++) { + it->type = type; } } - - // We register a unary kernel for each value type and dispatch to it after validation. - if (auto kernel = DispatchExactImpl(this, {values->back()})) return kernel; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *values); } }; -// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar arguments +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - for (size_t i = 0; i < batch.values.size() - 1; i += 2) { - const Scalar& cond = *batch[i].scalar(); - if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) { - *out = batch[i + 1]; - return Status::OK(); + const auto& conds = checked_cast(*batch.values[0].scalar()); + Datum result; + for (size_t i = 0; i < batch.values.size() - 1; i++) { + if (i < conds.value.size()) { + const Scalar& cond = *conds.value[i]; + if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) { + result = batch[i + 1]; + break; + } + } else { + // ELSE clause + result = batch[i + 1]; + break; } } - if (batch.values.size() % 2 == 0) { - // No ELSE - *out = MakeNullScalar(batch[1].type()); + if (out->is_scalar()) { + *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); + } else if (result.is_value()) { + if (result.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto temp, MakeArrayFromScalar(*result.scalar(), batch.length, + ctx->memory_pool())); + *out->mutable_array() = *temp->data(); + } else { + *out->mutable_array() = *result.array(); + } } else { - *out = batch.values.back(); + ARROW_ASSIGN_OR_RAISE(auto temp, + MakeArrayOfNull(out->type(), batch.length, ctx->memory_pool())); + *out->mutable_array() = *temp->data(); } return Status::OK(); } // Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, -// given helper functions to copy data from a source array to a target array and to -// allocate a values buffer +// given helper functions to copy data from a source array to a target array template Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& conds_array = *batch.values[0].array(); ArrayData* output = out->mutable_array(); - const bool have_else_arg = batch.values.size() % 2 != 0; - // Check if we may need a validity bitmap + const auto num_value_args = batch.values.size() - 1; + const bool have_else_arg = + static_cast(conds_array.type->num_fields()) < num_value_args; uint8_t* out_valid = nullptr; + // Check if we may need a validity bitmap bool need_valid_bitmap = false; if (!have_else_arg) { // If we don't have an else arg -> need a bitmap since we may emit nulls need_valid_bitmap = true; - } else if (batch.values.back().null_count() > 0) { - // If the 'else' array has a null count we need a validity bitmap - need_valid_bitmap = true; } else { // Otherwise if any value array has a null count we need a validity bitmap - for (size_t i = 1; i < batch.values.size(); i += 2) { + for (size_t i = 1; i < batch.values.size(); i++) { if (batch[i].null_count() > 0) { need_valid_bitmap = true; break; @@ -880,6 +886,13 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } // Initialize values buffer + auto bit_width = checked_cast(*out->type()).bit_width(); + if (bit_width == 1) { + ARROW_ASSIGN_OR_RAISE(output->buffers[1], ctx->AllocateBitmap(batch.length)); + } else { + auto byte_width = BitUtil::BytesForBits(bit_width); + ARROW_ASSIGN_OR_RAISE(output->buffers[1], ctx->Allocate(batch.length * byte_width)); + } uint8_t* out_values = output->buffers[1]->mutable_data(); if (have_else_arg) { // Copy 'else' value into output @@ -894,96 +907,101 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length)); uint8_t* mask = mask_buffer->mutable_data(); std::memset(mask, 0xFF, mask_buffer->size()); + // Then iterate through each argument in turn and set elements. - for (size_t i = 0; i < batch.values.size() - 1; i += 2) { - const Datum& cond_datum = batch[i]; + const uint8_t* conds_valid = + conds_array.GetNullCount() > 0 ? conds_array.buffers[0]->data() : nullptr; + for (size_t i = 0; i < batch.values.size() - (have_else_arg ? 2 : 1); i++) { + const ArrayData& cond_array = *conds_array.child_data[i]; + const int64_t cond_offset = conds_array.offset + cond_array.offset; + const uint8_t* cond_values = cond_array.buffers[1]->data(); const Datum& values_datum = batch[i + 1]; - if (cond_datum.is_scalar()) { - const Scalar& cond_scalar = *cond_datum.scalar(); - const bool cond = - cond_scalar.is_valid && UnboxScalar::Unbox(cond_scalar); - if (!cond) continue; - BitBlockCounter counter(mask, /*start_offset=*/0, batch.length); - int64_t offset = 0; + int64_t offset = 0; + + if (!conds_valid && cond_array.GetNullCount() == 0) { + // If no valid buffer, visit mask & cond bitmap simultaneously + BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset, + batch.length); while (offset < batch.length) { - const auto block = counter.NextWord(); + const auto block = counter.NextAndWord(); if (block.AllSet()) { CopyValues(values_datum, out_valid, out_values, offset, block.length); + BitUtil::SetBitsTo(mask, offset, block.length, false); } else if (block.popcount) { for (int64_t j = 0; j < block.length; ++j) { - if (BitUtil::GetBit(mask, offset + j)) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_values, cond_offset + offset + j)) { CopyValues(values_datum, out_valid, out_values, offset + j, /*length=*/1); + BitUtil::SetBitTo(mask, offset + j, false); } } } offset += block.length; } - break; - } - - const ArrayData& cond_array = *cond_datum.array(); - const uint8_t* cond_values = cond_array.buffers[1]->data(); - int64_t offset = 0; - // If no valid buffer, visit mask & value bitmap simultaneously - if (cond_array.GetNullCount() == 0) { - BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, - cond_array.offset, batch.length); - while (offset < batch.length) { - const auto block = counter.NextAndWord(); - if (block.AllSet()) { - CopyValues(values_datum, out_valid, out_values, offset, block.length); - BitUtil::SetBitsTo(mask, offset, block.length, false); - } else if (block.popcount) { - for (int64_t j = 0; j < block.length; ++j) { + continue; + } else if (!conds_valid) { + // Visit mask & cond bitmap & cond validity + const uint8_t* cond_valid = cond_array.buffers[0]->data(); + Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, + {cond_values, cond_offset, batch.length}, + {cond_valid, cond_offset, batch.length}}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + const uint64_t word = words[0] & words[1] & words[2]; + const int64_t block_length = std::min(64, batch.length - offset); + if (word == std::numeric_limits::max()) { + CopyValues(values_datum, out_valid, out_values, offset, block_length); + BitUtil::SetBitsTo(mask, offset, block_length, false); + } else if (word) { + for (int64_t j = 0; j < block_length; ++j) { if (BitUtil::GetBit(mask, offset + j) && - BitUtil::GetBit(cond_values, cond_array.offset + offset + j)) { + BitUtil::GetBit(cond_valid, cond_offset + offset + j) && + BitUtil::GetBit(cond_values, cond_offset + offset + j)) { CopyValues(values_datum, out_valid, out_values, offset + j, /*length=*/1); BitUtil::SetBitTo(mask, offset + j, false); } } } - offset += block.length; - } - continue; - } - - // Else visit all three bitmaps simultaneously - const uint8_t* cond_valid = cond_array.buffers[0]->data(); - Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, - {cond_values, cond_array.offset, batch.length}, - {cond_valid, cond_array.offset, batch.length}}; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - const uint64_t word = words[0] & words[1] & words[2]; - const int64_t block_length = std::min(64, batch.length - offset); - if (word == std::numeric_limits::max()) { - CopyValues(values_datum, out_valid, out_values, offset, block_length); - BitUtil::SetBitsTo(mask, offset, block_length, false); - } else if (word) { - for (int64_t j = 0; j < block_length; ++j) { - if (BitUtil::GetBit(mask, offset + j) && - BitUtil::GetBit(cond_valid, cond_array.offset + offset + j) && - BitUtil::GetBit(cond_values, cond_array.offset + offset + j)) { - CopyValues(values_datum, out_valid, out_values, offset + j, - /*length=*/1); - BitUtil::SetBitTo(mask, offset + j, false); + }); + } else { + // Visit mask & cond bitmap & cond validity & struct validity + const uint8_t* cond_valid = cond_array.buffers[0]->data(); + Bitmap bitmaps[4] = {{mask, /*offset=*/0, batch.length}, + {cond_values, cond_offset, batch.length}, + {cond_valid, cond_offset, batch.length}, + {conds_valid, conds_array.offset, batch.length}}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + const uint64_t word = words[0] & words[1] & words[2] & words[3]; + const int64_t block_length = std::min(64, batch.length - offset); + if (word == std::numeric_limits::max()) { + CopyValues(values_datum, out_valid, out_values, offset, block_length); + BitUtil::SetBitsTo(mask, offset, block_length, false); + } else if (word) { + for (int64_t j = 0; j < block_length; ++j) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_valid, cond_offset + offset + j) && + BitUtil::GetBit(cond_values, cond_offset + offset + j) && + BitUtil::GetBit(conds_valid, conds_array.offset + offset + j)) { + CopyValues(values_datum, out_valid, out_values, offset + j, + /*length=*/1); + BitUtil::SetBitTo(mask, offset + j, false); + } } } - } - offset += block_length; - }); + offset += block_length; + }); + } } + // TODO: need to initialize output values return Status::OK(); } template struct CaseWhenFunctor { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - for (const auto& datum : batch.values) { - if (datum.is_array()) { - return ExecArrayCaseWhen(ctx, batch, out); - } + if (batch.values[0].is_array()) { + return ExecArrayCaseWhen(ctx, batch, out); } return ExecScalarCaseWhen(ctx, batch, out); } @@ -1004,11 +1022,14 @@ Result LastType(KernelContext*, const std::vector& descr void AddCaseWhenKernel(const std::shared_ptr& scalar_function, detail::GetTypeId get_id, ArrayKernelExec exec) { - ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(LastType), - /*is_varargs=*/true), - exec); + ScalarKernel kernel( + KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, + OutputType(LastType), + /*is_varargs=*/true), + exec); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + kernel.can_write_into_slices = false; DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 83639235bb5..e4b71aceab8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -122,12 +122,15 @@ static void CaseWhenBench(benchmark::State& state) { rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val4 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); + ASSERT_OK_AND_ASSIGN( + auto cond, + StructArray::Make({cond1, cond2, cond3}, std::vector{"a", "b", "c"}, + nullptr, /*null_count=*/0)); for (auto _ : state) { ABORT_NOT_OK( - CaseWhen({cond1->Slice(offset), val1->Slice(offset), cond2->Slice(offset), - val2->Slice(offset), cond3->Slice(offset), val3->Slice(offset), - val4->Slice(offset)})); + CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), + val3->Slice(offset), val4->Slice(offset)})); } state.SetBytesProcessed(state.iterations() * @@ -160,11 +163,13 @@ static void CaseWhenBenchContiguous(benchmark::State& state) { rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val3 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); + ASSERT_OK_AND_ASSIGN( + auto cond, StructArray::Make({cond1, cond2}, std::vector{"a", "b"}, + nullptr, /*null_count=*/0)); for (auto _ : state) { - ABORT_NOT_OK( - CaseWhen({cond1->Slice(offset), val1->Slice(offset), cond2->Slice(offset), - val2->Slice(offset), val3->Slice(offset)})); + ABORT_NOT_OK(CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), + val3->Slice(offset)})); } state.SetBytesProcessed(state.iterations() * diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 0a9f5548368..3eb682d97be 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -323,6 +323,17 @@ class TestCaseWhenNumeric : public ::testing::Test {}; TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); +Datum MakeStruct(const std::vector& conds) { + ProjectOptions options; + options.field_names.resize(conds.size()); + options.field_metadata.resize(conds.size()); + for (const auto& datum : conds) { + options.field_nullability.push_back(datum.null_count() > 0); + } + EXPECT_OK_AND_ASSIGN(auto result, CallFunction("project", conds, &options)); + return result; +} + TYPED_TEST(TestCaseWhenNumeric, FixedSize) { auto type = default_type_instance(); auto cond_true = ScalarFromJSON(boolean(), "true"); @@ -337,55 +348,66 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) { auto values1 = ArrayFromJSON(type, "[3, null, 5, 6]"); auto values2 = ArrayFromJSON(type, "[7, 8, null, 10]"); - CheckScalar("case_when", {values1}, values1); - CheckScalar("case_when", {values_null}, values_null); - - CheckScalar("case_when", {cond_true, values1}, values1); - CheckScalar("case_when", {cond_false, values1}, values_null); - CheckScalar("case_when", {cond_null, values1}, values_null); - CheckScalar("case_when", {cond_true, values1, values2}, values1); - CheckScalar("case_when", {cond_false, values1, values2}, values2); - CheckScalar("case_when", {cond_null, values1, values2}, values2); - - CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); - CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, ArrayFromJSON(type, "[1, 1, 2, null]")); - CheckScalar("case_when", {cond1, scalar_null}, values_null); - CheckScalar("case_when", {cond1, scalar_null, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, ArrayFromJSON(type, "[null, null, 1, 1]")); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, ArrayFromJSON(type, "[1, 1, 2, 1]")); - CheckScalar("case_when", {cond1, values1, cond2, values2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, ArrayFromJSON(type, "[3, null, null, null]")); - CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, ArrayFromJSON(type, "[3, null, null, 6]")); - CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, ArrayFromJSON(type, "[null, null, null, 6]")); - CheckScalar("case_when", - {ArrayFromJSON(boolean(), - "[true, true, true, false, false, false, null, null, null]"), - ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), - ArrayFromJSON(boolean(), - "[true, false, null, true, false, null, true, false, null]"), - ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]")}, - ArrayFromJSON(type, "[10, 11, 12, 23, null, null, 26, null, null]")); - CheckScalar("case_when", - {ArrayFromJSON(boolean(), - "[true, true, true, false, false, false, null, null, null]"), - ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), - ArrayFromJSON(boolean(), - "[true, false, null, true, false, null, true, false, null]"), - ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"), - ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")}, - ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]")); + CheckScalar( + "case_when", + {MakeStruct( + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]")}), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, null, null, 26, null, null]")); + CheckScalar( + "case_when", + {MakeStruct( + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]")}), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"), + ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]")); } TEST(TestCaseWhen, Null) { @@ -394,10 +416,10 @@ TEST(TestCaseWhen, Null) { auto cond_arr = ArrayFromJSON(boolean(), "[true, true, false, null]"); auto scalar = ScalarFromJSON(null(), "null"); auto array = ArrayFromJSON(null(), "[null, null, null, null]"); - CheckScalar("case_when", {array}, array); - CheckScalar("case_when", {cond_false, array}, array); - CheckScalar("case_when", {cond_true, array, array}, array); - CheckScalar("case_when", {cond_arr, array, cond_true, array}, array); + CheckScalar("case_when", {MakeStruct({}), array}, array); + CheckScalar("case_when", {MakeStruct({cond_false}), array}, array); + CheckScalar("case_when", {MakeStruct({cond_true}), array, array}, array); + CheckScalar("case_when", {MakeStruct({cond_arr, cond_true}), array, array}, array); } TEST(TestCaseWhen, Boolean) { @@ -414,36 +436,42 @@ TEST(TestCaseWhen, Boolean) { auto values1 = ArrayFromJSON(type, "[true, null, true, true]"); auto values2 = ArrayFromJSON(type, "[false, false, null, false]"); - CheckScalar("case_when", {values1}, values1); - CheckScalar("case_when", {values_null}, values_null); - - CheckScalar("case_when", {cond_true, values1}, values1); - CheckScalar("case_when", {cond_false, values1}, values_null); - CheckScalar("case_when", {cond_null, values1}, values_null); - CheckScalar("case_when", {cond_true, values1, values2}, values1); - CheckScalar("case_when", {cond_false, values1, values2}, values2); - CheckScalar("case_when", {cond_null, values1, values2}, values2); - - CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); - CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, ArrayFromJSON(type, "[true, true, false, null]")); - CheckScalar("case_when", {cond1, scalar_null}, values_null); - CheckScalar("case_when", {cond1, scalar_null, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, ArrayFromJSON(type, "[null, null, true, true]")); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, ArrayFromJSON(type, "[true, true, false, true]")); - CheckScalar("case_when", {cond1, values1, cond2, values2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, ArrayFromJSON(type, "[true, null, null, null]")); - CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, ArrayFromJSON(type, "[true, null, null, true]")); - CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, ArrayFromJSON(type, "[null, null, null, true]")); } @@ -461,36 +489,42 @@ TEST(TestCaseWhen, DayTimeInterval) { auto values1 = ArrayFromJSON(type, "[[3, 3], null, [5, 5], [6, 6]]"); auto values2 = ArrayFromJSON(type, "[[7, 7], [8, 8], null, [10, 10]]"); - CheckScalar("case_when", {values1}, values1); - CheckScalar("case_when", {values_null}, values_null); - - CheckScalar("case_when", {cond_true, values1}, values1); - CheckScalar("case_when", {cond_false, values1}, values_null); - CheckScalar("case_when", {cond_null, values1}, values_null); - CheckScalar("case_when", {cond_true, values1, values2}, values1); - CheckScalar("case_when", {cond_false, values1, values2}, values2); - CheckScalar("case_when", {cond_null, values1, values2}, values2); - - CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); - CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], null]")); - CheckScalar("case_when", {cond1, scalar_null}, values_null); - CheckScalar("case_when", {cond1, scalar_null, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, ArrayFromJSON(type, "[null, null, [1, 1], [1, 1]]")); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], [1, 1]]")); - CheckScalar("case_when", {cond1, values1, cond2, values2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, ArrayFromJSON(type, "[[3, 3], null, null, null]")); - CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, ArrayFromJSON(type, "[[3, 3], null, null, [6, 6]]")); - CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, ArrayFromJSON(type, "[null, null, null, [6, 6]]")); } @@ -509,37 +543,43 @@ TEST(TestCaseWhen, Decimal) { auto values1 = ArrayFromJSON(type, R"(["3.45", null, "5.67", "6.78"])"); auto values2 = ArrayFromJSON(type, R"(["7.89", "8.90", null, "1.01"])"); - CheckScalar("case_when", {values1}, values1); - CheckScalar("case_when", {values_null}, values_null); - - CheckScalar("case_when", {cond_true, values1}, values1); - CheckScalar("case_when", {cond_false, values1}, values_null); - CheckScalar("case_when", {cond_null, values1}, values_null); - CheckScalar("case_when", {cond_true, values1, values2}, values1); - CheckScalar("case_when", {cond_false, values1, values2}, values2); - CheckScalar("case_when", {cond_null, values1, values2}, values2); - - CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); - CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", null])")); - CheckScalar("case_when", {cond1, scalar_null}, values_null); - CheckScalar("case_when", {cond1, scalar_null, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, ArrayFromJSON(type, R"([null, null, "1.23", "1.23"])")); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", "1.23"])")); - CheckScalar("case_when", {cond1, values1, cond2, values2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, ArrayFromJSON(type, R"(["3.45", null, null, null])")); - CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, ArrayFromJSON(type, R"(["3.45", null, null, "6.78"])")); - CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, ArrayFromJSON(type, R"([null, null, null, "6.78"])")); } } @@ -558,64 +598,67 @@ TEST(TestCaseWhen, FixedSizeBinary) { auto values1 = ArrayFromJSON(type, R"(["cde", null, "def", "efg"])"); auto values2 = ArrayFromJSON(type, R"(["fgh", "ghi", null, "hij"])"); - CheckScalar("case_when", {values1}, values1); - CheckScalar("case_when", {values_null}, values_null); - - CheckScalar("case_when", {cond_true, values1}, values1); - CheckScalar("case_when", {cond_false, values1}, values_null); - CheckScalar("case_when", {cond_null, values1}, values_null); - CheckScalar("case_when", {cond_true, values1, values2}, values1); - CheckScalar("case_when", {cond_false, values1, values2}, values2); - CheckScalar("case_when", {cond_null, values1, values2}, values2); - - CheckScalar("case_when", {cond_true, values1, cond_true, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_false, values2}, values_null); - CheckScalar("case_when", {cond_true, values1, cond_false, values2}, values1); - CheckScalar("case_when", {cond_false, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_null, values1, cond_true, values2}, values2); - CheckScalar("case_when", {cond_false, values1, cond_false, values2, values2}, values2); + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, ArrayFromJSON(type, R"(["abc", "abc", "bcd", null])")); - CheckScalar("case_when", {cond1, scalar_null}, values_null); - CheckScalar("case_when", {cond1, scalar_null, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, ArrayFromJSON(type, R"([null, null, "abc", "abc"])")); - CheckScalar("case_when", {cond1, scalar1, cond2, scalar2, scalar1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, ArrayFromJSON(type, R"(["abc", "abc", "bcd", "abc"])")); - CheckScalar("case_when", {cond1, values1, cond2, values2}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, ArrayFromJSON(type, R"(["cde", null, null, null])")); - CheckScalar("case_when", {cond1, values1, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, ArrayFromJSON(type, R"(["cde", null, null, "efg"])")); - CheckScalar("case_when", {cond1, values_null, cond2, values2, values1}, + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, ArrayFromJSON(type, R"([null, null, null, "efg"])")); } TEST(TestCaseWhen, DispatchBest) { - ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction("case_when")); - auto Check = - [&](std::vector original_values) -> Result> { - auto values = original_values; - RETURN_NOT_OK(function->DispatchBest(&values)); - return values; - }; - - // Since DispatchBest for this kernel pulls tricks, we can't compare it to DispatchExact - // as CheckDispatchBest does - EXPECT_EQ((std::vector{int32()}), *Check({int32()})); - EXPECT_EQ((std::vector{boolean(), int64(), int64()}), - *Check({boolean(), int32(), int64()})); - EXPECT_EQ((std::vector{boolean(), int64(), int64()}), - *Check({null(), int32(), int64()})); - ASSERT_RAISES(TypeError, Check({boolean(), utf8(), int32()})); - ASSERT_RAISES(TypeError, Check({int32(), int32(), int32()})); - ASSERT_RAISES(Invalid, Check({})); -} - -TEST(TestCaseWhen, Errors) { - ASSERT_RAISES(Invalid, CaseWhen({})); - ASSERT_RAISES(TypeError, CaseWhen({ArrayFromJSON(utf8(), "[\"\"]"), - ArrayFromJSON(int32(), "[0]")})); + CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()}, + {struct_({field("", boolean())}), int64(), int64()}); + + ASSERT_RAISES(Invalid, CallFunction("case_when", {})); + // Too many/too few conditions + ASSERT_RAISES( + Invalid, CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")})})); + ASSERT_RAISES(Invalid, + CallFunction("case_when", {MakeStruct({}), ArrayFromJSON(int64(), "[]"), + ArrayFromJSON(int64(), "[]")})); + // Conditions must be struct of boolean + ASSERT_RAISES(TypeError, + CallFunction("case_when", {MakeStruct({ArrayFromJSON(int64(), "[]")}), + ArrayFromJSON(int64(), "[]")})); + ASSERT_RAISES(TypeError, CallFunction("case_when", {ArrayFromJSON(boolean(), "[true]"), + ArrayFromJSON(int32(), "[0]")})); + // Values must have compatible types + ASSERT_RAISES(NotImplemented, + CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}), + ArrayFromJSON(int64(), "[]"), + ArrayFromJSON(utf8(), "[]")})); } } // namespace compute } // namespace arrow From 63c0034812dc80b0241ab24d01180ba93c7cfd72 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 11:46:33 -0400 Subject: [PATCH 03/11] ARROW-13064: [C++] Update benchmark to only count output size --- cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index e4b71aceab8..ab9f06c3a5a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -133,8 +133,7 @@ static void CaseWhenBench(benchmark::State& state) { val3->Slice(offset), val4->Slice(offset)})); } - state.SetBytesProcessed(state.iterations() * - ((len - offset) / 8 + 4 * (len - offset) * sizeof(CType))); + state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); } template @@ -172,8 +171,7 @@ static void CaseWhenBenchContiguous(benchmark::State& state) { val3->Slice(offset)})); } - state.SetBytesProcessed(state.iterations() * - ((len - offset) / 8 + 3 * (len - offset) * sizeof(CType))); + state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); } static void CaseWhenBench64(benchmark::State& state) { From ec49ff9e96d6b1d57283dcb04007f1b7334beb05 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 13:56:21 -0400 Subject: [PATCH 04/11] ARROW-13064: [C++] Enable can_write_into_slices for case_when --- .../arrow/compute/kernels/scalar_if_else.cc | 188 ++++++++---------- 1 file changed, 81 insertions(+), 107 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 0364bc40a1f..780a538583e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -682,65 +682,65 @@ template struct CopyFixedWidth {}; template <> struct CopyFixedWidth { - static void CopyScalar(const Scalar& scalar, uint8_t* out_values, const int64_t offset, - const int64_t length) { + static void CopyScalar(const Scalar& scalar, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { const bool value = UnboxScalar::Unbox(scalar); - BitUtil::SetBitsTo(out_values, offset, length, value); + BitUtil::SetBitsTo(raw_out_values, out_offset, length, value); } - static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, - const int64_t length) { - arrow::internal::CopyBitmap(array.buffers[1]->data(), array.offset + offset, length, - out_values, offset); + static void CopyArray(const DataType&, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + arrow::internal::CopyBitmap(in_values, in_offset, length, raw_out_values, out_offset); } }; template struct CopyFixedWidth> { using CType = typename TypeTraits::CType; - static void CopyScalar(const Scalar& values, uint8_t* raw_out_values, - const int64_t offset, const int64_t length) { + static void CopyScalar(const Scalar& scalar, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { CType* out_values = reinterpret_cast(raw_out_values); - const CType value = UnboxScalar::Unbox(values); - std::fill(out_values + offset, out_values + offset + length, value); + const CType value = UnboxScalar::Unbox(scalar); + std::fill(out_values + out_offset, out_values + out_offset + length, value); } - static void CopyArray(const ArrayData& array, uint8_t* raw_out_values, - const int64_t offset, const int64_t length) { - CType* out_values = reinterpret_cast(raw_out_values); - const CType* in_values = array.GetValues(1); - std::copy(in_values + offset, in_values + offset + length, out_values + offset); + static void CopyArray(const DataType&, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + std::memcpy(raw_out_values + out_offset * sizeof(CType), + in_values + in_offset * sizeof(CType), length * sizeof(CType)); } }; template struct CopyFixedWidth> { - static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, - const int64_t length) { + static void CopyScalar(const Scalar& values, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { const int32_t width = checked_cast(*values.type).byte_width(); - uint8_t* next = out_values + (width * offset); + uint8_t* next = raw_out_values + (width * out_offset); const auto& scalar = checked_cast(values); - if (!scalar.is_valid) return; + // Scalar may have null value buffer + if (!scalar.value) return; DCHECK_EQ(scalar.value->size(), width); for (int i = 0; i < length; i++) { std::memcpy(next, scalar.value->data(), width); next += width; } } - static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, - const int64_t length) { - const int32_t width = - checked_cast(*array.type).byte_width(); - uint8_t* next = out_values + (width * offset); - const auto* in_values = array.GetValues(1, (offset + array.offset) * width); - std::memcpy(next, in_values, length * width); + static void CopyArray(const DataType& type, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = checked_cast(type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + std::memcpy(next, in_values + in_offset * width, length * width); } }; template struct CopyFixedWidth> { using ScalarType = typename TypeTraits::ScalarType; - static void CopyScalar(const Scalar& values, uint8_t* out_values, const int64_t offset, - const int64_t length) { + static void CopyScalar(const Scalar& values, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { const int32_t width = checked_cast(*values.type).byte_width(); - uint8_t* next = out_values + (width * offset); + uint8_t* next = raw_out_values + (width * out_offset); const auto& scalar = checked_cast(values); const auto value = scalar.value.ToBytes(); for (int i = 0; i < length; i++) { @@ -748,37 +748,37 @@ struct CopyFixedWidth> { next += width; } } - static void CopyArray(const ArrayData& array, uint8_t* out_values, const int64_t offset, - const int64_t length) { - const int32_t width = - checked_cast(*array.type).byte_width(); - uint8_t* next = out_values + (width * offset); - const auto* in_values = array.GetValues(1, (offset + array.offset) * width); - std::memcpy(next, in_values, length * width); + static void CopyArray(const DataType& type, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = checked_cast(type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + std::memcpy(next, in_values + in_offset * width, length * width); } }; // Copy fixed-width values from a scalar/array datum into an output values buffer template -void CopyValues(const Datum& values, uint8_t* out_valid, uint8_t* out_values, - const int64_t offset, const int64_t length) { - using Copier = CopyFixedWidth; - if (values.is_scalar()) { - const auto& scalar = *values.scalar(); +void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length, + uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) { + if (in_values.is_scalar()) { + const auto& scalar = *in_values.scalar(); if (out_valid) { - BitUtil::SetBitsTo(out_valid, offset, length, scalar.is_valid); + BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid); } - Copier::CopyScalar(scalar, out_values, offset, length); + CopyFixedWidth::CopyScalar(scalar, length, out_values, out_offset); } else { - const ArrayData& array = *values.array(); + const ArrayData& array = *in_values.array(); if (out_valid) { if (array.MayHaveNulls()) { - arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + offset, - length, out_valid, offset); + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, + length, out_valid, out_offset); } else { - BitUtil::SetBitsTo(out_valid, offset, length, true); + BitUtil::SetBitsTo(out_valid, out_offset, length, true); } } - Copier::CopyArray(array, out_values, offset, length); + CopyFixedWidth::CopyArray(*array.type, array.buffers[1]->data(), + array.offset + in_offset, length, out_values, + out_offset); } } @@ -821,6 +821,7 @@ struct CaseWhenFunction : ScalarFunction { }; // Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions +template Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& conds = checked_cast(*batch.values[0].scalar()); Datum result; @@ -839,19 +840,16 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out } if (out->is_scalar()) { *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); - } else if (result.is_value()) { - if (result.is_scalar()) { - ARROW_ASSIGN_OR_RAISE(auto temp, MakeArrayFromScalar(*result.scalar(), batch.length, - ctx->memory_pool())); - *out->mutable_array() = *temp->data(); - } else { - *out->mutable_array() = *result.array(); - } - } else { - ARROW_ASSIGN_OR_RAISE(auto temp, - MakeArrayOfNull(out->type(), batch.length, ctx->memory_pool())); - *out->mutable_array() = *temp->data(); + return Status::OK(); + } + ArrayData* output = out->mutable_array(); + if (!result.is_value()) { + // All conditions false, no 'else' argument + result = MakeNullScalar(out->type()); } + CopyValues(result, /*in_offset=*/0, batch.length, + output->GetMutableValues(0, 0), + output->GetMutableValues(1, 0), output->offset); return Status::OK(); } @@ -861,46 +859,19 @@ template Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& conds_array = *batch.values[0].array(); ArrayData* output = out->mutable_array(); + const int64_t out_offset = output->offset; const auto num_value_args = batch.values.size() - 1; const bool have_else_arg = static_cast(conds_array.type->num_fields()) < num_value_args; - uint8_t* out_valid = nullptr; - - // Check if we may need a validity bitmap - bool need_valid_bitmap = false; - if (!have_else_arg) { - // If we don't have an else arg -> need a bitmap since we may emit nulls - need_valid_bitmap = true; - } else { - // Otherwise if any value array has a null count we need a validity bitmap - for (size_t i = 1; i < batch.values.size(); i++) { - if (batch[i].null_count() > 0) { - need_valid_bitmap = true; - break; - } - } - } - if (need_valid_bitmap) { - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(batch.length)); - out_valid = output->buffers[0]->mutable_data(); - } - - // Initialize values buffer - auto bit_width = checked_cast(*out->type()).bit_width(); - if (bit_width == 1) { - ARROW_ASSIGN_OR_RAISE(output->buffers[1], ctx->AllocateBitmap(batch.length)); - } else { - auto byte_width = BitUtil::BytesForBits(bit_width); - ARROW_ASSIGN_OR_RAISE(output->buffers[1], ctx->Allocate(batch.length * byte_width)); - } + uint8_t* out_valid = output->buffers[0]->mutable_data(); uint8_t* out_values = output->buffers[1]->mutable_data(); if (have_else_arg) { // Copy 'else' value into output - CopyValues(batch.values.back(), out_valid, out_values, /*offset=*/0, - batch.length); - } else if (need_valid_bitmap) { + CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, + out_values, out_offset); + } else { // There's no 'else' argument, so we should have an all-null validity bitmap - std::memset(out_valid, 0x00, output->buffers[0]->size()); + BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); } // Allocate a temporary bitmap to determine which elements still need setting. @@ -925,14 +896,15 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) while (offset < batch.length) { const auto block = counter.NextAndWord(); if (block.AllSet()) { - CopyValues(values_datum, out_valid, out_values, offset, block.length); + CopyValues(values_datum, offset, block.length, out_valid, out_values, + out_offset + offset); BitUtil::SetBitsTo(mask, offset, block.length, false); } else if (block.popcount) { for (int64_t j = 0; j < block.length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { - CopyValues(values_datum, out_valid, out_values, offset + j, - /*length=*/1); + CopyValues(values_datum, offset + j, /*length=*/1, out_valid, + out_values, out_offset + offset + j); BitUtil::SetBitTo(mask, offset + j, false); } } @@ -950,15 +922,16 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) const uint64_t word = words[0] & words[1] & words[2]; const int64_t block_length = std::min(64, batch.length - offset); if (word == std::numeric_limits::max()) { - CopyValues(values_datum, out_valid, out_values, offset, block_length); + CopyValues(values_datum, offset, block_length, out_valid, out_values, + out_offset + offset); BitUtil::SetBitsTo(mask, offset, block_length, false); } else if (word) { for (int64_t j = 0; j < block_length; ++j) { if (BitUtil::GetBit(mask, offset + j) && BitUtil::GetBit(cond_valid, cond_offset + offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j)) { - CopyValues(values_datum, out_valid, out_values, offset + j, - /*length=*/1); + CopyValues(values_datum, offset + j, /*length=*/1, out_valid, + out_values, out_offset + offset + j); BitUtil::SetBitTo(mask, offset + j, false); } } @@ -975,7 +948,8 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) const uint64_t word = words[0] & words[1] & words[2] & words[3]; const int64_t block_length = std::min(64, batch.length - offset); if (word == std::numeric_limits::max()) { - CopyValues(values_datum, out_valid, out_values, offset, block_length); + CopyValues(values_datum, offset, block_length, out_valid, out_values, + out_offset + offset); BitUtil::SetBitsTo(mask, offset, block_length, false); } else if (word) { for (int64_t j = 0; j < block_length; ++j) { @@ -983,8 +957,8 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) BitUtil::GetBit(cond_valid, cond_offset + offset + j) && BitUtil::GetBit(cond_values, cond_offset + offset + j) && BitUtil::GetBit(conds_valid, conds_array.offset + offset + j)) { - CopyValues(values_datum, out_valid, out_values, offset + j, - /*length=*/1); + CopyValues(values_datum, offset + j, /*length=*/1, out_valid, + out_values, out_offset + offset + j); BitUtil::SetBitTo(mask, offset + j, false); } } @@ -1003,7 +977,7 @@ struct CaseWhenFunctor { if (batch.values[0].is_array()) { return ExecArrayCaseWhen(ctx, batch, out); } - return ExecScalarCaseWhen(ctx, batch, out); + return ExecScalarCaseWhen(ctx, batch, out); } }; @@ -1027,9 +1001,9 @@ void AddCaseWhenKernel(const std::shared_ptr& scalar_function, OutputType(LastType), /*is_varargs=*/true), exec); - kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - kernel.can_write_into_slices = false; + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.can_write_into_slices = is_fixed_width(get_id.id); DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } From 879da44f7085b9c630050a8025127c6706366d27 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 15:34:18 -0400 Subject: [PATCH 05/11] ARROW-13064: [C++] Update case_when docstrings --- .../arrow/compute/kernels/scalar_if_else.cc | 15 ++--- .../kernels/scalar_if_else_benchmark.cc | 24 ++++---- docs/source/cpp/compute.rst | 59 ++++++++++--------- 3 files changed, 50 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 780a538583e..ded29e4d729 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1024,16 +1024,17 @@ const FunctionDoc if_else_doc{"Choose values based on a condition", const FunctionDoc case_when_doc{ "Choose values based on multiple conditions", - ("`cond` must be a sequence of alternating Boolean condition data " - "and value data (of any type, but all must be the same type or " - "castable to a common type), along with an optional datum of " - "\"else\" values. At least one datum must be given.\n" + ("`cond` must be a struct of Boolean values. `cases` can be a mix " + "of scalar and array arguments (of any type, but all must be the " + "same type or castable to a common type), with either exactly one " + "datum per child of `cond`, or one more `cases` than children of " + "`cond` (in which case we have an \"else\" value).\n" "Each row of the output will be the corresponding value of the " - "first value datum for which the corresponding condition datum " + "first datum in `cases` for which the corresponding child of `cond` " "is true, or otherwise the \"else\" value (if given), or null. " - "Essentially, this implements a switch-case or if-else if-else " + "Essentially, this implements a switch-case or if-else, if-else... " "statement."), - {"*cond"}}; + {"cond", "*cases"}}; } // namespace void RegisterScalarIfElse(FunctionRegistry* registry) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index ab9f06c3a5a..0ad744054c6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include -#include #include +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/key_value_metadata.h" + namespace arrow { namespace compute { @@ -114,6 +116,11 @@ static void CaseWhenBench(benchmark::State& state) { rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); auto cond3 = std::static_pointer_cast( rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond_field = + field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); + auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), + key_value_metadata({{"null_probability", "0.0"}})), + len); auto val1 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val2 = std::static_pointer_cast( @@ -122,11 +129,6 @@ static void CaseWhenBench(benchmark::State& state) { rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val4 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - ASSERT_OK_AND_ASSIGN( - auto cond, - StructArray::Make({cond1, cond2, cond3}, std::vector{"a", "b", "c"}, - nullptr, /*null_count=*/0)); - for (auto _ : state) { ABORT_NOT_OK( CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), @@ -147,9 +149,7 @@ static void CaseWhenBenchContiguous(benchmark::State& state) { ASSERT_OK_AND_ASSIGN(auto trues, MakeArrayFromScalar(BooleanScalar(true), len / 3)); ASSERT_OK_AND_ASSIGN(auto falses, MakeArrayFromScalar(BooleanScalar(false), len / 3)); - auto null_scalar = MakeNullScalar(boolean()); - ASSERT_OK_AND_ASSIGN(auto nulls, - MakeArrayFromScalar(*null_scalar, len - 2 * (len / 3))); + ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(boolean(), len - 2 * (len / 3))); ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({trues, falses, nulls})); auto cond1 = std::static_pointer_cast(concat); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 39f07b348e5..ed97faead74 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -860,37 +860,38 @@ Structural transforms .. XXX (this category is a bit of a hodgepodge) -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| Function name | Arity | Input types | Output type | Notes | -+==========================+============+================================================+=====================+=========+ -| case_when | Varargs | Boolean, Any fixed-width | Input type | \(1) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(2) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(3) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(4) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(5) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(6) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(7) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(8) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(9) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| project | Varargs | Any | Struct | \(10) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| Function name | Arity | Input types | Output type | Notes | ++==========================+============+===================================================+=====================+=========+ +| case_when | Varargs | Struct of Boolean (Arg 0), Any fixed-width (rest) | Input type | \(1) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(2) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(3) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_finite | Unary | Float, Double | Boolean | \(4) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_inf | Unary | Float, Double | Boolean | \(5) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_nan | Unary | Float, Double | Boolean | \(6) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_null | Unary | Any | Boolean | \(7) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_valid | Unary | Any | Boolean | \(8) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| list_value_length | Unary | List-like | Int32 or Int64 | \(9) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| project | Varargs | Any | Struct | \(10) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ * \(1) This function acts like a SQL 'case when' statement or switch-case. The - input is any number of alternating Boolean and value data, followed by an - optional value datum to represent the 'else' or 'default' case. At least one - input must be provided. The output is of the same type as the value inputs; - each row will be the corresponding value from the first value datum for which - the corresponding Boolean is true, or the corresponding value from the - 'default' input, or null otherwise. + input is a "condition" value, which is a struct of Booleans, followed by the + values for each "branch". There must be either exactly one value argument for + each child of the condition struct, or one more value argument than children + (in which case we have an 'else' or 'default' value). The output is of the + same type as the value inputs; each row will be the corresponding value from + the first value datum for which the corresponding Boolean is true, or the + corresponding value from the 'default' input, or null otherwise. * \(2) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values From c64dd7b316ef1df4417523baf31beac0d3e8df7e Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 15:59:34 -0400 Subject: [PATCH 06/11] ARROW-13064: [C++] Benchmark another case --- .../kernels/scalar_if_else_benchmark.cc | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 0ad744054c6..22aa4200f4f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -99,7 +99,7 @@ static void IfElseBench32Contiguous(benchmark::State& state) { return IfElseBenchContiguous(state); } -template +template static void CaseWhenBench(benchmark::State& state) { using CType = typename Type::c_type; auto type = TypeTraits::type_singleton(); @@ -118,9 +118,10 @@ static void CaseWhenBench(benchmark::State& state) { rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); auto cond_field = field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); - auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), - key_value_metadata({{"null_probability", "0.0"}})), - len); + auto cond = rand.ArrayOf( + *field("", struct_({cond_field, cond_field, cond_field}), + key_value_metadata({{"null_probability", outer_nulls ? "0.01" : "0.0"}})), + len); auto val1 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val2 = std::static_pointer_cast( @@ -178,18 +179,15 @@ static void CaseWhenBench64(benchmark::State& state) { return CaseWhenBench(state); } -static void CaseWhenBench32(benchmark::State& state) { - return CaseWhenBench(state); +static void CaseWhenBench64OuterNulls(benchmark::State& state) { + // Benchmark where both children of cond have nulls and cond itself has nulls + return CaseWhenBench(state); } static void CaseWhenBench64Contiguous(benchmark::State& state) { return CaseWhenBenchContiguous(state); } -static void CaseWhenBench32Contiguous(benchmark::State& state) { - return CaseWhenBenchContiguous(state); -} - BENCHMARK(IfElseBench32)->Args({elems, 0}); BENCHMARK(IfElseBench64)->Args({elems, 0}); @@ -202,16 +200,13 @@ BENCHMARK(IfElseBench64Contiguous)->Args({elems, 0}); BENCHMARK(IfElseBench32Contiguous)->Args({elems, 99}); BENCHMARK(IfElseBench64Contiguous)->Args({elems, 99}); -BENCHMARK(CaseWhenBench32)->Args({elems, 0}); BENCHMARK(CaseWhenBench64)->Args({elems, 0}); - -BENCHMARK(CaseWhenBench32)->Args({elems, 99}); BENCHMARK(CaseWhenBench64)->Args({elems, 99}); -BENCHMARK(CaseWhenBench32Contiguous)->Args({elems, 0}); -BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64OuterNulls)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64OuterNulls)->Args({elems, 99}); -BENCHMARK(CaseWhenBench32Contiguous)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99}); } // namespace compute From 008729024e9c92817f2b483ff5da5e01bdd95edf Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 15:59:45 -0400 Subject: [PATCH 07/11] ARROW-13064: [C++] Micro-optimize a particularly bad case --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index ded29e4d729..dfd53ba9591 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -770,8 +770,15 @@ void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t l const ArrayData& array = *in_values.array(); if (out_valid) { if (array.MayHaveNulls()) { - arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, - length, out_valid, out_offset); + if (length == 1) { + // CopyBitmap is slow for short runs + BitUtil::SetBitTo( + out_valid, out_offset, + BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset)); + } else { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, + length, out_valid, out_offset); + } } else { BitUtil::SetBitsTo(out_valid, out_offset, length, true); } From 0a8a3c7a1c477385459c45e25869f51e0a586768 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 15:59:54 -0400 Subject: [PATCH 08/11] ARROW-13064: [C++] Make sure to initialize memory --- .../arrow/compute/kernels/scalar_if_else.cc | 30 ++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index dfd53ba9591..1a1e8cc193e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -974,7 +974,35 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) }); } } - // TODO: need to initialize output values + if (!have_else_arg) { + // Need to initialize any remaining null slots (uninitialized memory) + BitBlockCounter counter(mask, /*offset=*/0, batch.length); + int64_t offset = 0; + auto bit_width = checked_cast(*out->type()).bit_width(); + auto byte_width = BitUtil::BytesForBits(bit_width); + while (offset < batch.length) { + const auto block = counter.NextWord(); + if (block.AllSet()) { + if (bit_width == 1) { + BitUtil::SetBitsTo(out_values, out_offset + offset, block.length, false); + } else { + std::memset(out_values + (out_offset + offset) * byte_width, 0x00, + byte_width * block.length); + } + } else if (!block.NoneSet()) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(out_valid, out_offset + offset + j)) continue; + if (bit_width == 1) { + BitUtil::ClearBit(out_values, out_offset + offset + j); + } else { + std::memset(out_values + (out_offset + offset + j) * byte_width, 0x00, + byte_width); + } + } + } + offset += block.length; + } + } return Status::OK(); } From c4ea55fec1bfd936793f85868c72a4ea75849a8e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 14 Jul 2021 15:21:23 -0400 Subject: [PATCH 09/11] ARROW-13064: [C++] Don't allow outer nulls --- .../arrow/compute/kernels/scalar_if_else.cc | 41 ++++--------------- .../kernels/scalar_if_else_benchmark.cc | 17 ++------ .../compute/kernels/scalar_if_else_test.cc | 13 ++++++ 3 files changed, 25 insertions(+), 46 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 1a1e8cc193e..dedafb55ff7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -831,6 +831,9 @@ struct CaseWhenFunction : ScalarFunction { template Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& conds = checked_cast(*batch.values[0].scalar()); + if (!conds.is_valid) { + return Status::Invalid("cond struct must not be null"); + } Datum result; for (size_t i = 0; i < batch.values.size() - 1; i++) { if (i < conds.value.size()) { @@ -865,6 +868,9 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out template Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& conds_array = *batch.values[0].array(); + if (conds_array.GetNullCount() > 0) { + return Status::Invalid("cond struct must not have nulls"); + } ArrayData* output = out->mutable_array(); const int64_t out_offset = output->offset; const auto num_value_args = batch.values.size() - 1; @@ -887,8 +893,6 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) std::memset(mask, 0xFF, mask_buffer->size()); // Then iterate through each argument in turn and set elements. - const uint8_t* conds_valid = - conds_array.GetNullCount() > 0 ? conds_array.buffers[0]->data() : nullptr; for (size_t i = 0; i < batch.values.size() - (have_else_arg ? 2 : 1); i++) { const ArrayData& cond_array = *conds_array.child_data[i]; const int64_t cond_offset = conds_array.offset + cond_array.offset; @@ -896,7 +900,7 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) const Datum& values_datum = batch[i + 1]; int64_t offset = 0; - if (!conds_valid && cond_array.GetNullCount() == 0) { + if (cond_array.GetNullCount() == 0) { // If no valid buffer, visit mask & cond bitmap simultaneously BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset, batch.length); @@ -918,8 +922,7 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } offset += block.length; } - continue; - } else if (!conds_valid) { + } else { // Visit mask & cond bitmap & cond validity const uint8_t* cond_valid = cond_array.buffers[0]->data(); Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, @@ -944,34 +947,6 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) } } }); - } else { - // Visit mask & cond bitmap & cond validity & struct validity - const uint8_t* cond_valid = cond_array.buffers[0]->data(); - Bitmap bitmaps[4] = {{mask, /*offset=*/0, batch.length}, - {cond_values, cond_offset, batch.length}, - {cond_valid, cond_offset, batch.length}, - {conds_valid, conds_array.offset, batch.length}}; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - const uint64_t word = words[0] & words[1] & words[2] & words[3]; - const int64_t block_length = std::min(64, batch.length - offset); - if (word == std::numeric_limits::max()) { - CopyValues(values_datum, offset, block_length, out_valid, out_values, - out_offset + offset); - BitUtil::SetBitsTo(mask, offset, block_length, false); - } else if (word) { - for (int64_t j = 0; j < block_length; ++j) { - if (BitUtil::GetBit(mask, offset + j) && - BitUtil::GetBit(cond_valid, cond_offset + offset + j) && - BitUtil::GetBit(cond_values, cond_offset + offset + j) && - BitUtil::GetBit(conds_valid, conds_array.offset + offset + j)) { - CopyValues(values_datum, offset + j, /*length=*/1, out_valid, - out_values, out_offset + offset + j); - BitUtil::SetBitTo(mask, offset + j, false); - } - } - } - offset += block_length; - }); } } if (!have_else_arg) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 22aa4200f4f..9192cf54ebb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -99,7 +99,7 @@ static void IfElseBench32Contiguous(benchmark::State& state) { return IfElseBenchContiguous(state); } -template +template static void CaseWhenBench(benchmark::State& state) { using CType = typename Type::c_type; auto type = TypeTraits::type_singleton(); @@ -118,10 +118,9 @@ static void CaseWhenBench(benchmark::State& state) { rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); auto cond_field = field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); - auto cond = rand.ArrayOf( - *field("", struct_({cond_field, cond_field, cond_field}), - key_value_metadata({{"null_probability", outer_nulls ? "0.01" : "0.0"}})), - len); + auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), + key_value_metadata({{"null_probability", "0.0"}})), + len); auto val1 = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); auto val2 = std::static_pointer_cast( @@ -179,11 +178,6 @@ static void CaseWhenBench64(benchmark::State& state) { return CaseWhenBench(state); } -static void CaseWhenBench64OuterNulls(benchmark::State& state) { - // Benchmark where both children of cond have nulls and cond itself has nulls - return CaseWhenBench(state); -} - static void CaseWhenBench64Contiguous(benchmark::State& state) { return CaseWhenBenchContiguous(state); } @@ -203,9 +197,6 @@ BENCHMARK(IfElseBench64Contiguous)->Args({elems, 99}); BENCHMARK(CaseWhenBench64)->Args({elems, 0}); BENCHMARK(CaseWhenBench64)->Args({elems, 99}); -BENCHMARK(CaseWhenBench64OuterNulls)->Args({elems, 0}); -BENCHMARK(CaseWhenBench64OuterNulls)->Args({elems, 99}); - BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99}); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 3eb682d97be..eacb917735b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -408,6 +408,19 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) { ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"), ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")}, ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]")); + + // Error cases + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("cond struct must not be null"), + CallFunction( + "case_when", + {Datum(std::make_shared(struct_({field("", boolean())}))), + Datum(scalar1)})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("cond struct must not have nulls"), + CallFunction( + "case_when", + {Datum(*MakeArrayOfNull(struct_({field("", boolean())}), 4)), Datum(values1)})); } TEST(TestCaseWhen, Null) { From 89287c6d9ae3f45060419b8a080731d25ad3f39e Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 15 Jul 2021 11:46:28 -0400 Subject: [PATCH 10/11] Update cpp/src/arrow/compute/kernels/scalar_if_else.cc Co-authored-by: Joris Van den Bossche --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index dedafb55ff7..32307542d97 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -869,7 +869,7 @@ template Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& conds_array = *batch.values[0].array(); if (conds_array.GetNullCount() > 0) { - return Status::Invalid("cond struct must not have nulls"); + return Status::Invalid("cond struct must not have top-level nulls"); } ArrayData* output = out->mutable_array(); const int64_t out_offset = output->offset; From fd3458a4ceb3bcd23ed1244621ffdb046b92525e Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 15 Jul 2021 11:51:45 -0400 Subject: [PATCH 11/11] ARROW-13064: [C++] Fix test --- cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index eacb917735b..cd2d04a13e0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -417,7 +417,7 @@ TYPED_TEST(TestCaseWhenNumeric, FixedSize) { {Datum(std::make_shared(struct_({field("", boolean())}))), Datum(scalar1)})); EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("cond struct must not have nulls"), + Invalid, ::testing::HasSubstr("cond struct must not have top-level nulls"), CallFunction( "case_when", {Datum(*MakeArrayOfNull(struct_({field("", boolean())}), 4)), Datum(values1)}));