diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 634d202623f..88a92d8c2c9 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -401,6 +401,7 @@ if(ARROW_COMPUTE) compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc + compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc compute/exec/key_hash.cc diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index db3c640b9a4..f8e8c4f8a44 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -71,6 +71,13 @@ class BaseBinaryArray : public FlatArray { raw_value_offsets_[i + 1] - pos); } + /// \brief Get binary value as a string_view + /// Provided for consistency with other arrays. + /// + /// \param i the value index + /// \return the view over the selected value + util::string_view Value(int64_t i) const { return GetView(i); } + /// \brief Get binary value as a std::string /// /// \param i the value index diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 688cb20cb9a..ed26ecff4e0 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -528,6 +528,11 @@ class RepeatedArrayFactory { return FinishFixedWidth(value.data(), value.size()); } + Status Visit(const Decimal256Type&) { + auto value = checked_cast(scalar_).value.ToBytes(); + return FinishFixedWidth(value.data(), value.size()); + } + template enable_if_base_binary Visit(const T&) { std::shared_ptr value = diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 9c1ef8533b4..a68969b2ee5 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -162,6 +162,11 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } +Result ReplaceWithMask(const Datum& values, const Datum& mask, + const Datum& replacements, ExecContext* ctx) { + return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); +} + Result> SortIndices(const Array& values, SortOrder order, ExecContext* ctx) { ArraySortOptions options(order); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 6021492320e..9d8d4271db8 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -171,6 +171,23 @@ Result> GetTakeIndices( } // namespace internal +/// \brief ReplaceWithMask replaces each value in the array corresponding +/// to a true value in the mask with the next element from `replacements`. +/// +/// \param[in] values Array input to replace +/// \param[in] mask Array or Scalar of Boolean mask values +/// \param[in] replacements The replacement values to draw from. There must +/// be as many replacement values as true values in the mask. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result ReplaceWithMask(const Datum& values, const Datum& mask, + const Datum& replacements, ExecContext* ctx = NULLPTR); + /// \brief Take from an array of values at indices in another array /// /// The output array will be of the same type as the input values diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 3362d91cbe8..474ce1418fd 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -48,6 +48,7 @@ add_arrow_compute_test(vector_test SOURCES vector_hash_test.cc vector_nested_test.cc + vector_replace_test.cc vector_selection_test.cc vector_sort_test.cc test_util.cc) @@ -55,6 +56,7 @@ add_arrow_compute_test(vector_test add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute") # ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 33b7006491a..12e80423f7f 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1240,6 +1240,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { case Type::FLOAT: case Type::DATE32: case Type::TIME32: + case Type::INTERVAL_MONTHS: return Generator::Exec; case Type::UINT64: case Type::INT64: @@ -1248,6 +1249,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { case Type::TIMESTAMP: case Type::TIME64: case Type::DURATION: + case Type::INTERVAL_DAY_TIME: return Generator::Exec; default: DCHECK(false); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index f4854087b51..c691a9f3be3 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -172,5 +172,27 @@ void CheckDispatchBest(std::string func_name, std::vector descrs, // Check that function fails to produce a Kernel for the set of ValueDescrs. void CheckDispatchFails(std::string func_name, std::vector descrs); +// Helper to get a default instance of a type, including parameterized types +template +enable_if_parameter_free> default_type_instance() { + return TypeTraits::type_singleton(); +} +template +enable_if_time> default_type_instance() { + // Time32 requires second/milli, Time64 requires nano/micro + if (bit_width(T::type_id) == 32) { + return std::make_shared(TimeUnit::type::SECOND); + } + return std::make_shared(TimeUnit::type::NANO); +} +template +enable_if_timestamp> default_type_instance() { + return std::make_shared(TimeUnit::type::SECOND); +} +template +enable_if_decimal> default_type_instance() { + return std::make_shared(5, 2); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc new file mode 100644 index 00000000000..644aec2a4e9 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -0,0 +1,540 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { + +Status ReplacementArrayTooShort(int64_t expected, int64_t actual) { + return Status::Invalid("Replacement array must be of appropriate length (expected ", + expected, " items but got ", actual, " items)"); +} + +// Helper to implement replace_with kernel with scalar mask for fixed-width types, +// using callbacks to handle both bool and byte-sized types +template +Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + Datum source = array; + if (!mask.is_valid) { + // Output = null + source = MakeNullScalar(output->type); + } else if (mask.value) { + // Output = replacement + source = replacements; + } + uint8_t* out_bitmap = output->buffers[0]->mutable_data(); + uint8_t* out_values = output->buffers[1]->mutable_data(); + const int64_t out_offset = output->offset; + if (source.is_array()) { + const ArrayData& in_data = *source.array(); + if (in_data.length < array.length) { + return ReplacementArrayTooShort(array.length, in_data.length); + } + Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0, + array.length); + if (in_data.MayHaveNulls()) { + arrow::internal::CopyBitmap(in_data.buffers[0]->data(), in_data.offset, + array.length, out_bitmap, out_offset); + } else { + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true); + } + } else { + const Scalar& in_data = *source.scalar(); + Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0, + array.length); + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, in_data.is_valid); + } + return Status::OK(); +} + +struct CopyArrayBitmap { + const uint8_t* in_bitmap; + int64_t in_offset; + + void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset, + int64_t length) const { + arrow::internal::CopyBitmap(in_bitmap, in_offset + offset, length, out_bitmap, + out_offset); + } + + void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const { + BitUtil::SetBitTo(out_bitmap, out_offset, + BitUtil::GetBit(in_bitmap, in_offset + offset)); + } +}; + +struct CopyScalarBitmap { + const bool is_valid; + + void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset, + int64_t length) const { + BitUtil::SetBitsTo(out_bitmap, out_offset, length, is_valid); + } + + void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const { + BitUtil::SetBitTo(out_bitmap, out_offset, is_valid); + } +}; + +// Helper to implement replace_with kernel with array mask for fixed-width types, +// using callbacks to handle both bool and byte-sized types and to handle +// scalar and array replacements +template +void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, + const Data& replacements, bool replacements_bitmap, + const CopyBitmap& copy_bitmap, const uint8_t* mask_bitmap, + const uint8_t* mask_values, uint8_t* out_bitmap, + uint8_t* out_values, const int64_t out_offset) { + Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, + array.length); + arrow::internal::OptionalBinaryBitBlockCounter counter( + mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); + int64_t write_offset = 0; + int64_t replacements_offset = 0; + while (write_offset < array.length) { + BitBlockCount block = counter.NextAndBlock(); + if (block.AllSet()) { + // Copy from replacement array + Functor::CopyData(*array.type, out_values, out_offset + write_offset, replacements, + replacements_offset, block.length); + if (replacements_bitmap) { + copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset, + block.length); + } else if (!replacements_bitmap && out_bitmap) { + BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true); + } + replacements_offset += block.length; + } else if (block.popcount) { + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(mask_values, write_offset + mask.offset + i) && + (!mask_bitmap || + BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) { + Functor::CopyData(*array.type, out_values, out_offset + write_offset + i, + replacements, replacements_offset, /*length=*/1); + if (replacements_bitmap) { + copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i, + replacements_offset); + } + replacements_offset++; + } + } + } + write_offset += block.length; + } +} + +template +Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + const int64_t out_offset = output->offset; + uint8_t* out_bitmap = nullptr; + uint8_t* out_values = output->buffers[1]->mutable_data(); + const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; + const uint8_t* mask_values = mask.buffers[1]->data(); + const bool replacements_bitmap = replacements.is_array() + ? replacements.array()->MayHaveNulls() + : !replacements.scalar()->is_valid; + if (replacements.is_array()) { + // Check that we have enough replacement values + const int64_t replacements_length = replacements.array()->length; + + BooleanArray mask_arr(mask.length, mask.buffers[1], mask.buffers[0], mask.null_count, + mask.offset); + const int64_t count = mask_arr.true_count(); + if (count > replacements_length) { + return ReplacementArrayTooShort(count, replacements_length); + } + } + if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) { + out_bitmap = output->buffers[0]->mutable_data(); + output->null_count = -1; + if (array.MayHaveNulls()) { + // Copy array's bitmap + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length, + out_bitmap, out_offset); + } else { + // Array has no bitmap but mask/replacements do, generate an all-valid bitmap + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true); + } + } else { + BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, array.length, + true); + output->null_count = 0; + } + + if (replacements.is_array()) { + const ArrayData& array_repl = *replacements.array(); + ReplaceWithArrayMaskImpl( + array, mask, array_repl, replacements_bitmap, + CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr, + array_repl.offset}, + mask_bitmap, mask_values, out_bitmap, out_values, out_offset); + } else { + const Scalar& scalar_repl = *replacements.scalar(); + ReplaceWithArrayMaskImpl(array, mask, scalar_repl, replacements_bitmap, + CopyScalarBitmap{scalar_repl.is_valid}, mask_bitmap, + mask_values, out_bitmap, out_values, out_offset); + } + + if (mask.MayHaveNulls()) { + arrow::internal::BitmapAnd(out_bitmap, out_offset, mask.buffers[0]->data(), + mask.offset, array.length, out_offset, out_bitmap); + } + return Status::OK(); +} + +template +struct ReplaceWithMask {}; + +template +struct ReplaceWithMask> { + using T = typename TypeTraits::CType; + + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * sizeof(T)); + std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T)); + } + + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + T* begin = reinterpret_cast(out + (out_offset * sizeof(T))); + T* end = begin + length; + std::fill(begin, end, UnboxScalar::Unbox(in)); + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const auto in_arr = in.GetValues(1, /*absolute_offset=*/0); + arrow::internal::CopyBitmap(in_arr, in_offset + in.offset, length, out, out_offset); + } + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + BitUtil::SetBitsTo(out, out_offset, length, in.is_valid); + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * width); + std::memcpy(begin, in_arr, length * width); + } + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto& scalar = checked_cast(in); + // Null scalar may have null value buffer + if (!scalar.value) return; + const Buffer& buffer = *scalar.value; + const uint8_t* value = buffer.data(); + DCHECK_GE(buffer.size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value, width); + begin += width; + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + using ScalarType = typename TypeTraits::ScalarType; + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * width); + std::memcpy(begin, in_arr, length * width); + } + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto& scalar = checked_cast(in); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value.data(), width); + begin += width; + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + *output = array; + return Status::OK(); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + *output = array; + return Status::OK(); + } +}; + +template +struct ReplaceWithMask> { + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + if (!mask.is_valid) { + // Output = null + ARROW_ASSIGN_OR_RAISE( + auto replacement_array, + MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); + *output = *replacement_array->data(); + } else if (mask.value) { + // Output = replacement + if (replacements.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto replacement_array, + MakeArrayFromScalar(*replacements.scalar(), array.length, + ctx->memory_pool())); + *output = *replacement_array->data(); + } else { + const ArrayData& replacement_array = *replacements.array(); + if (replacement_array.length < array.length) { + return ReplacementArrayTooShort(array.length, replacement_array.length); + } + *output = replacement_array; + output->length = array.length; + } + } else { + // Output = input + *output = array; + } + return Status::OK(); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + BuilderType builder(array.type, ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(array.length)); + RETURN_NOT_OK(builder.ReserveData(array.buffers[2]->size())); + int64_t source_offset = 0; + int64_t replacements_offset = 0; + RETURN_NOT_OK(VisitArrayDataInline( + mask, + [&](bool replace) { + if (replace && replacements.is_scalar()) { + const Scalar& scalar = *replacements.scalar(); + if (scalar.is_valid) { + RETURN_NOT_OK(builder.Append(UnboxScalar::Unbox(scalar))); + } else { + RETURN_NOT_OK(builder.AppendNull()); + } + } else { + const ArrayData& source = replace ? *replacements.array() : array; + const int64_t offset = replace ? replacements_offset++ : source_offset; + if (!source.MayHaveNulls() || + BitUtil::GetBit(source.buffers[0]->data(), source.offset + offset)) { + const uint8_t* data = source.buffers[2]->data(); + const offset_type* offsets = source.GetValues(1); + const offset_type offset0 = offsets[offset]; + const offset_type offset1 = offsets[offset + 1]; + RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0)); + } else { + RETURN_NOT_OK(builder.AppendNull()); + } + } + source_offset++; + return Status::OK(); + }, + [&]() { + RETURN_NOT_OK(builder.AppendNull()); + source_offset++; + return Status::OK(); + })); + std::shared_ptr temp_output; + RETURN_NOT_OK(builder.Finish(&temp_output)); + *output = *temp_output->data(); + // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase + output->type = array.type; + return Status::OK(); + } +}; + +template +struct ReplaceWithMaskFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& array = *batch[0].array(); + const Datum& replacements = batch[2]; + ArrayData* output = out->array().get(); + output->length = array.length; + + // Needed for FixedSizeBinary/parameterized types + if (!array.type->Equals(*replacements.type(), /*check_metadata=*/false)) { + return Status::Invalid("Replacements must be of same type (expected ", + array.type->ToString(), " but got ", + replacements.type()->ToString(), ")"); + } + + if (!replacements.is_array() && !replacements.is_scalar()) { + return Status::Invalid("Replacements must be array or scalar"); + } + + if (batch[1].is_scalar()) { + return ReplaceWithMask::ExecScalarMask( + ctx, array, batch[1].scalar_as(), replacements, output); + } + const ArrayData& mask = *batch[1].array(); + if (array.length != mask.length) { + return Status::Invalid("Mask must be of same length as array (expected ", + array.length, " items but got ", mask.length, " items)"); + } + return ReplaceWithMask::ExecArrayMask(ctx, array, mask, replacements, output); + } +}; + +} // namespace + +const FunctionDoc replace_with_mask_doc( + "Replace items using a mask and replacement values", + ("Given an array and a Boolean mask (either scalar or of equal length), " + "along with replacement values (either scalar or array), " + "each element of the array for which the corresponding mask element is " + "true will be replaced by the next value from the replacements, " + "or with null if the mask is null. " + "Hence, for replacement arrays, len(replacements) == sum(mask == true)."), + {"values", "mask", "replacements"}); + +void RegisterVectorReplace(FunctionRegistry* registry) { + auto func = std::make_shared("replace_with_mask", Arity::Ternary(), + &replace_with_mask_doc); + auto add_kernel = [&](detail::GetTypeId get_id, ArrayKernelExec exec) { + VectorKernel kernel; + kernel.can_execute_chunkwise = false; + if (is_fixed_width(get_id.id)) { + kernel.null_handling = NullHandling::type::COMPUTED_PREALLOCATE; + } else { + kernel.can_write_into_slices = false; + kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + } + kernel.mem_allocation = MemAllocation::type::PREALLOCATE; + kernel.signature = KernelSignature::Make( + {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)}, + OutputType(FirstType)); + kernel.exec = std::move(exec); + DCHECK_OK(func->AddKernel(std::move(kernel))); + }; + auto add_primitive_kernel = [&](detail::GetTypeId get_id) { + add_kernel(get_id, GenerateTypeAgnosticPrimitive(get_id)); + }; + for (const auto& ty : NumericTypes()) { + add_primitive_kernel(ty); + } + for (const auto& ty : TemporalTypes()) { + add_primitive_kernel(ty); + } + add_primitive_kernel(null()); + add_primitive_kernel(boolean()); + add_primitive_kernel(day_time_interval()); + add_primitive_kernel(month_interval()); + add_kernel(Type::FIXED_SIZE_BINARY, ReplaceWithMaskFunctor::Exec); + add_kernel(Type::DECIMAL128, ReplaceWithMaskFunctor::Exec); + add_kernel(Type::DECIMAL256, ReplaceWithMaskFunctor::Exec); + for (const auto& ty : BaseBinaryTypes()) { + add_kernel(ty->id(), GenerateTypeAgnosticVarBinaryBase(*ty)); + } + // TODO: list types + DCHECK_OK(registry->AddFunction(std::move(func))); + + // TODO(ARROW-9431): "replace_with_indices" +} +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc new file mode 100644 index 00000000000..719969d46ea --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/array.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +#include "arrow/compute/api_vector.h" + +namespace arrow { +namespace compute { + +using ::arrow::internal::checked_pointer_cast; + +static constexpr random::SeedType kRandomSeed = 0xabcdef; +static constexpr random::SeedType kLongLength = 16384; + +static std::shared_ptr MakeReplacements(random::RandomArrayGenerator* generator, + const BooleanArray& mask) { + int64_t count = 0; + for (int64_t i = 0; i < mask.length(); i++) { + count += mask.Value(i) && mask.IsValid(i); + } + return generator->Int64(count, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1); +} + +static void ReplaceWithMaskLowSelectivityBench( + benchmark::State& state) { // NOLINT non-const reference + random::RandomArrayGenerator generator(kRandomSeed); + const int64_t len = state.range(0); + const int64_t offset = state.range(1); + + auto values = + generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1) + ->Slice(offset); + auto mask = checked_pointer_cast( + generator.Boolean(len, /*true_probability=*/0.1, /*null_probability=*/0.1) + ->Slice(offset)); + auto replacements = MakeReplacements(&generator, *mask); + + for (auto _ : state) { + ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements)); + } + state.SetBytesProcessed(state.iterations() * (len - offset) * 8); +} + +static void ReplaceWithMaskHighSelectivityBench( + benchmark::State& state) { // NOLINT non-const reference + random::RandomArrayGenerator generator(kRandomSeed); + const int64_t len = state.range(0); + const int64_t offset = state.range(1); + + auto values = + generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1) + ->Slice(offset); + auto mask = checked_pointer_cast( + generator.Boolean(len, /*true_probability=*/0.9, /*null_probability=*/0.1) + ->Slice(offset)); + auto replacements = MakeReplacements(&generator, *mask); + + for (auto _ : state) { + ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements)); + } + state.SetBytesProcessed(state.iterations() * (len - offset) * 8); +} + +BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 0}); +BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 99}); +BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 0}); +BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 99}); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc new file mode 100644 index 00000000000..48f253e7ca9 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -0,0 +1,677 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/make_unique.h" + +namespace arrow { +namespace compute { + +using arrow::internal::checked_pointer_cast; + +template +class TestReplaceKernel : public ::testing::Test { + protected: + virtual std::shared_ptr type() = 0; + + using ReplaceFunction = std::function(const Datum&, const Datum&, + const Datum&, ExecContext*)>; + + void SetUp() override { equal_options_ = equal_options_.nans_equal(true); } + + Datum mask_scalar(bool value) { return Datum(std::make_shared(value)); } + + Datum null_mask_scalar() { + auto scalar = std::make_shared(true); + scalar->is_valid = false; + return Datum(std::move(scalar)); + } + + Datum scalar(const std::string& json) { return ScalarFromJSON(type(), json); } + + std::shared_ptr array(const std::string& value) { + return ArrayFromJSON(type(), value); + } + + std::shared_ptr mask(const std::string& value) { + return ArrayFromJSON(boolean(), value); + } + + Status AssertRaises(ReplaceFunction func, const std::shared_ptr& array, + const Datum& mask, const std::shared_ptr& replacements) { + auto result = func(array, mask, replacements, nullptr); + EXPECT_FALSE(result.ok()); + return result.status(); + } + + void Assert(ReplaceFunction func, const std::shared_ptr& array, + const Datum& mask, Datum replacements, + const std::shared_ptr& expected) { + SCOPED_TRACE("Replacements: " + (replacements.is_array() + ? replacements.make_array()->ToString() + : replacements.scalar()->ToString())); + SCOPED_TRACE("Mask: " + (mask.is_array() ? mask.make_array()->ToString() + : mask.scalar()->ToString())); + SCOPED_TRACE("Array: " + array->ToString()); + + ASSERT_OK_AND_ASSIGN(auto actual, func(array, mask, replacements, nullptr)); + ASSERT_TRUE(actual.is_array()); + ASSERT_OK(actual.make_array()->ValidateFull()); + + AssertArraysApproxEqual(*expected, *actual.make_array(), /*verbose=*/true, + equal_options_); + } + + std::shared_ptr NaiveImpl( + const typename TypeTraits::ArrayType& array, const BooleanArray& mask, + const typename TypeTraits::ArrayType& replacements) { + auto length = array.length(); + auto builder = arrow::internal::make_unique::BuilderType>( + default_type_instance(), default_memory_pool()); + int64_t replacement_offset = 0; + for (int64_t i = 0; i < length; ++i) { + if (mask.IsValid(i)) { + if (mask.Value(i)) { + if (replacements.IsValid(replacement_offset)) { + ARROW_EXPECT_OK(builder->Append(replacements.Value(replacement_offset++))); + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + replacement_offset++; + } + } else { + if (array.IsValid(i)) { + ARROW_EXPECT_OK(builder->Append(array.Value(i))); + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + } + } + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + } + } + EXPECT_OK_AND_ASSIGN(auto expected, builder->Finish()); + return expected; + } + + EqualOptions equal_options_ = EqualOptions::Defaults(); +}; + +template +class TestReplaceNumeric : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +class TestReplaceBoolean : public TestReplaceKernel { + protected: + std::shared_ptr type() override { + return TypeTraits::type_singleton(); + } +}; + +class TestReplaceFixedSizeBinary : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return fixed_size_binary(3); } +}; + +template +class TestReplaceDecimal : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +class TestReplaceDayTimeInterval : public TestReplaceKernel { + protected: + std::shared_ptr type() override { + return TypeTraits::type_singleton(); + } +}; + +template +class TestReplaceBinary : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +using NumericBasedTypes = + ::testing::Types; + +TYPED_TEST_SUITE(TestReplaceNumeric, NumericBasedTypes); +TYPED_TEST_SUITE(TestReplaceDecimal, DecimalArrowTypes); +TYPED_TEST_SUITE(TestReplaceBinary, BinaryTypes); + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(false), + this->array("[]"), this->array("[1]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), + this->array("[0]"), this->array("[0]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), + this->array("[2, 0]"), this->array("[2]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(false), + this->scalar("1"), this->array("[0, 0]")); + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true), + this->scalar("1"), this->array("[1, 1]")); + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[0, 1, 2, 3]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"), + this->array("[10, 11, 12, 13]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[0, 1, 2, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"), + this->array("[10, 11, 12, 13]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3, 4, 5]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[10, null]"), this->array("[10, null, 2, 3, null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[10, null]"), + this->array("[10, null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("1"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"), + this->scalar("10"), this->array("[10, 10]")); + this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2]"), + this->mask("[true, false, null]"), this->scalar("10"), + this->array("[10, 1, null]")); +} + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; + auto ty = this->type(); + + random::RandomArrayGenerator rand(/*seed=*/0); + const int64_t length = 1023; + std::vector values = {"0.01", "0"}; + // Clamp the range because date/time types don't print well with extreme values + values.push_back(std::to_string(static_cast(std::min( + 16384.0, static_cast(std::numeric_limits::max()))))); + auto options = key_value_metadata({"null_probability", "min", "max"}, values); + auto array = + checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length)); + auto mask = checked_pointer_cast( + rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); + const int64_t num_replacements = std::count_if( + mask->begin(), mask->end(), + [](util::optional value) { return value.has_value() && *value; }); + auto replacements = checked_pointer_cast( + rand.ArrayOf(*field("a", ty, options), num_replacements)); + auto expected = this->NaiveImpl(*array, *mask, *replacements); + + this->Assert(ReplaceWithMask, array, mask, replacements, expected); + for (int64_t slice = 1; slice <= 16; slice++) { + auto sliced_array = checked_pointer_cast(array->Slice(slice, 15)); + auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15)); + auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements); + this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected); + } +} + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " + "items but got 1 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), + this->mask("[true, true]"), this->array("[0]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), + this->mask("[true, null]"), this->array("[]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Mask must be of same length as array (expected 2 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), this->mask("[]"), + this->array("[]"))); +} + +TEST_F(TestReplaceBoolean, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(false), + this->array("[]"), this->array("[true]")); + this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true), + this->array("[false]"), this->array("[false]")); + this->Assert(ReplaceWithMask, this->array("[true]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(false), + this->scalar("true"), this->array("[false, false]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true), + this->scalar("true"), this->array("[true, true]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[true, true, true, true]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[true, true, true, true]"), + this->array("[false, false, false, false]"), + this->array("[false, false, false, false]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[true, true, true, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[true, true, true, true]"), + this->array("[false, false, false, false]"), + this->array("[false, false, false, false]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true, true, true]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[false, null]"), + this->array("[false, null, true, true, null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[false, null]"), + this->array("[false, null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), + this->scalar("true"), this->array("[true, true]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[false, false, false]"), + this->mask("[true, false, null]"), this->scalar("true"), + this->array("[true, false, null]")); +} + +TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " + "items but got 1 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), + this->mask("[true, true]"), this->array("[false]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), + this->mask("[true, null]"), this->array("[]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Mask must be of same length as array (expected 2 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), this->mask("[]"), + this->array("[]"))); +} + +TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["foo"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true), + this->array(R"(["bar"])"), this->array(R"(["bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), + this->mask_scalar(false), this->scalar(R"("baz")"), + this->array(R"(["foo", "bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar("null"), this->array(R"([null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["aaa", "bbb", "ccc", "ddd"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eee", "fff", "ggg", "hhh"])"), + this->array(R"(["eee", "fff", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["aaa", "bbb", "ccc", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eee", "fff", "ggg", "hhh"])"), + this->array(R"(["eee", "fff", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, + this->array(R"(["aaa", "bbb", "ccc", "ddd", "eee", "fff"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["ggg", null])"), + this->array(R"(["ggg", null, "ccc", "ddd", null, null])")); + this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["aaa", null])"), + this->array(R"(["aaa", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("zzz")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"), + this->mask("[true, true]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "zzz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc"])"), + this->mask("[true, false, null]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "bbb", null])")); +} + +TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::AllOf( + ::testing::HasSubstr("Replacements must be of same type (expected "), + ::testing::HasSubstr(this->type()->ToString()), + ::testing::HasSubstr("but got fixed_size_binary[2]")), + this->AssertRaises(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + ArrayFromJSON(fixed_size_binary(2), "[]"))); +} + +TYPED_TEST(TestReplaceDecimal, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["1.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(true), + this->array(R"(["0.00"])"), this->array(R"(["0.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(false), this->scalar(R"("1.00")"), + this->array(R"(["0.00", "0.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(true), this->scalar(R"("1.00")"), + this->array(R"(["1.00", "1.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(true), this->scalar("null"), + this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["0.00", "1.00", "2.00", "3.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["0.00", "1.00", "2.00", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, + this->array(R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["10.00", null])"), + this->array(R"(["10.00", null, "2.00", "3.00", null, null])")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["10.00", null])"), + this->array(R"(["10.00", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("1.00")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"), + this->mask("[true, true]"), this->scalar(R"("10.00")"), + this->array(R"(["10.00", "10.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00"])"), + this->mask("[true, false, null]"), this->scalar(R"("10.00")"), + this->array(R"(["10.00", "1.00", null])")); +} + +TEST_F(TestReplaceDayTimeInterval, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(false), + this->array("[]"), this->array("[[1, 2]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(true), + this->array("[[3, 4]]"), this->array("[[3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(false), + this->scalar("[7, 8]"), this->array("[[1, 2], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true), + this->scalar("[7, 8]"), this->array("[[7, 8], [7, 8]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[true, true, true, true]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[[1, 2], [1, 2], [1, 2], null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[true, true, true, true]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert( + ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[true, true, false, false, null, null]"), this->array("[[3, 4], null]"), + this->array("[[3, 4], null, [1, 2], [1, 2], null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[[3, 4], null]"), + this->array("[[3, 4], null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar("[7, 8]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), + this->mask("[true, true]"), this->scalar("[7, 8]"), + this->array("[[7, 8], [7, 8]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4], [5, 6]]"), + this->mask("[true, false, null]"), this->scalar("[7, 8]"), + this->array("[[7, 8], [3, 4], null]")); +} + +TYPED_TEST(TestReplaceBinary, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["foo"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true), + this->array(R"(["bar"])"), this->array(R"(["bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), + this->mask_scalar(false), this->scalar(R"("baz")"), + this->array(R"(["foo", "bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar("null"), this->array(R"([null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["a", "bb", "ccc", "dddd"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["a", "bb", "ccc", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, + this->array(R"(["a", "bb", "ccc", "dddd", "eeeee", "f"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["ggg", null])"), + this->array(R"(["ggg", null, "ccc", "dddd", null, null])")); + this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["a", null])"), + this->array(R"(["a", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("zzz")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"), + this->scalar(R"("zzz")"), this->array(R"(["zzz", "zzz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc"])"), + this->mask("[true, false, null]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "bb", null])")); +} + +TYPED_TEST(TestReplaceBinary, ReplaceWithMaskRandom) { + using ArrayType = typename TypeTraits::ArrayType; + auto ty = this->type(); + + random::RandomArrayGenerator rand(/*seed=*/0); + const int64_t length = 1023; + auto options = key_value_metadata({{"null_probability", "0.01"}, {"max_length", "5"}}); + auto array = + checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length)); + auto mask = checked_pointer_cast( + rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); + const int64_t num_replacements = std::count_if( + mask->begin(), mask->end(), + [](util::optional value) { return value.has_value() && *value; }); + auto replacements = checked_pointer_cast( + rand.ArrayOf(*field("a", ty, options), num_replacements)); + auto expected = this->NaiveImpl(*array, *mask, *replacements); + + this->Assert(ReplaceWithMask, array, mask, replacements, expected); + for (int64_t slice = 1; slice <= 16; slice++) { + auto sliced_array = checked_pointer_cast(array->Slice(slice, 15)); + auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15)); + auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements); + this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 8a0d9e62518..ca7b6137306 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -168,6 +168,7 @@ static std::unique_ptr CreateBuiltInRegistry() { // Vector functions RegisterVectorHash(registry.get()); + RegisterVectorReplace(registry.get()); RegisterVectorSelection(registry.get()); RegisterVectorNested(registry.get()); RegisterVectorSort(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index dd0271eb43d..892b54341da 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -41,6 +41,7 @@ void RegisterScalarOptions(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); +void RegisterVectorReplace(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 86664bbb162..e4d809967f9 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -233,6 +233,7 @@ struct TypeTraits { using ArrayType = MonthIntervalArray; using BuilderType = MonthIntervalBuilder; using ScalarType = MonthIntervalScalar; + using CType = MonthIntervalType::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast(sizeof(int32_t)); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index fc6c8b7c7e1..00391052b1e 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -850,6 +850,7 @@ in reverse order. as given by :struct:`SliceOptions` where ``start`` and ``stop`` are measured in codeunits. Null inputs emit null. +.. _cpp-compute-scalar-structural-transforms: Structural transforms ~~~~~~~~~~~~~~~~~~~~~ @@ -861,7 +862,7 @@ Structural transforms +==========================+============+================================================+=====================+=========+ | fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type + \(2) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(2) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ | is_finite | Unary | Float, Double | Boolean | \(3) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ @@ -888,6 +889,8 @@ Structural transforms input. If the nulls present on the first input, they will be promoted to the output, otherwise nulls will be chosen based on the first input values. + Also see: :ref:`replace_with_mask `. + * \(3) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). @@ -1154,6 +1157,8 @@ value, but smaller than nulls. table. If the input is a record batch or table, one or more sort keys must be specified. +.. _cpp-compute-vector-structural-transforms: + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ @@ -1172,3 +1177,18 @@ Structural transforms * \(2) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. + +These functions create a copy of the first input with some elements +replaced, based on the remaining inputs. + ++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+ +| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes | ++==========================+============+=======================+==============+==============+==============+=======+ +| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) | ++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+ + +* \(1) Each element in input 1 for which the corresponding Boolean in input 2 + is true is replaced with the next value from input 3. A null in input 2 + results in a corresponding null in the output. + + Also see: :ref:`if_else `. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index a611d2a2384..09c67598193 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -292,6 +292,14 @@ Conversions cast strptime +Replacements +------------ + +.. autosummary:: + :toctree: ../generated/ + + replace_with_mask + Selections ----------