diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index be6498a74c6..68df5f98b10 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -466,6 +466,14 @@ Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_fa return CallFunction("if_else", {cond, if_true, if_false}, ctx); } +Result CaseWhen(const Datum& cond, const std::vector& cases, + ExecContext* ctx) { + std::vector args = {cond}; + args.reserve(cases.size() + 1); + args.insert(args.end(), cases.begin(), cases.end()); + return CallFunction("case_when", args, ctx); +} + // ---------------------------------------------------------------------- // Temporal functions diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f0aebc8e032..bbaa4d13a21 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -741,6 +741,23 @@ ARROW_EXPORT Result IfElse(const Datum& cond, const Datum& left, const Datum& right, ExecContext* ctx = NULLPTR); +/// \brief CaseWhen behaves like a switch/case or if-else if-else statement: for +/// each row, select the first value for which the corresponding condition is +/// true, or (if given) select the 'else' value, else emit null. Note that a +/// null condition is the same as false. +/// +/// \param[in] cond Conditions (Boolean) +/// \param[in] cases Values (any type), along with an optional 'else' value. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result CaseWhen(const Datum& cond, const std::vector& cases, + ExecContext* ctx = NULLPTR); + /// \brief Year returns year for each element of `values` /// /// \param[in] values input to extract year from diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 6cdd17adcc9..f131f524d2e 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -402,8 +402,7 @@ KernelSignature::KernelSignature(std::vector in_types, OutputType out out_type_(std::move(out_type)), is_varargs_(is_varargs), hash_code_(0) { - // VarArgs sigs must have only a single input type to use for argument validation - DCHECK(!is_varargs || (is_varargs && (in_types_.size() == 1))); + DCHECK(!is_varargs || (is_varargs && (in_types_.size() >= 1))); } std::shared_ptr KernelSignature::Make(std::vector in_types, @@ -430,8 +429,8 @@ bool KernelSignature::Equals(const KernelSignature& other) const { bool KernelSignature::MatchesInputs(const std::vector& args) const { if (is_varargs_) { - for (const auto& arg : args) { - if (!in_types_[0].Matches(arg)) { + for (size_t i = 0; i < args.size(); ++i) { + if (!in_types_[std::min(i, in_types_.size() - 1)].Matches(args[i])) { return false; } } @@ -464,15 +463,19 @@ std::string KernelSignature::ToString() const { std::stringstream ss; if (is_varargs_) { - ss << "varargs[" << in_types_[0].ToString() << "]"; + ss << "varargs["; } else { ss << "("; - for (size_t i = 0; i < in_types_.size(); ++i) { - if (i > 0) { - ss << ", "; - } - ss << in_types_[i].ToString(); + } + for (size_t i = 0; i < in_types_.size(); ++i) { + if (i > 0) { + ss << ", "; } + ss << in_types_[i].ToString(); + } + if (is_varargs_) { + ss << "]"; + } else { ss << ")"; } ss << " -> " << out_type_.ToString(); diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index 50b1dd8e55e..36d20c7289e 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -366,8 +366,10 @@ class ARROW_EXPORT OutputType { /// \brief Holds the input types and output type of the kernel. /// -/// VarArgs functions should pass a single input type to be used to validate -/// the input types of a function invocation. +/// VarArgs functions with minimum N arguments should pass up to N input types to be +/// used to validate the input types of a function invocation. The first N-1 types +/// will be matched against the first N-1 arguments, and the last type will be +/// matched against the remaining arguments. class ARROW_EXPORT KernelSignature { public: KernelSignature(std::vector in_types, OutputType out_type, diff --git a/cpp/src/arrow/compute/kernel_test.cc b/cpp/src/arrow/compute/kernel_test.cc index a5ef9d44e18..a63c42d4fde 100644 --- a/cpp/src/arrow/compute/kernel_test.cc +++ b/cpp/src/arrow/compute/kernel_test.cc @@ -468,15 +468,28 @@ TEST(KernelSignature, MatchesInputs) { } TEST(KernelSignature, VarArgsMatchesInputs) { - KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); - - std::vector args = {int8()}; - ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(ValueDescr::Scalar(int8())); - args.push_back(ValueDescr::Array(int8())); - ASSERT_TRUE(sig.MatchesInputs(args)); - args.push_back(int32()); - ASSERT_FALSE(sig.MatchesInputs(args)); + { + KernelSignature sig({int8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(int8())); + args.push_back(ValueDescr::Array(int8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); + } + { + KernelSignature sig({int8(), utf8()}, utf8(), /*is_varargs=*/true); + + std::vector args = {int8()}; + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(ValueDescr::Scalar(utf8())); + args.push_back(ValueDescr::Array(utf8())); + ASSERT_TRUE(sig.MatchesInputs(args)); + args.push_back(int32()); + ASSERT_FALSE(sig.MatchesInputs(args)); + } } TEST(KernelSignature, ToString) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index e723bd7838e..673db088eae 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -218,9 +218,14 @@ void ReplaceTypes(const std::shared_ptr& type, } std::shared_ptr CommonNumeric(const std::vector& descrs) { - DCHECK(!descrs.empty()) << "tried to find CommonNumeric type of an empty set"; + return CommonNumeric(descrs.data(), descrs.size()); +} - for (const auto& descr : descrs) { +std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { + DCHECK_GT(count, 0) << "tried to find CommonNumeric type of an empty set"; + + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); auto id = descr.type->id(); if (!is_floating(id) && !is_integer(id)) { // a common numeric type is only possible if all types are numeric @@ -232,17 +237,20 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { } } - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); if (descr.type->id() == Type::DOUBLE) return float64(); } - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); if (descr.type->id() == Type::FLOAT) return float32(); } int max_width_signed = 0, max_width_unsigned = 0; - for (const auto& descr : descrs) { + for (size_t i = 0; i < count; i++) { + const auto& descr = *(begin + i); auto id = descr.type->id(); auto max_width = &(is_signed_integer(id) ? max_width_signed : max_width_unsigned); *max_width = std::max(bit_width(id), *max_width); @@ -253,7 +261,7 @@ std::shared_ptr CommonNumeric(const std::vector& descrs) { if (max_width_unsigned == 32) return uint32(); if (max_width_unsigned == 16) return uint16(); DCHECK_EQ(max_width_unsigned, 8); - return int8(); + return uint8(); } if (max_width_signed <= max_width_unsigned) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 12e80423f7f..d28ede4f77a 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1367,6 +1367,9 @@ void ReplaceTypes(const std::shared_ptr&, std::vector* des ARROW_EXPORT std::shared_ptr CommonNumeric(const std::vector& descrs); +ARROW_EXPORT +std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); + ARROW_EXPORT std::shared_ptr CommonTimestamp(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 54e0725fce7..32307542d97 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -30,6 +30,7 @@ using internal::Bitmap; using internal::BitmapWordReader; namespace compute { +namespace internal { namespace { @@ -676,7 +677,353 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun } } -} // namespace +// Helper to copy or broadcast fixed-width values between buffers. +template +struct CopyFixedWidth {}; +template <> +struct CopyFixedWidth { + static void CopyScalar(const Scalar& scalar, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const bool value = UnboxScalar::Unbox(scalar); + BitUtil::SetBitsTo(raw_out_values, out_offset, length, value); + } + static void CopyArray(const DataType&, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + arrow::internal::CopyBitmap(in_values, in_offset, length, raw_out_values, out_offset); + } +}; +template +struct CopyFixedWidth> { + using CType = typename TypeTraits::CType; + static void CopyScalar(const Scalar& scalar, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + CType* out_values = reinterpret_cast(raw_out_values); + const CType value = UnboxScalar::Unbox(scalar); + std::fill(out_values + out_offset, out_values + out_offset + length, value); + } + static void CopyArray(const DataType&, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + std::memcpy(raw_out_values + out_offset * sizeof(CType), + in_values + in_offset * sizeof(CType), length * sizeof(CType)); + } +}; +template +struct CopyFixedWidth> { + static void CopyScalar(const Scalar& values, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = + checked_cast(*values.type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + const auto& scalar = checked_cast(values); + // Scalar may have null value buffer + if (!scalar.value) return; + DCHECK_EQ(scalar.value->size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(next, scalar.value->data(), width); + next += width; + } + } + static void CopyArray(const DataType& type, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = checked_cast(type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + std::memcpy(next, in_values + in_offset * width, length * width); + } +}; +template +struct CopyFixedWidth> { + using ScalarType = typename TypeTraits::ScalarType; + static void CopyScalar(const Scalar& values, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = + checked_cast(*values.type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + const auto& scalar = checked_cast(values); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(next, value.data(), width); + next += width; + } + } + static void CopyArray(const DataType& type, const uint8_t* in_values, + const int64_t in_offset, const int64_t length, + uint8_t* raw_out_values, const int64_t out_offset) { + const int32_t width = checked_cast(type).byte_width(); + uint8_t* next = raw_out_values + (width * out_offset); + std::memcpy(next, in_values + in_offset * width, length * width); + } +}; +// Copy fixed-width values from a scalar/array datum into an output values buffer +template +void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t length, + uint8_t* out_valid, uint8_t* out_values, const int64_t out_offset) { + if (in_values.is_scalar()) { + const auto& scalar = *in_values.scalar(); + if (out_valid) { + BitUtil::SetBitsTo(out_valid, out_offset, length, scalar.is_valid); + } + CopyFixedWidth::CopyScalar(scalar, length, out_values, out_offset); + } else { + const ArrayData& array = *in_values.array(); + if (out_valid) { + if (array.MayHaveNulls()) { + if (length == 1) { + // CopyBitmap is slow for short runs + BitUtil::SetBitTo( + out_valid, out_offset, + BitUtil::GetBit(array.buffers[0]->data(), array.offset + in_offset)); + } else { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset + in_offset, + length, out_valid, out_offset); + } + } else { + BitUtil::SetBitsTo(out_valid, out_offset, length, true); + } + } + CopyFixedWidth::CopyArray(*array.type, array.buffers[1]->data(), + array.offset + in_offset, length, out_values, + out_offset); + } +} + +struct CaseWhenFunction : ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + // The first function is a struct of booleans, where the number of fields in the + // struct is either equal to the number of other arguments or is one less. + RETURN_NOT_OK(CheckArity(*values)); + EnsureDictionaryDecoded(values); + auto first_type = (*values)[0].type; + if (first_type->id() != Type::STRUCT) { + return Status::TypeError("case_when: first argument must be STRUCT, not ", + *first_type); + } + auto num_fields = static_cast(first_type->num_fields()); + if (num_fields < values->size() - 2 || num_fields >= values->size()) { + return Status::Invalid( + "case_when: number of struct fields must be equal to or one less than count of " + "remaining arguments (", + values->size() - 1, "), got: ", first_type->num_fields()); + } + for (const auto& field : first_type->fields()) { + if (field->type()->id() != Type::BOOL) { + return Status::TypeError( + "case_when: all fields of first argument must be BOOL, but ", field->name(), + " was of type: ", *field->type()); + } + } + + if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { + for (auto it = values->begin() + 1; it != values->end(); it++) { + it->type = type; + } + } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +// Implement a 'case when' (SQL)/'select' (NumPy) function for any scalar conditions +template +Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& conds = checked_cast(*batch.values[0].scalar()); + if (!conds.is_valid) { + return Status::Invalid("cond struct must not be null"); + } + Datum result; + for (size_t i = 0; i < batch.values.size() - 1; i++) { + if (i < conds.value.size()) { + const Scalar& cond = *conds.value[i]; + if (cond.is_valid && internal::UnboxScalar::Unbox(cond)) { + result = batch[i + 1]; + break; + } + } else { + // ELSE clause + result = batch[i + 1]; + break; + } + } + if (out->is_scalar()) { + *out = result.is_scalar() ? result.scalar() : MakeNullScalar(out->type()); + return Status::OK(); + } + ArrayData* output = out->mutable_array(); + if (!result.is_value()) { + // All conditions false, no 'else' argument + result = MakeNullScalar(out->type()); + } + CopyValues(result, /*in_offset=*/0, batch.length, + output->GetMutableValues(0, 0), + output->GetMutableValues(1, 0), output->offset); + return Status::OK(); +} + +// Implement 'case when' for any mix of scalar/array arguments for any fixed-width type, +// given helper functions to copy data from a source array to a target array +template +Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& conds_array = *batch.values[0].array(); + if (conds_array.GetNullCount() > 0) { + return Status::Invalid("cond struct must not have top-level nulls"); + } + ArrayData* output = out->mutable_array(); + const int64_t out_offset = output->offset; + const auto num_value_args = batch.values.size() - 1; + const bool have_else_arg = + static_cast(conds_array.type->num_fields()) < num_value_args; + uint8_t* out_valid = output->buffers[0]->mutable_data(); + uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { + // Copy 'else' value into output + CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, + out_values, out_offset); + } else { + // There's no 'else' argument, so we should have an all-null validity bitmap + BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false); + } + + // Allocate a temporary bitmap to determine which elements still need setting. + ARROW_ASSIGN_OR_RAISE(auto mask_buffer, ctx->AllocateBitmap(batch.length)); + uint8_t* mask = mask_buffer->mutable_data(); + std::memset(mask, 0xFF, mask_buffer->size()); + + // Then iterate through each argument in turn and set elements. + for (size_t i = 0; i < batch.values.size() - (have_else_arg ? 2 : 1); i++) { + const ArrayData& cond_array = *conds_array.child_data[i]; + const int64_t cond_offset = conds_array.offset + cond_array.offset; + const uint8_t* cond_values = cond_array.buffers[1]->data(); + const Datum& values_datum = batch[i + 1]; + int64_t offset = 0; + + if (cond_array.GetNullCount() == 0) { + // If no valid buffer, visit mask & cond bitmap simultaneously + BinaryBitBlockCounter counter(mask, /*start_offset=*/0, cond_values, cond_offset, + batch.length); + while (offset < batch.length) { + const auto block = counter.NextAndWord(); + if (block.AllSet()) { + CopyValues(values_datum, offset, block.length, out_valid, out_values, + out_offset + offset); + BitUtil::SetBitsTo(mask, offset, block.length, false); + } else if (block.popcount) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_values, cond_offset + offset + j)) { + CopyValues(values_datum, offset + j, /*length=*/1, out_valid, + out_values, out_offset + offset + j); + BitUtil::SetBitTo(mask, offset + j, false); + } + } + } + offset += block.length; + } + } else { + // Visit mask & cond bitmap & cond validity + const uint8_t* cond_valid = cond_array.buffers[0]->data(); + Bitmap bitmaps[3] = {{mask, /*offset=*/0, batch.length}, + {cond_values, cond_offset, batch.length}, + {cond_valid, cond_offset, batch.length}}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + const uint64_t word = words[0] & words[1] & words[2]; + const int64_t block_length = std::min(64, batch.length - offset); + if (word == std::numeric_limits::max()) { + CopyValues(values_datum, offset, block_length, out_valid, out_values, + out_offset + offset); + BitUtil::SetBitsTo(mask, offset, block_length, false); + } else if (word) { + for (int64_t j = 0; j < block_length; ++j) { + if (BitUtil::GetBit(mask, offset + j) && + BitUtil::GetBit(cond_valid, cond_offset + offset + j) && + BitUtil::GetBit(cond_values, cond_offset + offset + j)) { + CopyValues(values_datum, offset + j, /*length=*/1, out_valid, + out_values, out_offset + offset + j); + BitUtil::SetBitTo(mask, offset + j, false); + } + } + } + }); + } + } + if (!have_else_arg) { + // Need to initialize any remaining null slots (uninitialized memory) + BitBlockCounter counter(mask, /*offset=*/0, batch.length); + int64_t offset = 0; + auto bit_width = checked_cast(*out->type()).bit_width(); + auto byte_width = BitUtil::BytesForBits(bit_width); + while (offset < batch.length) { + const auto block = counter.NextWord(); + if (block.AllSet()) { + if (bit_width == 1) { + BitUtil::SetBitsTo(out_values, out_offset + offset, block.length, false); + } else { + std::memset(out_values + (out_offset + offset) * byte_width, 0x00, + byte_width * block.length); + } + } else if (!block.NoneSet()) { + for (int64_t j = 0; j < block.length; ++j) { + if (BitUtil::GetBit(out_valid, out_offset + offset + j)) continue; + if (bit_width == 1) { + BitUtil::ClearBit(out_values, out_offset + offset + j); + } else { + std::memset(out_values + (out_offset + offset + j) * byte_width, 0x00, + byte_width); + } + } + } + offset += block.length; + } + } + return Status::OK(); +} + +template +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch.values[0].is_array()) { + return ExecArrayCaseWhen(ctx, batch, out); + } + return ExecScalarCaseWhen(ctx, batch, out); + } +}; + +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } +}; + +Result LastType(KernelContext*, const std::vector& descrs) { + ValueDescr result = descrs.back(); + result.shape = GetBroadcastShape(descrs); + return result; +} + +void AddCaseWhenKernel(const std::shared_ptr& scalar_function, + detail::GetTypeId get_id, ArrayKernelExec exec) { + ScalarKernel kernel( + KernelSignature::Make({InputType(Type::STRUCT), InputType(get_id.id)}, + OutputType(LastType), + /*is_varargs=*/true), + exec); + kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; + kernel.mem_allocation = MemAllocation::PREALLOCATE; + kernel.can_write_into_slices = is_fixed_width(get_id.id); + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); +} + +void AddPrimitiveCaseWhenKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = GenerateTypeAgnosticPrimitive(*type); + AddCaseWhenKernel(scalar_function, type, std::move(exec)); + } +} const FunctionDoc if_else_doc{"Choose values based on a condition", ("`cond` must be a Boolean scalar/ array. \n`left` or " @@ -685,22 +1032,46 @@ const FunctionDoc if_else_doc{"Choose values based on a condition", " output."), {"cond", "left", "right"}}; -namespace internal { +const FunctionDoc case_when_doc{ + "Choose values based on multiple conditions", + ("`cond` must be a struct of Boolean values. `cases` can be a mix " + "of scalar and array arguments (of any type, but all must be the " + "same type or castable to a common type), with either exactly one " + "datum per child of `cond`, or one more `cases` than children of " + "`cond` (in which case we have an \"else\" value).\n" + "Each row of the output will be the corresponding value of the " + "first datum in `cases` for which the corresponding child of `cond` " + "is true, or otherwise the \"else\" value (if given), or null. " + "Essentially, this implements a switch-case or if-else, if-else... " + "statement."), + {"cond", "*cases"}}; +} // namespace void RegisterScalarIfElse(FunctionRegistry* registry) { - ScalarKernel scalar_kernel; - scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - - auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); - - AddPrimitiveIfElseKernels(func, NumericTypes()); - AddPrimitiveIfElseKernels(func, TemporalTypes()); - AddPrimitiveIfElseKernels(func, {boolean()}); - AddNullIfElseKernel(func); - // todo add binary kernels - - DCHECK_OK(registry->AddFunction(std::move(func))); + { + auto func = + std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), day_time_interval(), month_interval()}); + AddNullIfElseKernel(func); + // todo add binary kernels + DCHECK_OK(registry->AddFunction(std::move(func))); + } + { + auto func = std::make_shared( + "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); + AddPrimitiveCaseWhenKernels(func, NumericTypes()); + AddPrimitiveCaseWhenKernels(func, TemporalTypes()); + AddPrimitiveCaseWhenKernels( + func, {boolean(), null(), day_time_interval(), month_interval()}); + AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, + CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor::Exec); + DCHECK_OK(registry->AddFunction(std::move(func))); + } } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc index 98fb675da40..9192cf54ebb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc @@ -15,12 +15,14 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include -#include #include +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/key_value_metadata.h" + namespace arrow { namespace compute { @@ -97,6 +99,89 @@ static void IfElseBench32Contiguous(benchmark::State& state) { return IfElseBenchContiguous(state); } +template +static void CaseWhenBench(benchmark::State& state) { + using CType = typename Type::c_type; + auto type = TypeTraits::type_singleton(); + using ArrayType = typename TypeTraits::ArrayType; + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + random::RandomArrayGenerator rand(/*seed=*/0); + + auto cond1 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond2 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond3 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto cond_field = + field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}})); + auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}), + key_value_metadata({{"null_probability", "0.0"}})), + len); + auto val1 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val2 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val3 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val4 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + for (auto _ : state) { + ABORT_NOT_OK( + CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), + val3->Slice(offset), val4->Slice(offset)})); + } + + state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); +} + +template +static void CaseWhenBenchContiguous(benchmark::State& state) { + using CType = typename Type::c_type; + auto type = TypeTraits::type_singleton(); + using ArrayType = typename TypeTraits::ArrayType; + + int64_t len = state.range(0); + int64_t offset = state.range(1); + + ASSERT_OK_AND_ASSIGN(auto trues, MakeArrayFromScalar(BooleanScalar(true), len / 3)); + ASSERT_OK_AND_ASSIGN(auto falses, MakeArrayFromScalar(BooleanScalar(false), len / 3)); + ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(boolean(), len - 2 * (len / 3))); + ASSERT_OK_AND_ASSIGN(auto concat, Concatenate({trues, falses, nulls})); + auto cond1 = std::static_pointer_cast(concat); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto cond2 = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto val1 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val2 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto val3 = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + ASSERT_OK_AND_ASSIGN( + auto cond, StructArray::Make({cond1, cond2}, std::vector{"a", "b"}, + nullptr, /*null_count=*/0)); + + for (auto _ : state) { + ABORT_NOT_OK(CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset), + val3->Slice(offset)})); + } + + state.SetBytesProcessed(state.iterations() * (len - offset) * sizeof(CType)); +} + +static void CaseWhenBench64(benchmark::State& state) { + return CaseWhenBench(state); +} + +static void CaseWhenBench64Contiguous(benchmark::State& state) { + return CaseWhenBenchContiguous(state); +} + BENCHMARK(IfElseBench32)->Args({elems, 0}); BENCHMARK(IfElseBench64)->Args({elems, 0}); @@ -109,5 +194,11 @@ BENCHMARK(IfElseBench64Contiguous)->Args({elems, 0}); BENCHMARK(IfElseBench32Contiguous)->Args({elems, 99}); BENCHMARK(IfElseBench64Contiguous)->Args({elems, 99}); +BENCHMARK(CaseWhenBench64)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64)->Args({elems, 99}); + +BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0}); +BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99}); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 670a2d42a3a..cd2d04a13e0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -#include -#include -#include -#include -#include #include +#include "arrow/array.h" +#include "arrow/array/concatenate.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/registry.h" +#include "arrow/testing/gtest_util.h" namespace arrow { namespace compute { @@ -45,15 +46,16 @@ class TestIfElseKernel : public ::testing::Test {}; template class TestIfElsePrimitive : public ::testing::Test {}; -using PrimitiveTypes = ::testing::Types; +using NumericBasedTypes = + ::testing::Types; -TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); +TYPED_TEST_SUITE(TestIfElsePrimitive, NumericBasedTypes); TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { using ArrayType = typename TypeTraits::ArrayType; - auto type = TypeTraits::type_singleton(); + auto type = default_type_instance(); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; @@ -71,7 +73,7 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - typename TypeTraits::BuilderType builder; + typename TypeTraits::BuilderType builder(type, default_memory_pool()); for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || @@ -155,7 +157,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, } TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { - auto type = TypeTraits::type_singleton(); + auto type = default_type_instance(); CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), ArrayFromJSON(type, "[1, 2, 3, 4]"), @@ -316,5 +318,360 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); } +template +class TestCaseWhenNumeric : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestCaseWhenNumeric, NumericBasedTypes); + +Datum MakeStruct(const std::vector& conds) { + ProjectOptions options; + options.field_names.resize(conds.size()); + options.field_metadata.resize(conds.size()); + for (const auto& datum : conds) { + options.field_nullability.push_back(datum.null_count() > 0); + } + EXPECT_OK_AND_ASSIGN(auto result, CallFunction("project", conds, &options)); + return result; +} + +TYPED_TEST(TestCaseWhenNumeric, FixedSize) { + auto type = default_type_instance(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "1"); + auto scalar2 = ScalarFromJSON(type, "2"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[3, null, 5, 6]"); + auto values2 = ArrayFromJSON(type, "[7, 8, null, 10]"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, "[1, 1, 2, null]")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, 1, 1]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, "[1, 1, 2, 1]")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, "[3, null, null, null]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, "[3, null, null, 6]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, "[null, null, null, 6]")); + + CheckScalar( + "case_when", + {MakeStruct( + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]")}), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, null, null, 26, null, null]")); + CheckScalar( + "case_when", + {MakeStruct( + {ArrayFromJSON(boolean(), + "[true, true, true, false, false, false, null, null, null]"), + ArrayFromJSON(boolean(), + "[true, false, null, true, false, null, true, false, null]")}), + ArrayFromJSON(type, "[10, 11, 12, 13, 14, 15, 16, 17, 18]"), + + ArrayFromJSON(type, "[20, 21, 22, 23, 24, 25, 26, 27, 28]"), + ArrayFromJSON(type, "[30, 31, 32, 33, 34, null, 36, 37, null]")}, + ArrayFromJSON(type, "[10, 11, 12, 23, 34, null, 26, 37, null]")); + + // Error cases + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("cond struct must not be null"), + CallFunction( + "case_when", + {Datum(std::make_shared(struct_({field("", boolean())}))), + Datum(scalar1)})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("cond struct must not have top-level nulls"), + CallFunction( + "case_when", + {Datum(*MakeArrayOfNull(struct_({field("", boolean())}), 4)), Datum(values1)})); +} + +TEST(TestCaseWhen, Null) { + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_arr = ArrayFromJSON(boolean(), "[true, true, false, null]"); + auto scalar = ScalarFromJSON(null(), "null"); + auto array = ArrayFromJSON(null(), "[null, null, null, null]"); + CheckScalar("case_when", {MakeStruct({}), array}, array); + CheckScalar("case_when", {MakeStruct({cond_false}), array}, array); + CheckScalar("case_when", {MakeStruct({cond_true}), array, array}, array); + CheckScalar("case_when", {MakeStruct({cond_arr, cond_true}), array, array}, array); +} + +TEST(TestCaseWhen, Boolean) { + auto type = boolean(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "true"); + auto scalar2 = ScalarFromJSON(type, "false"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[true, null, true, true]"); + auto values2 = ArrayFromJSON(type, "[false, false, null, false]"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, "[true, true, false, null]")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, true, true]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, "[true, true, false, true]")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, "[true, null, null, null]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, "[true, null, null, true]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, "[null, null, null, true]")); +} + +TEST(TestCaseWhen, DayTimeInterval) { + auto type = day_time_interval(); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, "[1, 1]"); + auto scalar2 = ScalarFromJSON(type, "[2, 2]"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, "[[3, 3], null, [5, 5], [6, 6]]"); + auto values2 = ArrayFromJSON(type, "[[7, 7], [8, 8], null, [10, 10]]"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], null]")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, "[null, null, [1, 1], [1, 1]]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, "[[1, 1], [1, 1], [2, 2], [1, 1]]")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, "[[3, 3], null, null, null]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, "[[3, 3], null, null, [6, 6]]")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, "[null, null, null, [6, 6]]")); +} + +TEST(TestCaseWhen, Decimal) { + for (const auto& type : + std::vector>{decimal128(3, 2), decimal256(3, 2)}) { + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"("1.23")"); + auto scalar2 = ScalarFromJSON(type, R"("2.34")"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"(["3.45", null, "5.67", "6.78"])"); + auto values2 = ArrayFromJSON(type, R"(["7.89", "8.90", null, "1.01"])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, + values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, "1.23", "1.23"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"(["1.23", "1.23", "2.34", "1.23"])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"(["3.45", null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"(["3.45", null, null, "6.78"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, "6.78"])")); + } +} + +TEST(TestCaseWhen, FixedSizeBinary) { + auto type = fixed_size_binary(3); + auto cond_true = ScalarFromJSON(boolean(), "true"); + auto cond_false = ScalarFromJSON(boolean(), "false"); + auto cond_null = ScalarFromJSON(boolean(), "null"); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto scalar_null = ScalarFromJSON(type, "null"); + auto scalar1 = ScalarFromJSON(type, R"("abc")"); + auto scalar2 = ScalarFromJSON(type, R"("bcd")"); + auto values_null = ArrayFromJSON(type, "[null, null, null, null]"); + auto values1 = ArrayFromJSON(type, R"(["cde", null, "def", "efg"])"); + auto values2 = ArrayFromJSON(type, R"(["fgh", "ghi", null, "hij"])"); + + CheckScalar("case_when", {MakeStruct({}), values1}, values1); + CheckScalar("case_when", {MakeStruct({}), values_null}, values_null); + + CheckScalar("case_when", {MakeStruct({cond_true}), values1}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_null}), values1}, values_null); + CheckScalar("case_when", {MakeStruct({cond_true}), values1, values2}, values1); + CheckScalar("case_when", {MakeStruct({cond_false}), values1, values2}, values2); + CheckScalar("case_when", {MakeStruct({cond_null}), values1, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond_true, cond_true}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_false}), values1, values2}, + values_null); + CheckScalar("case_when", {MakeStruct({cond_true, cond_false}), values1, values2}, + values1); + CheckScalar("case_when", {MakeStruct({cond_false, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", {MakeStruct({cond_null, cond_true}), values1, values2}, + values2); + CheckScalar("case_when", + {MakeStruct({cond_false, cond_false}), values1, values2, values2}, values2); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2}, + ArrayFromJSON(type, R"(["abc", "abc", "bcd", null])")); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null}, values_null); + CheckScalar("case_when", {MakeStruct({cond1}), scalar_null, scalar1}, + ArrayFromJSON(type, R"([null, null, "abc", "abc"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), scalar1, scalar2, scalar1}, + ArrayFromJSON(type, R"(["abc", "abc", "bcd", "abc"])")); + + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + ArrayFromJSON(type, R"(["cde", null, null, null])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}, + ArrayFromJSON(type, R"(["cde", null, null, "efg"])")); + CheckScalar("case_when", {MakeStruct({cond1, cond2}), values_null, values2, values1}, + ArrayFromJSON(type, R"([null, null, null, "efg"])")); +} + +TEST(TestCaseWhen, DispatchBest) { + CheckDispatchBest("case_when", {struct_({field("", boolean())}), int64(), int32()}, + {struct_({field("", boolean())}), int64(), int64()}); + + ASSERT_RAISES(Invalid, CallFunction("case_when", {})); + // Too many/too few conditions + ASSERT_RAISES( + Invalid, CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")})})); + ASSERT_RAISES(Invalid, + CallFunction("case_when", {MakeStruct({}), ArrayFromJSON(int64(), "[]"), + ArrayFromJSON(int64(), "[]")})); + // Conditions must be struct of boolean + ASSERT_RAISES(TypeError, + CallFunction("case_when", {MakeStruct({ArrayFromJSON(int64(), "[]")}), + ArrayFromJSON(int64(), "[]")})); + ASSERT_RAISES(TypeError, CallFunction("case_when", {ArrayFromJSON(boolean(), "[true]"), + ArrayFromJSON(int32(), "[0]")})); + // Values must have compatible types + ASSERT_RAISES(NotImplemented, + CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}), + ArrayFromJSON(int64(), "[]"), + ArrayFromJSON(utf8(), "[]")})); +} } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index a1151717d8b..ce8d42e34c2 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -47,12 +47,10 @@ DatumVector GetDatums(const std::vector& inputs) { } void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, - const std::shared_ptr& expected, - const FunctionOptions* options) { + const Datum& expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - std::shared_ptr actual = std::move(out).make_array(); - ValidateOutput(*actual); - AssertArraysEqual(*expected, *actual, /*verbose=*/true); + ValidateOutput(out); + AssertDatumsEqual(expected, out, /*verbose=*/true); } template @@ -103,35 +101,38 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, } } -void CheckScalar(std::string func_name, const DatumVector& inputs, - std::shared_ptr expected, const FunctionOptions* options) { - CheckScalarNonRecursive(func_name, inputs, expected, options); +void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected_datum, + const FunctionOptions* options) { + CheckScalarNonRecursive(func_name, inputs, expected_datum, options); + + if (expected_datum.is_scalar()) return; + ASSERT_TRUE(expected_datum.is_array()) + << "CheckScalar is only implemented for scalar/array expected values"; + auto expected = expected_datum.make_array(); // check for at least 1 array, and make sure the others are of equal length - std::shared_ptr array; + bool has_array = false; for (const auto& input : inputs) { if (input.is_array()) { - if (!array) { - array = input.make_array(); - } else { - ASSERT_EQ(input.array()->length, array->length()); - } + ASSERT_EQ(input.array()->length, expected->length()); + has_array = true; } } + ASSERT_TRUE(has_array) << "Must have at least 1 array input to have an array output"; // Check all the input scalars, if scalars are implemented if (std::none_of(inputs.begin(), inputs.end(), [](const Datum& datum) { return datum.type()->id() == Type::EXTENSION; })) { // Check all the input scalars - for (int64_t i = 0; i < array->length(); ++i) { + for (int64_t i = 0; i < expected->length(); ++i) { CheckScalar(func_name, GetScalars(inputs, i), *expected->GetScalar(i), options); } } // Since it's a scalar function, calling it on sliced inputs should // result in the sliced expected output. - const auto slice_length = array->length() / 3; + const auto slice_length = expected->length() / 3; if (slice_length > 0) { CheckScalarNonRecursive(func_name, SliceArrays(inputs, 0, slice_length), expected->Slice(0, slice_length), options); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index c691a9f3be3..a3fb9308f58 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -95,8 +95,7 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); -void CheckScalar(std::string func_name, const DatumVector& inputs, - std::shared_ptr expected, +void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 6ce808aba67..ed97faead74 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -860,33 +860,44 @@ Structural transforms .. XXX (this category is a bit of a hodgepodge) -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| Function name | Arity | Input types | Output type | Notes | -+==========================+============+================================================+=====================+=========+ -| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(2) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(3) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(4) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(5) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(6) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(7) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(8) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ -| project | Varargs | Any | Struct | \(9) | -+--------------------------+------------+------------------------------------------------+---------------------+---------+ - -* \(1) First input must be an array, second input a scalar of the same type. ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| Function name | Arity | Input types | Output type | Notes | ++==========================+============+===================================================+=====================+=========+ +| case_when | Varargs | Struct of Boolean (Arg 0), Any fixed-width (rest) | Input type | \(1) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(2) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(3) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_finite | Unary | Float, Double | Boolean | \(4) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_inf | Unary | Float, Double | Boolean | \(5) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_nan | Unary | Float, Double | Boolean | \(6) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_null | Unary | Any | Boolean | \(7) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| is_valid | Unary | Any | Boolean | \(8) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| list_value_length | Unary | List-like | Int32 or Int64 | \(9) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ +| project | Varargs | Any | Struct | \(10) | ++--------------------------+------------+---------------------------------------------------+---------------------+---------+ + +* \(1) This function acts like a SQL 'case when' statement or switch-case. The + input is a "condition" value, which is a struct of Booleans, followed by the + values for each "branch". There must be either exactly one value argument for + each child of the condition struct, or one more value argument than children + (in which case we have an 'else' or 'default' value). The output is of the + same type as the value inputs; each row will be the corresponding value from + the first value datum for which the corresponding Boolean is true, or the + corresponding value from the 'default' input, or null otherwise. + +* \(2) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values as the first input, except for nulls replaced with the second input value. -* \(2) First input must be a Boolean scalar or array. Second and third inputs +* \(3) First input must be a Boolean scalar or array. Second and third inputs could be scalars or arrays and must be of the same type. Output is an array (or scalar if all inputs are scalar) of the same type as the second/ third input. If the nulls present on the first input, they will be promoted to the @@ -894,21 +905,21 @@ Structural transforms Also see: :ref:`replace_with_mask `. -* \(3) Output is true iff the corresponding input element is finite (not Infinity, +* \(4) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). -* \(4) Output is true iff the corresponding input element is Infinity/-Infinity. +* \(5) Output is true iff the corresponding input element is Infinity/-Infinity. -* \(5) Output is true iff the corresponding input element is NaN. +* \(6) Output is true iff the corresponding input element is NaN. -* \(6) Output is true iff the corresponding input element is null. +* \(7) Output is true iff the corresponding input element is null. -* \(7) Output is true iff the corresponding input element is non-null. +* \(8) Output is true iff the corresponding input element is non-null. -* \(8) Each output element is the length of the corresponding input element +* \(9) Each output element is the length of the corresponding input element (null if input is null). Output type is Int32 for List, Int64 for LargeList. -* \(9) The output struct's field types are the types of its arguments. The +* \(10) The output struct's field types are the types of its arguments. The field names are specified using an instance of :struct:`ProjectOptions`. The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 09c67598193..c12f2f91b26 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -335,6 +335,7 @@ Structural Transforms :toctree: ../generated/ binary_length + case_when fill_null if_else is_finite