From b249edaccb88080cec7d6297b8fac1a3f601e0c6 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 27 May 2021 09:46:57 -0400 Subject: [PATCH 01/15] ARROW-9430: [C++] Implement override_mask kernel --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/array/array_binary.h | 7 + cpp/src/arrow/array/util.cc | 5 + cpp/src/arrow/compute/api_vector.cc | 5 + cpp/src/arrow/compute/api_vector.h | 17 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/codegen_internal.h | 2 + .../arrow/compute/kernels/vector_replace.cc | 495 ++++++++++++ .../compute/kernels/vector_replace_test.cc | 736 ++++++++++++++++++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + cpp/src/arrow/type_traits.h | 1 + docs/source/cpp/compute.rst | 16 + docs/source/python/api/compute.rst | 8 + 14 files changed, 1296 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/vector_replace.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_replace_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 634d202623f..88a92d8c2c9 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -401,6 +401,7 @@ if(ARROW_COMPUTE) compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc + compute/kernels/vector_replace.cc compute/kernels/vector_selection.cc compute/kernels/vector_sort.cc compute/exec/key_hash.cc diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index db3c640b9a4..f8e8c4f8a44 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -71,6 +71,13 @@ class BaseBinaryArray : public FlatArray { raw_value_offsets_[i + 1] - pos); } + /// \brief Get binary value as a string_view + /// Provided for consistency with other arrays. + /// + /// \param i the value index + /// \return the view over the selected value + util::string_view Value(int64_t i) const { return GetView(i); } + /// \brief Get binary value as a std::string /// /// \param i the value index diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 688cb20cb9a..ed26ecff4e0 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -528,6 +528,11 @@ class RepeatedArrayFactory { return FinishFixedWidth(value.data(), value.size()); } + Status Visit(const Decimal256Type&) { + auto value = checked_cast(scalar_).value.ToBytes(); + return FinishFixedWidth(value.data(), value.size()); + } + template enable_if_base_binary Visit(const T&) { std::shared_ptr value = diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 9c1ef8533b4..a68969b2ee5 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -162,6 +162,11 @@ Result> NthToIndices(const Array& values, int64_t n, return result.make_array(); } +Result ReplaceWithMask(const Datum& values, const Datum& mask, + const Datum& replacements, ExecContext* ctx) { + return CallFunction("replace_with_mask", {values, mask, replacements}, ctx); +} + Result> SortIndices(const Array& values, SortOrder order, ExecContext* ctx) { ArraySortOptions options(order); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 6021492320e..9d8d4271db8 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -171,6 +171,23 @@ Result> GetTakeIndices( } // namespace internal +/// \brief ReplaceWithMask replaces each value in the array corresponding +/// to a true value in the mask with the next element from `replacements`. +/// +/// \param[in] values Array input to replace +/// \param[in] mask Array or Scalar of Boolean mask values +/// \param[in] replacements The replacement values to draw from. There must +/// be as many replacement values as true values in the mask. +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result ReplaceWithMask(const Datum& values, const Datum& mask, + const Datum& replacements, ExecContext* ctx = NULLPTR); + /// \brief Take from an array of values at indices in another array /// /// The output array will be of the same type as the input values diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 3362d91cbe8..f8cb67d82c4 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -48,6 +48,7 @@ add_arrow_compute_test(vector_test SOURCES vector_hash_test.cc vector_nested_test.cc + vector_replace_test.cc vector_selection_test.cc vector_sort_test.cc test_util.cc) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 33b7006491a..12e80423f7f 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1240,6 +1240,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { case Type::FLOAT: case Type::DATE32: case Type::TIME32: + case Type::INTERVAL_MONTHS: return Generator::Exec; case Type::UINT64: case Type::INT64: @@ -1248,6 +1249,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) { case Type::TIMESTAMP: case Type::TIME64: case Type::DURATION: + case Type::INTERVAL_DAY_TIME: return Generator::Exec; default: DCHECK(false); diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc new file mode 100644 index 00000000000..4b768608c62 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { +namespace compute { +namespace internal { + +namespace { + +Status ReplacementArrayTooShort(int64_t expected, int64_t actual) { + return Status::Invalid("Replacement array must be of appropriate length (expected ", + expected, " items but got ", actual, " items)"); +} + +// Helper to implement replace_with kernel with scalar mask for fixed-width types, +// using callbacks to handle both bool and byte-sized types +Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + if (!mask.is_valid) { + // Output = null + ARROW_ASSIGN_OR_RAISE(auto array, + MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); + *output = *array->data(); + return Status::OK(); + } + if (mask.value) { + // Output = replacement + if (replacements.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + auto replacement_array, + MakeArrayFromScalar(*replacements.scalar(), array.length, ctx->memory_pool())); + *output = *replacement_array->data(); + } else { + auto replacement_array = replacements.array(); + if (replacement_array->length != array.length) { + return ReplacementArrayTooShort(array.length, replacement_array->length); + } + *output = *replacement_array; + } + } else { + // Output = input + *output = array; + } + return Status::OK(); +} + +// Helper to implement replace_with kernel with array mask for fixed-width types, +// using callbacks to handle both bool and byte-sized types and to handle +// scalar and array replacements +template +Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + ARROW_ASSIGN_OR_RAISE(output->buffers[1], + Functor::AllocateData(ctx, *array.type, array.length)); + + uint8_t* out_bitmap = nullptr; + uint8_t* out_values = output->buffers[1]->mutable_data(); + const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; + const uint8_t* mask_values = mask.buffers[1]->data(); + bool replacements_bitmap; + int64_t replacements_length; + if (replacements.is_array()) { + replacements_bitmap = replacements.array()->MayHaveNulls(); + replacements_length = replacements.array()->length; + } else { + replacements_bitmap = !replacements.scalar()->is_valid; + replacements_length = std::numeric_limits::max(); + } + if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) { + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(array.length)); + out_bitmap = output->buffers[0]->mutable_data(); + output->null_count = -1; + if (array.MayHaveNulls()) { + arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length, + out_bitmap, /*dest_offset=*/0); + } else { + std::memset(out_bitmap, 0xFF, output->buffers[0]->size()); + } + } else { + output->null_count = 0; + } + auto copy_bitmap = [&](int64_t out_offset, int64_t in_offset, int64_t length) { + DCHECK(out_bitmap); + if (replacements.is_array()) { + const auto& in_data = *replacements.array(); + const auto in_bitmap = in_data.GetValues(0, /*absolute_offset=*/0); + arrow::internal::CopyBitmap(in_bitmap, in_data.offset + in_offset, length, + out_bitmap, out_offset); + } else { + BitUtil::SetBitsTo(out_bitmap, out_offset, length, !replacements_bitmap); + } + }; + + Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, + array.length); + arrow::internal::BitBlockCounter value_counter(mask_values, mask.offset, mask.length); + arrow::internal::OptionalBitBlockCounter valid_counter(mask_bitmap, mask.offset, + mask.length); + int64_t out_offset = 0; + int64_t replacements_offset = 0; + while (out_offset < array.length) { + BitBlockCount value_block = value_counter.NextWord(); + BitBlockCount valid_block = valid_counter.NextWord(); + DCHECK_EQ(value_block.length, valid_block.length); + if (value_block.AllSet() && valid_block.AllSet()) { + // Copy from replacement array + if (replacements_offset + valid_block.length > replacements_length) { + return ReplacementArrayTooShort(replacements_offset + valid_block.length, + replacements_length); + } + Functor::CopyData(*array.type, out_values, out_offset, replacements, + replacements_offset, valid_block.length); + if (replacements_bitmap) { + copy_bitmap(out_offset, replacements_offset, valid_block.length); + } else if (!replacements_bitmap && out_bitmap) { + BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, true); + } + replacements_offset += valid_block.length; + } else if (value_block.NoneSet() && valid_block.AllSet()) { + // Do nothing + } else if (valid_block.NoneSet()) { + DCHECK(out_bitmap); + BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, false); + } else { + for (int64_t i = 0; i < valid_block.length; ++i) { + if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && + (!mask_bitmap || + BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { + if (replacements_offset >= replacements_length) { + return ReplacementArrayTooShort(replacements_offset + 1, replacements_length); + } + Functor::CopyData(*array.type, out_values, out_offset + i, replacements, + replacements_offset, + /*length=*/1); + if (replacements_bitmap) { + copy_bitmap(out_offset + i, replacements_offset, 1); + } + replacements_offset++; + } + } + } + out_offset += valid_block.length; + } + + if (mask.MayHaveNulls()) { + arrow::internal::BitmapAnd(out_bitmap, /*left_offset=*/0, mask.buffers[0]->data(), + mask.offset, array.length, + /*out_offset=*/0, out_bitmap); + } + return Status::OK(); +} + +template +struct ReplaceWithMask {}; + +template +struct ReplaceWithMask> { + using T = typename TypeTraits::CType; + + static Result> AllocateData(KernelContext* ctx, const DataType&, + const int64_t length) { + return ctx->Allocate(length * sizeof(T)); + } + + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Datum& in, const int64_t in_offset, const int64_t length) { + if (in.is_array()) { + const auto& in_data = *in.array(); + const auto in_arr = + in_data.GetValues(1, (in_offset + in_data.offset) * sizeof(T)); + std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T)); + } else { + T* begin = reinterpret_cast(out + (out_offset * sizeof(T))); + T* end = begin + length; + std::fill(begin, end, UnboxScalar::Unbox(*in.scalar())); + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static Result> AllocateData(KernelContext* ctx, const DataType&, + const int64_t length) { + return ctx->AllocateBitmap(length); + } + + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Datum& in, const int64_t in_offset, const int64_t length) { + if (in.is_array()) { + const auto& in_data = *in.array(); + const auto in_arr = in_data.GetValues(1, /*absolute_offset=*/0); + arrow::internal::CopyBitmap(in_arr, in_offset + in_data.offset, length, out, + out_offset); + } else { + BitUtil::SetBitsTo(out, out_offset, length, in.scalar()->is_valid); + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static Result> AllocateData(KernelContext* ctx, + const DataType& ty, + const int64_t length) { + return ctx->Allocate(length * + checked_cast(ty).byte_width()); + } + + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Datum& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + if (in.is_array()) { + const auto& in_data = *in.array(); + const auto in_arr = + in_data.GetValues(1, (in_offset + in_data.offset) * width); + std::memcpy(begin, in_arr, length * width); + } else { + const FixedSizeBinaryScalar& scalar = + checked_cast(*in.scalar()); + // Null scalar may have null value buffer + if (!scalar.value) return; + const Buffer& buffer = *scalar.value; + const uint8_t* value = buffer.data(); + DCHECK_GE(buffer.size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value, width); + begin += width; + } + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + using ScalarType = typename TypeTraits::ScalarType; + + static Result> AllocateData(KernelContext* ctx, + const DataType& ty, + const int64_t length) { + return ctx->Allocate(length * + checked_cast(ty).byte_width()); + } + + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Datum& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + if (in.is_array()) { + const auto& in_data = *in.array(); + const auto in_arr = + in_data.GetValues(1, (in_offset + in_data.offset) * width); + std::memcpy(begin, in_arr, length * width); + } else { + const ScalarType& scalar = checked_cast(*in.scalar()); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value.data(), width); + begin += width; + } + } + } + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + } + + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithArrayMask>(ctx, array, mask, replacements, + output); + } +}; + +template +struct ReplaceWithMask> { + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + *output = array; + return Status::OK(); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + *output = array; + return Status::OK(); + } +}; + +template +struct ReplaceWithMask> { + using offset_type = typename Type::offset_type; + using BuilderType = typename TypeTraits::BuilderType; + + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, + const BooleanScalar& mask, const Datum& replacements, + ArrayData* output) { + return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + } + static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { + BuilderType builder(array.type, ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(array.length)); + RETURN_NOT_OK(builder.ReserveData(array.buffers[2]->size())); + int64_t source_offset = 0; + int64_t replacements_offset = 0; + RETURN_NOT_OK(VisitArrayDataInline( + mask, + [&](bool replace) { + if (replace && replacements.is_scalar()) { + const Scalar& scalar = *replacements.scalar(); + if (scalar.is_valid) { + RETURN_NOT_OK(builder.Append(UnboxScalar::Unbox(scalar))); + } else { + RETURN_NOT_OK(builder.AppendNull()); + } + } else { + const ArrayData& source = replace ? *replacements.array() : array; + const int64_t offset = replace ? replacements_offset++ : source_offset; + if (!source.MayHaveNulls() || + BitUtil::GetBit(source.buffers[0]->data(), source.offset + offset)) { + const uint8_t* data = source.buffers[2]->data(); + const offset_type* offsets = source.GetValues(1); + const offset_type offset0 = offsets[offset]; + const offset_type offset1 = offsets[offset + 1]; + RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0)); + } else { + RETURN_NOT_OK(builder.AppendNull()); + } + } + source_offset++; + return Status::OK(); + }, + [&]() { + RETURN_NOT_OK(builder.AppendNull()); + source_offset++; + return Status::OK(); + })); + std::shared_ptr temp_output; + RETURN_NOT_OK(builder.Finish(&temp_output)); + *output = *temp_output->data(); + // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase + output->type = array.type; + return Status::OK(); + } +}; + +template +struct ReplaceWithMaskFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& array = *batch[0].array(); + const Datum& replacements = batch[2]; + ArrayData* output = out->array().get(); + output->length = array.length; + + // Needed for FixedSizeBinary/parameterized types + if (!array.type->Equals(*replacements.type(), /*check_metadata=*/false)) { + return Status::Invalid("Replacements must be of same type (expected ", + array.type->ToString(), " but got ", + replacements.type()->ToString(), ")"); + } + + if (!replacements.is_array() && !replacements.is_scalar()) { + return Status::Invalid("Replacements must be array or scalar"); + } + + if (batch[1].is_scalar()) { + return ReplaceWithMask::ExecScalarMask( + ctx, array, batch[1].scalar_as(), replacements, output); + } + const ArrayData& mask = *batch[1].array(); + if (array.length != mask.length) { + return Status::Invalid("Mask must be of same length as array (expected ", + array.length, " items but got ", mask.length, " items)"); + } + return ReplaceWithMask::ExecArrayMask(ctx, array, mask, replacements, output); + } +}; + +} // namespace + +const FunctionDoc replace_with_mask_doc( + "Replace items using a mask and replacement values", + ("Given an array and a Boolean mask (either scalar or of equal length), " + "along with replacement values (either scalar or array), " + "each element of the array for which the corresponding mask element is " + "true will be replaced by the next value from the replacements, " + "or with null if the mask is null. " + "Hence, for replacement arrays, len(replacements) == sum(mask == true)."), + {"values", "mask", "replacements"}); + +void RegisterVectorReplace(FunctionRegistry* registry) { + auto func = std::make_shared("replace_with_mask", Arity::Ternary(), + &replace_with_mask_doc); + auto add_kernel = [&](detail::GetTypeId get_id, ArrayKernelExec exec) { + VectorKernel kernel; + kernel.can_execute_chunkwise = false; + kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE; + kernel.signature = KernelSignature::Make( + {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)}, + OutputType(FirstType)); + kernel.exec = std::move(exec); + DCHECK_OK(func->AddKernel(std::move(kernel))); + }; + auto add_primitive_kernel = [&](detail::GetTypeId get_id) { + add_kernel(get_id, GenerateTypeAgnosticPrimitive(get_id)); + }; + for (const auto& ty : NumericTypes()) { + add_primitive_kernel(ty); + } + for (const auto& ty : TemporalTypes()) { + add_primitive_kernel(ty); + } + add_primitive_kernel(null()); + add_primitive_kernel(boolean()); + add_primitive_kernel(day_time_interval()); + add_primitive_kernel(month_interval()); + add_kernel(Type::FIXED_SIZE_BINARY, ReplaceWithMaskFunctor::Exec); + add_kernel(Type::DECIMAL128, ReplaceWithMaskFunctor::Exec); + add_kernel(Type::DECIMAL256, ReplaceWithMaskFunctor::Exec); + for (const auto& ty : BaseBinaryTypes()) { + add_kernel(ty->id(), GenerateTypeAgnosticVarBinaryBase(*ty)); + } + // TODO: list types + DCHECK_OK(registry->AddFunction(std::move(func))); + + // TODO(ARROW-9431): "replace_with_indices" +} +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc new file mode 100644 index 00000000000..1826b037034 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -0,0 +1,736 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" +#include "arrow/util/make_unique.h" + +namespace arrow { +namespace compute { + +using arrow::internal::checked_pointer_cast; + +namespace { +template +enable_if_parameter_free> default_type_instance() { + return TypeTraits::type_singleton(); +} +template +enable_if_time> default_type_instance() { + // Time32 requires second/milli, Time64 requires nano/micro + if (TypeTraits::bytes_required(1) == 4) { + return std::make_shared(TimeUnit::type::SECOND); + } else { + return std::make_shared(TimeUnit::type::NANO); + } +} +template +enable_if_timestamp> default_type_instance() { + return std::make_shared(TimeUnit::type::SECOND); +} +template +enable_if_decimal> default_type_instance() { + return std::make_shared(5, 2); +} +template +enable_if_parameter_free::BuilderType>> +builder_instance() { + return arrow::internal::make_unique::BuilderType>(); +} +template +enable_if_time::BuilderType>> +builder_instance() { + return arrow::internal::make_unique::BuilderType>( + default_type_instance(), default_memory_pool()); +} +template +enable_if_timestamp::BuilderType>> +builder_instance() { + return arrow::internal::make_unique::BuilderType>( + default_type_instance(), default_memory_pool()); +} +template +enable_if_t::value, T> max_int_value() { + return static_cast( + std::min(16384.0, static_cast(std::numeric_limits::max()))); +} +template +enable_if_t::value, T> max_int_value() { + return static_cast( + std::min(16384.0, static_cast(std::numeric_limits::max()))); +} +} // namespace + +template +class TestReplaceKernel : public ::testing::Test { + protected: + virtual std::shared_ptr type() = 0; + + using ReplaceFunction = std::function(const Datum&, const Datum&, + const Datum&, ExecContext*)>; + + void SetUp() override { equal_options_ = equal_options_.nans_equal(true); } + + Datum mask_scalar(bool value) { return Datum(std::make_shared(value)); } + + Datum null_mask_scalar() { + auto scalar = std::make_shared(true); + scalar->is_valid = false; + return Datum(std::move(scalar)); + } + + Datum scalar(const std::string& json) { return ScalarFromJSON(type(), json); } + + std::shared_ptr array(const std::string& value) { + return ArrayFromJSON(type(), value); + } + + std::shared_ptr mask(const std::string& value) { + return ArrayFromJSON(boolean(), value); + } + + Status AssertRaises(ReplaceFunction func, const std::shared_ptr& array, + const Datum& mask, const std::shared_ptr& replacements) { + auto result = func(array, mask, replacements, nullptr); + EXPECT_FALSE(result.ok()); + return result.status(); + } + + void Assert(ReplaceFunction func, const std::shared_ptr& array, + const Datum& mask, Datum replacements, + const std::shared_ptr& expected) { + SCOPED_TRACE("Replacements: " + (replacements.is_array() + ? replacements.make_array()->ToString() + : replacements.scalar()->ToString())); + SCOPED_TRACE("Mask: " + (mask.is_array() ? mask.make_array()->ToString() + : mask.scalar()->ToString())); + SCOPED_TRACE("Array: " + array->ToString()); + + ASSERT_OK_AND_ASSIGN(auto actual, func(array, mask, replacements, nullptr)); + ASSERT_TRUE(actual.is_array()); + ASSERT_OK(actual.make_array()->ValidateFull()); + + AssertArraysApproxEqual(*expected, *actual.make_array(), /*verbose=*/true, + equal_options_); + } + + std::shared_ptr NaiveImpl( + const typename TypeTraits::ArrayType& array, const BooleanArray& mask, + const typename TypeTraits::ArrayType& replacements) { + auto length = array.length(); + auto builder = builder_instance(); + int64_t replacement_offset = 0; + for (int64_t i = 0; i < length; ++i) { + if (mask.IsValid(i)) { + if (mask.Value(i)) { + if (replacements.IsValid(replacement_offset)) { + ARROW_EXPECT_OK(builder->Append(replacements.Value(replacement_offset++))); + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + replacement_offset++; + } + } else { + if (array.IsValid(i)) { + ARROW_EXPECT_OK(builder->Append(array.Value(i))); + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + } + } + } else { + ARROW_EXPECT_OK(builder->AppendNull()); + } + } + EXPECT_OK_AND_ASSIGN(auto expected, builder->Finish()); + return expected; + } + + EqualOptions equal_options_ = EqualOptions::Defaults(); +}; + +template +class TestReplaceNumeric : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +class TestReplaceBoolean : public TestReplaceKernel { + protected: + std::shared_ptr type() override { + return TypeTraits::type_singleton(); + } +}; + +class TestReplaceFixedSizeBinary : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return fixed_size_binary(3); } +}; + +template +class TestReplaceDecimal : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +class TestReplaceDayTimeInterval : public TestReplaceKernel { + protected: + std::shared_ptr type() override { + return TypeTraits::type_singleton(); + } +}; + +template +class TestReplaceBinary : public TestReplaceKernel { + protected: + std::shared_ptr type() override { return default_type_instance(); } +}; + +using NumericBasedTypes = + ::testing::Types; + +TYPED_TEST_SUITE(TestReplaceNumeric, NumericBasedTypes); +TYPED_TEST_SUITE(TestReplaceDecimal, DecimalArrowTypes); +TYPED_TEST_SUITE(TestReplaceBinary, BinaryTypes); + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(false), + this->array("[]"), this->array("[1]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), + this->array("[0]"), this->array("[0]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(false), + this->scalar("1"), this->array("[0, 0]")); + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true), + this->scalar("1"), this->array("[1, 1]")); + this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[0, 1, 2, 3]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"), + this->array("[10, 11, 12, 13]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[0, 1, 2, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"), + this->array("[10, 11, 12, 13]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3, 4, 5]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[10, null]"), this->array("[10, null, 2, 3, null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[10, null]"), + this->array("[10, null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("1"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"), + this->scalar("10"), this->array("[10, 10]")); + this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[0, 1, 2]"), + this->mask("[true, false, null]"), this->scalar("10"), + this->array("[10, 1, null]")); +} + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; + auto ty = this->type(); + + random::RandomArrayGenerator rand(/*seed=*/0); + const int64_t length = 1023; + // Clamp the range because date/time types don't print well with extreme values + std::vector values = {"0.01", "0"}; + values.push_back(std::to_string(max_int_value())); + auto options = key_value_metadata({"null_probability", "min", "max"}, values); + auto array = + checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length)); + auto mask = checked_pointer_cast( + rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); + const int64_t num_replacements = std::count_if( + mask->begin(), mask->end(), + [](util::optional value) { return value.has_value() && *value; }); + auto replacements = checked_pointer_cast( + rand.ArrayOf(*field("a", ty, options), num_replacements)); + auto expected = this->NaiveImpl(*array, *mask, *replacements); + + this->Assert(ReplaceWithMask, array, mask, replacements, expected); + for (int64_t slice = 1; slice <= 16; slice++) { + auto sliced_array = checked_pointer_cast(array->Slice(slice, 15)); + auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15)); + auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements); + this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected); + } +} + +TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 2 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), + this->array("[0, 1]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " + "items but got 1 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), + this->mask("[true, true]"), this->array("[0]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), + this->mask("[true, null]"), this->array("[]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Mask must be of same length as array (expected 2 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), this->mask("[]"), + this->array("[]"))); +} + +TEST_F(TestReplaceBoolean, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(false), + this->array("[]"), this->array("[true]")); + this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true), + this->array("[false]"), this->array("[false]")); + this->Assert(ReplaceWithMask, this->array("[true]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(false), + this->scalar("true"), this->array("[false, false]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true), + this->scalar("true"), this->array("[true, true]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[true, true, true, true]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[true, true, true, true]"), + this->array("[false, false, false, false]"), + this->array("[false, false, false, false]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[true, true, true, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[true, true, true, true]"), + this->array("[false, false, false, false]"), + this->array("[false, false, false, false]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[true, true, true, true, true, true]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[false, null]"), + this->array("[false, null, true, true, null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[false, null]"), + this->array("[false, null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), + this->scalar("true"), this->array("[true, true]")); + this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[false, false, false]"), + this->mask("[true, false, null]"), this->scalar("true"), + this->array("[true, false, null]")); +} + +TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 2 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true), + this->array("[false, false]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " + "items but got 1 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), + this->mask("[true, true]"), this->array("[false]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), + this->mask("[true, null]"), this->array("[]"))); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::HasSubstr("Mask must be of same length as array (expected 2 " + "items but got 0 items)"), + this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), this->mask("[]"), + this->array("[]"))); +} + +TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["foo"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true), + this->array(R"(["bar"])"), this->array(R"(["bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), + this->mask_scalar(false), this->scalar(R"("baz")"), + this->array(R"(["foo", "bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar("null"), this->array(R"([null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["aaa", "bbb", "ccc", "ddd"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eee", "fff", "ggg", "hhh"])"), + this->array(R"(["eee", "fff", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["aaa", "bbb", "ccc", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eee", "fff", "ggg", "hhh"])"), + this->array(R"(["eee", "fff", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, + this->array(R"(["aaa", "bbb", "ccc", "ddd", "eee", "fff"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["ggg", null])"), + this->array(R"(["ggg", null, "ccc", "ddd", null, null])")); + this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["aaa", null])"), + this->array(R"(["aaa", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("zzz")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"), + this->mask("[true, true]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "zzz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc"])"), + this->mask("[true, false, null]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "bbb", null])")); +} + +TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMaskErrors) { + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, + ::testing::AllOf( + ::testing::HasSubstr("Replacements must be of same type (expected "), + ::testing::HasSubstr(this->type()->ToString()), + ::testing::HasSubstr("but got fixed_size_binary[2]")), + this->AssertRaises(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + ArrayFromJSON(fixed_size_binary(2), "[]"))); +} + +TYPED_TEST(TestReplaceDecimal, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["1.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(true), + this->array(R"(["0.00"])"), this->array(R"(["0.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(false), this->scalar(R"("1.00")"), + this->array(R"(["0.00", "0.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(true), this->scalar(R"("1.00")"), + this->array(R"(["1.00", "1.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"), + this->mask_scalar(true), this->scalar("null"), + this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["0.00", "1.00", "2.00", "3.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["0.00", "1.00", "2.00", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])"), + this->array(R"(["10.00", "11.00", "12.00", "13.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, + this->array(R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["10.00", null])"), + this->array(R"(["10.00", null, "2.00", "3.00", null, null])")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["10.00", null])"), + this->array(R"(["10.00", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("1.00")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"), + this->mask("[true, true]"), this->scalar(R"("10.00")"), + this->array(R"(["10.00", "10.00"])")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00"])"), + this->mask("[true, false, null]"), this->scalar(R"("10.00")"), + this->array(R"(["10.00", "1.00", null])")); +} + +TEST_F(TestReplaceDayTimeInterval, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(false), + this->array("[]"), this->array("[[1, 2]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(true), + this->array("[[3, 4]]"), this->array("[[3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(false), + this->scalar("[7, 8]"), this->array("[[1, 2], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true), + this->scalar("[7, 8]"), this->array("[[7, 8], [7, 8]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true), + this->scalar("null"), this->array("[null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[true, true, true, true]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array("[[1, 2], [1, 2], [1, 2], null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[true, true, true, true]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"), + this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array("[null, null, null, null]")); + this->Assert( + ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2], [1, 2], [1, 2]]"), + this->mask("[true, true, false, false, null, null]"), this->array("[[3, 4], null]"), + this->array("[[3, 4], null, [1, 2], [1, 2], null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"), + this->mask("[true, true, false, false, null, null]"), + this->array("[[3, 4], null]"), + this->array("[[3, 4], null, null, null, null, null]")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar("[7, 8]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), + this->mask("[true, true]"), this->scalar("[7, 8]"), + this->array("[[7, 8], [7, 8]]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), + this->mask("[true, true]"), this->scalar("null"), + this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4], [5, 6]]"), + this->mask("[true, false, null]"), this->scalar("[7, 8]"), + this->array("[[7, 8], [3, 4], null]")); +} + +TYPED_TEST(TestReplaceBinary, ReplaceWithMask) { + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true), + this->array("[]"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(), + this->array("[]"), this->array("[]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false), + this->array("[]"), this->array(R"(["foo"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true), + this->array(R"(["bar"])"), this->array(R"(["bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(), + this->array("[]"), this->array("[null]")); + + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), + this->mask_scalar(false), this->scalar(R"("baz")"), + this->array(R"(["foo", "bar"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true), + this->scalar("null"), this->array(R"([null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"), + this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["a", "bb", "ccc", "dddd"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[false, false, false, false]"), this->array("[]"), + this->array(R"(["a", "bb", "ccc", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[true, true, true, true]"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])"), + this->array(R"(["eeeee", "f", "ggg", "hhh"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"), + this->mask("[null, null, null, null]"), this->array("[]"), + this->array(R"([null, null, null, null])")); + this->Assert(ReplaceWithMask, + this->array(R"(["a", "bb", "ccc", "dddd", "eeeee", "f"])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["ggg", null])"), + this->array(R"(["ggg", null, "ccc", "dddd", null, null])")); + this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"), + this->mask("[true, true, false, false, null, null]"), + this->array(R"(["a", null])"), + this->array(R"(["a", null, null, null, null, null])")); + + this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), + this->scalar(R"("zzz")"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"), + this->scalar(R"("zzz")"), this->array(R"(["zzz", "zzz"])")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"), + this->scalar("null"), this->array("[null, null]")); + this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc"])"), + this->mask("[true, false, null]"), this->scalar(R"("zzz")"), + this->array(R"(["zzz", "bb", null])")); +} + +TYPED_TEST(TestReplaceBinary, ReplaceWithMaskRandom) { + using ArrayType = typename TypeTraits::ArrayType; + auto ty = this->type(); + + random::RandomArrayGenerator rand(/*seed=*/0); + const int64_t length = 1023; + auto options = key_value_metadata({{"null_probability", "0.01"}, {"max_length", "5"}}); + auto array = + checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length)); + auto mask = checked_pointer_cast( + rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); + const int64_t num_replacements = std::count_if( + mask->begin(), mask->end(), + [](util::optional value) { return value.has_value() && *value; }); + auto replacements = checked_pointer_cast( + rand.ArrayOf(*field("a", ty, options), num_replacements)); + auto expected = this->NaiveImpl(*array, *mask, *replacements); + + this->Assert(ReplaceWithMask, array, mask, replacements, expected); + for (int64_t slice = 1; slice <= 16; slice++) { + auto sliced_array = checked_pointer_cast(array->Slice(slice, 15)); + auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15)); + auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements); + this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 8a0d9e62518..ca7b6137306 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -168,6 +168,7 @@ static std::unique_ptr CreateBuiltInRegistry() { // Vector functions RegisterVectorHash(registry.get()); + RegisterVectorReplace(registry.get()); RegisterVectorSelection(registry.get()); RegisterVectorNested(registry.get()); RegisterVectorSort(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index dd0271eb43d..892b54341da 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -41,6 +41,7 @@ void RegisterScalarOptions(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); +void RegisterVectorReplace(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 86664bbb162..e4d809967f9 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -233,6 +233,7 @@ struct TypeTraits { using ArrayType = MonthIntervalArray; using BuilderType = MonthIntervalBuilder; using ScalarType = MonthIntervalScalar; + using CType = MonthIntervalType::c_type; static constexpr int64_t bytes_required(int64_t elements) { return elements * static_cast(sizeof(int32_t)); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index fc6c8b7c7e1..6d857ed6822 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1090,6 +1090,22 @@ Associative transforms Each output element corresponds to a unique value in the input, along with the number of times this value has appeared. +Replacements +~~~~~~~~~~~~ + +These functions create a copy of the first input with some elements +replaced, based on the remaining inputs. + ++-------------------+------------+-----------------------+--------------+--------------+--------------+-------+ +| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes | ++===================+============+=======================+==============+==============+==============+=======+ +| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) | ++-------------------+------------+-----------------------+--------------+--------------+--------------+-------+ + +* \(1) Each element in input 1 for which the corresponding Boolean in input 2 + is true is replaced with the next value from input 3. A null in input 2 + results in a corresponding null in the output. + Selections ~~~~~~~~~~ diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index a611d2a2384..09c67598193 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -292,6 +292,14 @@ Conversions cast strptime +Replacements +------------ + +.. autosummary:: + :toctree: ../generated/ + + replace_with_mask + Selections ---------- From ccca017bba32a6ea9ffc5287c09eec725faa4462 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 15 Jun 2021 08:22:35 -0400 Subject: [PATCH 02/15] ARROW-9430: [C++] Clarify replace_with_mask implementation --- cpp/src/arrow/compute/kernels/vector_replace.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 4b768608c62..7c0edb78226 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -37,9 +37,9 @@ Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, ArrayData* output) { if (!mask.is_valid) { // Output = null - ARROW_ASSIGN_OR_RAISE(auto array, + ARROW_ASSIGN_OR_RAISE(auto replacement_array, MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); - *output = *array->data(); + *output = *replacement_array->data(); return Status::OK(); } if (mask.value) { @@ -91,9 +91,11 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, out_bitmap = output->buffers[0]->mutable_data(); output->null_count = -1; if (array.MayHaveNulls()) { + // Copy array's bitmap arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length, out_bitmap, /*dest_offset=*/0); } else { + // Array has no bitmap but mask/replacements do, generate an all-valid bitmap std::memset(out_bitmap, 0xFF, output->buffers[0]->size()); } } else { @@ -136,11 +138,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, true); } replacements_offset += valid_block.length; - } else if (value_block.NoneSet() && valid_block.AllSet()) { + } else if ((value_block.NoneSet() && valid_block.AllSet()) || valid_block.NoneSet()) { // Do nothing - } else if (valid_block.NoneSet()) { - DCHECK(out_bitmap); - BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, false); } else { for (int64_t i = 0; i < valid_block.length; ++i) { if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && From ea1517a3db921da1118d90fa517d9773b6b99851 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 15 Jun 2021 15:01:18 -0400 Subject: [PATCH 03/15] ARROW-9430: [C++] Cross-reference if_else and replace_with_mask --- docs/source/cpp/compute.rst | 38 ++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 6d857ed6822..00391052b1e 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -850,6 +850,7 @@ in reverse order. as given by :struct:`SliceOptions` where ``start`` and ``stop`` are measured in codeunits. Null inputs emit null. +.. _cpp-compute-scalar-structural-transforms: Structural transforms ~~~~~~~~~~~~~~~~~~~~~ @@ -861,7 +862,7 @@ Structural transforms +==========================+============+================================================+=====================+=========+ | fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type + \(2) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(2) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ | is_finite | Unary | Float, Double | Boolean | \(3) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ @@ -888,6 +889,8 @@ Structural transforms input. If the nulls present on the first input, they will be promoted to the output, otherwise nulls will be chosen based on the first input values. + Also see: :ref:`replace_with_mask `. + * \(3) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). @@ -1090,22 +1093,6 @@ Associative transforms Each output element corresponds to a unique value in the input, along with the number of times this value has appeared. -Replacements -~~~~~~~~~~~~ - -These functions create a copy of the first input with some elements -replaced, based on the remaining inputs. - -+-------------------+------------+-----------------------+--------------+--------------+--------------+-------+ -| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes | -+===================+============+=======================+==============+==============+==============+=======+ -| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) | -+-------------------+------------+-----------------------+--------------+--------------+--------------+-------+ - -* \(1) Each element in input 1 for which the corresponding Boolean in input 2 - is true is replaced with the next value from input 3. A null in input 2 - results in a corresponding null in the output. - Selections ~~~~~~~~~~ @@ -1170,6 +1157,8 @@ value, but smaller than nulls. table. If the input is a record batch or table, one or more sort keys must be specified. +.. _cpp-compute-vector-structural-transforms: + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ @@ -1188,3 +1177,18 @@ Structural transforms * \(2) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. + +These functions create a copy of the first input with some elements +replaced, based on the remaining inputs. + ++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+ +| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes | ++==========================+============+=======================+==============+==============+==============+=======+ +| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) | ++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+ + +* \(1) Each element in input 1 for which the corresponding Boolean in input 2 + is true is replaced with the next value from input 3. A null in input 2 + results in a corresponding null in the output. + + Also see: :ref:`if_else `. From 853eee3388fb4f5b120d0e8f8f11602a48ae335a Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 25 Jun 2021 14:19:22 -0400 Subject: [PATCH 04/15] ARROW-9430: [C++] Clean up tests slightly --- .../compute/kernels/vector_replace_test.cc | 41 ++++--------------- 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index 1826b037034..45e99dbf777 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -32,6 +32,7 @@ namespace compute { using arrow::internal::checked_pointer_cast; namespace { +// Helper to get a default instance of a type, including parameterized types template enable_if_parameter_free> default_type_instance() { return TypeTraits::type_singleton(); @@ -39,11 +40,10 @@ enable_if_parameter_free> default_type_instance() { template enable_if_time> default_type_instance() { // Time32 requires second/milli, Time64 requires nano/micro - if (TypeTraits::bytes_required(1) == 4) { + if (bit_width(T::type_id) == 32) { return std::make_shared(TimeUnit::type::SECOND); - } else { - return std::make_shared(TimeUnit::type::NANO); } + return std::make_shared(TimeUnit::type::NANO); } template enable_if_timestamp> default_type_instance() { @@ -53,33 +53,6 @@ template enable_if_decimal> default_type_instance() { return std::make_shared(5, 2); } -template -enable_if_parameter_free::BuilderType>> -builder_instance() { - return arrow::internal::make_unique::BuilderType>(); -} -template -enable_if_time::BuilderType>> -builder_instance() { - return arrow::internal::make_unique::BuilderType>( - default_type_instance(), default_memory_pool()); -} -template -enable_if_timestamp::BuilderType>> -builder_instance() { - return arrow::internal::make_unique::BuilderType>( - default_type_instance(), default_memory_pool()); -} -template -enable_if_t::value, T> max_int_value() { - return static_cast( - std::min(16384.0, static_cast(std::numeric_limits::max()))); -} -template -enable_if_t::value, T> max_int_value() { - return static_cast( - std::min(16384.0, static_cast(std::numeric_limits::max()))); -} } // namespace template @@ -139,7 +112,8 @@ class TestReplaceKernel : public ::testing::Test { const typename TypeTraits::ArrayType& array, const BooleanArray& mask, const typename TypeTraits::ArrayType& replacements) { auto length = array.length(); - auto builder = builder_instance(); + auto builder = arrow::internal::make_unique::BuilderType>( + default_type_instance(), default_memory_pool()); int64_t replacement_offset = 0; for (int64_t i = 0; i < length; ++i) { if (mask.IsValid(i)) { @@ -282,9 +256,10 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { random::RandomArrayGenerator rand(/*seed=*/0); const int64_t length = 1023; - // Clamp the range because date/time types don't print well with extreme values std::vector values = {"0.01", "0"}; - values.push_back(std::to_string(max_int_value())); + // Clamp the range because date/time types don't print well with extreme values + values.push_back(std::to_string(static_cast(std::min( + 16384.0, static_cast(std::numeric_limits::max()))))); auto options = key_value_metadata({"null_probability", "min", "max"}, values); auto array = checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length)); From e15493000020ef0ec5060e4a5ac138841332a6a2 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Jun 2021 09:10:34 -0400 Subject: [PATCH 05/15] ARROW-9430: [C++] Move test helper into test_util.h --- cpp/src/arrow/compute/kernels/test_util.h | 22 +++++++++++++++++ .../compute/kernels/vector_replace_test.cc | 24 ------------------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index f4854087b51..c691a9f3be3 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -172,5 +172,27 @@ void CheckDispatchBest(std::string func_name, std::vector descrs, // Check that function fails to produce a Kernel for the set of ValueDescrs. void CheckDispatchFails(std::string func_name, std::vector descrs); +// Helper to get a default instance of a type, including parameterized types +template +enable_if_parameter_free> default_type_instance() { + return TypeTraits::type_singleton(); +} +template +enable_if_time> default_type_instance() { + // Time32 requires second/milli, Time64 requires nano/micro + if (bit_width(T::type_id) == 32) { + return std::make_shared(TimeUnit::type::SECOND); + } + return std::make_shared(TimeUnit::type::NANO); +} +template +enable_if_timestamp> default_type_instance() { + return std::make_shared(TimeUnit::type::SECOND); +} +template +enable_if_decimal> default_type_instance() { + return std::make_shared(5, 2); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index 45e99dbf777..a55ed709a0f 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -31,30 +31,6 @@ namespace compute { using arrow::internal::checked_pointer_cast; -namespace { -// Helper to get a default instance of a type, including parameterized types -template -enable_if_parameter_free> default_type_instance() { - return TypeTraits::type_singleton(); -} -template -enable_if_time> default_type_instance() { - // Time32 requires second/milli, Time64 requires nano/micro - if (bit_width(T::type_id) == 32) { - return std::make_shared(TimeUnit::type::SECOND); - } - return std::make_shared(TimeUnit::type::NANO); -} -template -enable_if_timestamp> default_type_instance() { - return std::make_shared(TimeUnit::type::SECOND); -} -template -enable_if_decimal> default_type_instance() { - return std::make_shared(5, 2); -} -} // namespace - template class TestReplaceKernel : public ::testing::Test { protected: From 97fabbe4e3891dedbdca7001ca7163fa8444ad00 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 28 Jun 2021 09:40:43 -0400 Subject: [PATCH 06/15] ARROW-9430: [C++] Take advantage of preallocation --- .../arrow/compute/kernels/vector_replace.cc | 30 +------------------ 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 7c0edb78226..9f3aeb2ca7e 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -70,9 +70,6 @@ template Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, const ArrayData& mask, const Datum& replacements, ArrayData* output) { - ARROW_ASSIGN_OR_RAISE(output->buffers[1], - Functor::AllocateData(ctx, *array.type, array.length)); - uint8_t* out_bitmap = nullptr; uint8_t* out_values = output->buffers[1]->mutable_data(); const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; @@ -176,11 +173,6 @@ template struct ReplaceWithMask> { using T = typename TypeTraits::CType; - static Result> AllocateData(KernelContext* ctx, const DataType&, - const int64_t length) { - return ctx->Allocate(length * sizeof(T)); - } - static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, const Datum& in, const int64_t in_offset, const int64_t length) { if (in.is_array()) { @@ -211,11 +203,6 @@ struct ReplaceWithMask> { template struct ReplaceWithMask> { - static Result> AllocateData(KernelContext* ctx, const DataType&, - const int64_t length) { - return ctx->AllocateBitmap(length); - } - static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, const Datum& in, const int64_t in_offset, const int64_t length) { if (in.is_array()) { @@ -243,13 +230,6 @@ struct ReplaceWithMask> { template struct ReplaceWithMask> { - static Result> AllocateData(KernelContext* ctx, - const DataType& ty, - const int64_t length) { - return ctx->Allocate(length * - checked_cast(ty).byte_width()); - } - static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, const Datum& in, const int64_t in_offset, const int64_t length) { const int32_t width = checked_cast(ty).byte_width(); @@ -291,14 +271,6 @@ struct ReplaceWithMask> { template struct ReplaceWithMask> { using ScalarType = typename TypeTraits::ScalarType; - - static Result> AllocateData(KernelContext* ctx, - const DataType& ty, - const int64_t length) { - return ctx->Allocate(length * - checked_cast(ty).byte_width()); - } - static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, const Datum& in, const int64_t in_offset, const int64_t length) { const int32_t width = checked_cast(ty).byte_width(); @@ -458,7 +430,7 @@ void RegisterVectorReplace(FunctionRegistry* registry) { VectorKernel kernel; kernel.can_execute_chunkwise = false; kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; - kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::type::PREALLOCATE; kernel.signature = KernelSignature::Make( {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)}, OutputType(FirstType)); From 68b733b300cccdc3ea297dbf61f7f89a0d3aac7b Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jul 2021 14:58:01 -0400 Subject: [PATCH 07/15] ARROW-9430: [C++] Clean up impl --- .../arrow/compute/kernels/vector_replace.cc | 31 ++++++++----------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 9f3aeb2ca7e..a7515d9065d 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -112,33 +112,28 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, array.length); - arrow::internal::BitBlockCounter value_counter(mask_values, mask.offset, mask.length); - arrow::internal::OptionalBitBlockCounter valid_counter(mask_bitmap, mask.offset, - mask.length); + arrow::internal::OptionalBinaryBitBlockCounter counter( + mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); int64_t out_offset = 0; int64_t replacements_offset = 0; while (out_offset < array.length) { - BitBlockCount value_block = value_counter.NextWord(); - BitBlockCount valid_block = valid_counter.NextWord(); - DCHECK_EQ(value_block.length, valid_block.length); - if (value_block.AllSet() && valid_block.AllSet()) { + BitBlockCount block = counter.NextAndBlock(); + if (block.AllSet()) { // Copy from replacement array - if (replacements_offset + valid_block.length > replacements_length) { - return ReplacementArrayTooShort(replacements_offset + valid_block.length, + if (replacements_offset + block.length > replacements_length) { + return ReplacementArrayTooShort(replacements_offset + block.length, replacements_length); } Functor::CopyData(*array.type, out_values, out_offset, replacements, - replacements_offset, valid_block.length); + replacements_offset, block.length); if (replacements_bitmap) { - copy_bitmap(out_offset, replacements_offset, valid_block.length); + copy_bitmap(out_offset, replacements_offset, block.length); } else if (!replacements_bitmap && out_bitmap) { - BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, true); + BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, true); } - replacements_offset += valid_block.length; - } else if ((value_block.NoneSet() && valid_block.AllSet()) || valid_block.NoneSet()) { - // Do nothing - } else { - for (int64_t i = 0; i < valid_block.length; ++i) { + replacements_offset += block.length; + } else if (block.popcount) { + for (int64_t i = 0; i < block.length; ++i) { if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && (!mask_bitmap || BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { @@ -155,7 +150,7 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, } } } - out_offset += valid_block.length; + out_offset += block.length; } if (mask.MayHaveNulls()) { From ba75c90b79e87b2264b59739a9488d04bcb70c0c Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jul 2021 15:25:47 -0400 Subject: [PATCH 08/15] ARROW-9430: [C++] Add simple benchmark --- cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../kernels/vector_replace_benchmark.cc | 87 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index f8cb67d82c4..474ce1418fd 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -56,6 +56,7 @@ add_arrow_compute_test(vector_test add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute") # ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc new file mode 100644 index 00000000000..12cc623b7cb --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/array.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +#include "arrow/compute/api_vector.h" + +namespace arrow { +namespace compute { + +using ::arrow::internal::checked_pointer_cast; + +static constexpr random::SeedType kRandomSeed = 0xabcdef; +static constexpr random::SeedType kLongLength = 16384; + +static std::shared_ptr MakeReplacements(random::RandomArrayGenerator* generator, + const BooleanArray& mask) { + int64_t count = 0; + for (int64_t i = 0; i < mask.length(); i++) { + count += mask.Value(i) && mask.IsValid(i); + } + return generator->Int64(count, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1); +} + +static void ReplaceWithMaskLowSelectivityBench( + benchmark::State& state) { // NOLINT non-const reference + random::RandomArrayGenerator generator(kRandomSeed); + const int64_t len = state.range(0); + const int64_t offset = state.range(1); + + auto values = generator.Int64(len, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1) + ->Slice(offset); + auto mask = checked_pointer_cast( + generator.Boolean(len, /*true_probability=*/0.1, /*null_probability=*/0.1) + ->Slice(offset)); + auto replacements = MakeReplacements(&generator, *mask); + + for (auto _ : state) { + ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements)); + } + state.SetBytesProcessed(state.iterations() * (len - offset) * 8); +} + +static void ReplaceWithMaskHighSelectivityBench( + benchmark::State& state) { // NOLINT non-const reference + random::RandomArrayGenerator generator(kRandomSeed); + const int64_t len = state.range(0); + const int64_t offset = state.range(1); + + auto values = generator.Int64(len, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1) + ->Slice(offset); + auto mask = checked_pointer_cast( + generator.Boolean(len, /*true_probability=*/0.9, /*null_probability=*/0.1) + ->Slice(offset)); + auto replacements = MakeReplacements(&generator, *mask); + + for (auto _ : state) { + ABORT_NOT_OK(ReplaceWithMask(values, mask, replacements)); + } + state.SetBytesProcessed(state.iterations() * (len - offset) * 8); +} + +BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 0}); +BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 0}); +BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 99}); +BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 99}); + +} // namespace compute +} // namespace arrow From 4ba9837938a0d972237f18f14c12ad89b4357717 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jul 2021 15:34:51 -0400 Subject: [PATCH 09/15] ARROW-9430: [C++] Count replacements up front --- .../arrow/compute/kernels/vector_replace.cc | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index a7515d9065d..a5cfad6733d 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -75,13 +75,23 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; const uint8_t* mask_values = mask.buffers[1]->data(); bool replacements_bitmap; - int64_t replacements_length; if (replacements.is_array()) { replacements_bitmap = replacements.array()->MayHaveNulls(); - replacements_length = replacements.array()->length; + const int64_t replacements_length = replacements.array()->length; + + arrow::internal::OptionalBinaryBitBlockCounter counter( + mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); + int64_t count = 0; + for (int64_t offset = 0; offset < mask.length;) { + BitBlockCount block = counter.NextAndBlock(); + count += block.popcount; + offset += block.length; + } + if (count > replacements_length) { + return ReplacementArrayTooShort(count, replacements_length); + } } else { replacements_bitmap = !replacements.scalar()->is_valid; - replacements_length = std::numeric_limits::max(); } if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) { ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(array.length)); @@ -120,10 +130,6 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, BitBlockCount block = counter.NextAndBlock(); if (block.AllSet()) { // Copy from replacement array - if (replacements_offset + block.length > replacements_length) { - return ReplacementArrayTooShort(replacements_offset + block.length, - replacements_length); - } Functor::CopyData(*array.type, out_values, out_offset, replacements, replacements_offset, block.length); if (replacements_bitmap) { @@ -137,9 +143,6 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && (!mask_bitmap || BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { - if (replacements_offset >= replacements_length) { - return ReplacementArrayTooShort(replacements_offset + 1, replacements_length); - } Functor::CopyData(*array.type, out_values, out_offset + i, replacements, replacements_offset, /*length=*/1); From 5b2de61674c290854171dc7e7577410aa6dd908e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 7 Jul 2021 16:48:54 -0400 Subject: [PATCH 10/15] ARROW-9430: [C++] Fix min/max in benchmark --- .../compute/kernels/vector_replace_benchmark.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc index 12cc623b7cb..4952cc6f341 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc @@ -37,7 +37,7 @@ static std::shared_ptr MakeReplacements(random::RandomArrayGenerator* gen for (int64_t i = 0; i < mask.length(); i++) { count += mask.Value(i) && mask.IsValid(i); } - return generator->Int64(count, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1); + return generator->Int64(count, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1); } static void ReplaceWithMaskLowSelectivityBench( @@ -46,8 +46,9 @@ static void ReplaceWithMaskLowSelectivityBench( const int64_t len = state.range(0); const int64_t offset = state.range(1); - auto values = generator.Int64(len, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1) - ->Slice(offset); + auto values = + generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1) + ->Slice(offset); auto mask = checked_pointer_cast( generator.Boolean(len, /*true_probability=*/0.1, /*null_probability=*/0.1) ->Slice(offset)); @@ -65,8 +66,9 @@ static void ReplaceWithMaskHighSelectivityBench( const int64_t len = state.range(0); const int64_t offset = state.range(1); - auto values = generator.Int64(len, /*min=*/-1.0, /*max=*/1.0, /*null_probability=*/0.1) - ->Slice(offset); + auto values = + generator.Int64(len, /*min=*/-65536, /*max=*/65536, /*null_probability=*/0.1) + ->Slice(offset); auto mask = checked_pointer_cast( generator.Boolean(len, /*true_probability=*/0.9, /*null_probability=*/0.1) ->Slice(offset)); From 0fc1995ec6965716bc36ba001f3d1ae17c4cf090 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 8 Jul 2021 10:32:33 -0400 Subject: [PATCH 11/15] ARROW-9430: [C++] Improve performance --- .../arrow/compute/kernels/vector_replace.cc | 250 ++++++++++-------- .../kernels/vector_replace_benchmark.cc | 2 +- 2 files changed, 145 insertions(+), 107 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index a5cfad6733d..8934f1497e6 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -63,20 +63,93 @@ Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, return Status::OK(); } +struct CopyArrayBitmap { + const uint8_t* in_bitmap; + int64_t in_offset; + + void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset, + int64_t length) const { + arrow::internal::CopyBitmap(in_bitmap, in_offset + offset, length, out_bitmap, + out_offset); + } + + void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const { + BitUtil::SetBitTo(out_bitmap, out_offset, + BitUtil::GetBit(in_bitmap, in_offset + offset)); + } +}; + +struct CopyScalarBitmap { + const bool is_valid; + + void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset, + int64_t length) const { + BitUtil::SetBitsTo(out_bitmap, out_offset, length, is_valid); + } + + void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const { + BitUtil::SetBitTo(out_bitmap, out_offset, is_valid); + } +}; + // Helper to implement replace_with kernel with array mask for fixed-width types, // using callbacks to handle both bool and byte-sized types and to handle // scalar and array replacements +template +void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, + const Data& replacements, bool replacements_bitmap, + const CopyBitmap copy_bitmap, const uint8_t* mask_bitmap, + const uint8_t* mask_values, uint8_t* out_bitmap, + uint8_t* out_values) { + Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, + array.length); + arrow::internal::OptionalBinaryBitBlockCounter counter( + mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); + int64_t out_offset = 0; + int64_t replacements_offset = 0; + while (out_offset < array.length) { + BitBlockCount block = counter.NextAndBlock(); + if (block.AllSet()) { + // Copy from replacement array + Functor::CopyData(*array.type, out_values, out_offset, replacements, + replacements_offset, block.length); + if (replacements_bitmap) { + copy_bitmap.CopyBitmap(out_bitmap, out_offset, replacements_offset, block.length); + } else if (!replacements_bitmap && out_bitmap) { + BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, true); + } + replacements_offset += block.length; + } else if (block.popcount) { + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && + (!mask_bitmap || + BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { + Functor::CopyData(*array.type, out_values, out_offset + i, replacements, + replacements_offset, /*length=*/1); + if (replacements_bitmap) { + copy_bitmap.SetBit(out_bitmap, out_offset + i, replacements_offset); + } + replacements_offset++; + } + } + } + out_offset += block.length; + } +} + template Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, - const ArrayData& mask, const Datum& replacements, - ArrayData* output) { + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { uint8_t* out_bitmap = nullptr; uint8_t* out_values = output->buffers[1]->mutable_data(); const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; const uint8_t* mask_values = mask.buffers[1]->data(); - bool replacements_bitmap; + const bool replacements_bitmap = replacements.is_array() + ? replacements.array()->MayHaveNulls() + : !replacements.scalar()->is_valid; if (replacements.is_array()) { - replacements_bitmap = replacements.array()->MayHaveNulls(); + // Check that we have enough replacement values const int64_t replacements_length = replacements.array()->length; arrow::internal::OptionalBinaryBitBlockCounter counter( @@ -90,8 +163,6 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, if (count > replacements_length) { return ReplacementArrayTooShort(count, replacements_length); } - } else { - replacements_bitmap = !replacements.scalar()->is_valid; } if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) { ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(array.length)); @@ -108,52 +179,19 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, } else { output->null_count = 0; } - auto copy_bitmap = [&](int64_t out_offset, int64_t in_offset, int64_t length) { - DCHECK(out_bitmap); - if (replacements.is_array()) { - const auto& in_data = *replacements.array(); - const auto in_bitmap = in_data.GetValues(0, /*absolute_offset=*/0); - arrow::internal::CopyBitmap(in_bitmap, in_data.offset + in_offset, length, - out_bitmap, out_offset); - } else { - BitUtil::SetBitsTo(out_bitmap, out_offset, length, !replacements_bitmap); - } - }; - Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, - array.length); - arrow::internal::OptionalBinaryBitBlockCounter counter( - mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); - int64_t out_offset = 0; - int64_t replacements_offset = 0; - while (out_offset < array.length) { - BitBlockCount block = counter.NextAndBlock(); - if (block.AllSet()) { - // Copy from replacement array - Functor::CopyData(*array.type, out_values, out_offset, replacements, - replacements_offset, block.length); - if (replacements_bitmap) { - copy_bitmap(out_offset, replacements_offset, block.length); - } else if (!replacements_bitmap && out_bitmap) { - BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, true); - } - replacements_offset += block.length; - } else if (block.popcount) { - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && - (!mask_bitmap || - BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { - Functor::CopyData(*array.type, out_values, out_offset + i, replacements, - replacements_offset, - /*length=*/1); - if (replacements_bitmap) { - copy_bitmap(out_offset + i, replacements_offset, 1); - } - replacements_offset++; - } - } - } - out_offset += block.length; + if (replacements.is_array()) { + const ArrayData& array_repl = *replacements.array(); + ReplaceWithArrayMaskImpl( + array, mask, array_repl, replacements_bitmap, + CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr, + array_repl.offset}, + mask_bitmap, mask_values, out_bitmap, out_values); + } else { + const Scalar& scalar_repl = *replacements.scalar(); + ReplaceWithArrayMaskImpl(array, mask, scalar_repl, replacements_bitmap, + CopyScalarBitmap{scalar_repl.is_valid}, mask_bitmap, + mask_values, out_bitmap, out_values); } if (mask.MayHaveNulls()) { @@ -172,17 +210,17 @@ struct ReplaceWithMask> { using T = typename TypeTraits::CType; static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, - const Datum& in, const int64_t in_offset, const int64_t length) { - if (in.is_array()) { - const auto& in_data = *in.array(); - const auto in_arr = - in_data.GetValues(1, (in_offset + in_data.offset) * sizeof(T)); - std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T)); - } else { - T* begin = reinterpret_cast(out + (out_offset * sizeof(T))); - T* end = begin + length; - std::fill(begin, end, UnboxScalar::Unbox(*in.scalar())); - } + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * sizeof(T)); + std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T)); + } + + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + T* begin = reinterpret_cast(out + (out_offset * sizeof(T))); + T* end = begin + length; + std::fill(begin, end, UnboxScalar::Unbox(in)); } static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, @@ -195,22 +233,21 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; template struct ReplaceWithMask> { static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, - const Datum& in, const int64_t in_offset, const int64_t length) { - if (in.is_array()) { - const auto& in_data = *in.array(); - const auto in_arr = in_data.GetValues(1, /*absolute_offset=*/0); - arrow::internal::CopyBitmap(in_arr, in_offset + in_data.offset, length, out, - out_offset); - } else { - BitUtil::SetBitsTo(out, out_offset, length, in.scalar()->is_valid); - } + const ArrayData& in, const int64_t in_offset, + const int64_t length) { + const auto in_arr = in.GetValues(1, /*absolute_offset=*/0); + arrow::internal::CopyBitmap(in_arr, in_offset + in.offset, length, out, out_offset); + } + static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + BitUtil::SetBitsTo(out, out_offset, length, in.is_valid); } static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, @@ -222,33 +259,33 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; template struct ReplaceWithMask> { static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, - const Datum& in, const int64_t in_offset, const int64_t length) { + const ArrayData& in, const int64_t in_offset, + const int64_t length) { const int32_t width = checked_cast(ty).byte_width(); uint8_t* begin = out + (out_offset * width); - if (in.is_array()) { - const auto& in_data = *in.array(); - const auto in_arr = - in_data.GetValues(1, (in_offset + in_data.offset) * width); - std::memcpy(begin, in_arr, length * width); - } else { - const FixedSizeBinaryScalar& scalar = - checked_cast(*in.scalar()); - // Null scalar may have null value buffer - if (!scalar.value) return; - const Buffer& buffer = *scalar.value; - const uint8_t* value = buffer.data(); - DCHECK_GE(buffer.size(), width); - for (int i = 0; i < length; i++) { - std::memcpy(begin, value, width); - begin += width; - } + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * width); + std::memcpy(begin, in_arr, length * width); + } + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto& scalar = checked_cast(in); + // Null scalar may have null value buffer + if (!scalar.value) return; + const Buffer& buffer = *scalar.value; + const uint8_t* value = buffer.data(); + DCHECK_GE(buffer.size(), width); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value, width); + begin += width; } } @@ -262,7 +299,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; @@ -270,21 +307,22 @@ template struct ReplaceWithMask> { using ScalarType = typename TypeTraits::ScalarType; static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, - const Datum& in, const int64_t in_offset, const int64_t length) { + const ArrayData& in, const int64_t in_offset, + const int64_t length) { const int32_t width = checked_cast(ty).byte_width(); uint8_t* begin = out + (out_offset * width); - if (in.is_array()) { - const auto& in_data = *in.array(); - const auto in_arr = - in_data.GetValues(1, (in_offset + in_data.offset) * width); - std::memcpy(begin, in_arr, length * width); - } else { - const ScalarType& scalar = checked_cast(*in.scalar()); - const auto value = scalar.value.ToBytes(); - for (int i = 0; i < length; i++) { - std::memcpy(begin, value.data(), width); - begin += width; - } + const auto in_arr = in.GetValues(1, (in_offset + in.offset) * width); + std::memcpy(begin, in_arr, length * width); + } + static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset, + const Scalar& in, const int64_t in_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (out_offset * width); + const auto& scalar = checked_cast(in); + const auto value = scalar.value.ToBytes(); + for (int i = 0; i < length; i++) { + std::memcpy(begin, value.data(), width); + begin += width; } } @@ -298,7 +336,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; diff --git a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc index 4952cc6f341..719969d46ea 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_benchmark.cc @@ -81,8 +81,8 @@ static void ReplaceWithMaskHighSelectivityBench( } BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 0}); -BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 0}); BENCHMARK(ReplaceWithMaskLowSelectivityBench)->Args({kLongLength, 99}); +BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 0}); BENCHMARK(ReplaceWithMaskHighSelectivityBench)->Args({kLongLength, 99}); } // namespace compute From 17e3cc3ba4619b2608f777656df57bde9e1cc9cf Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 8 Jul 2021 10:44:20 -0400 Subject: [PATCH 12/15] ARROW-9430: [C++] Actually run format --- cpp/src/arrow/compute/kernels/vector_replace.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 8934f1497e6..063e1aef90f 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -139,8 +139,8 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, template Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, - const ArrayData& mask, const Datum& replacements, - ArrayData* output) { + const ArrayData& mask, const Datum& replacements, + ArrayData* output) { uint8_t* out_bitmap = nullptr; uint8_t* out_values = output->buffers[1]->mutable_data(); const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; @@ -233,7 +233,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; @@ -259,7 +259,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; @@ -299,7 +299,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; @@ -336,7 +336,7 @@ struct ReplaceWithMask> { const ArrayData& mask, const Datum& replacements, ArrayData* output) { return ReplaceWithArrayMask>(ctx, array, mask, replacements, - output); + output); } }; From 3a56c31e94dfb459c4b8d621b4e1913d2fa5deec Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 12 Jul 2021 16:28:51 -0400 Subject: [PATCH 13/15] ARROW-9430: [C++] Preallocate validity buffer too --- .../arrow/compute/kernels/vector_replace.cc | 50 +++++++++++-------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 063e1aef90f..28cea35d65e 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -100,40 +100,42 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, const Data& replacements, bool replacements_bitmap, const CopyBitmap copy_bitmap, const uint8_t* mask_bitmap, const uint8_t* mask_values, uint8_t* out_bitmap, - uint8_t* out_values) { + uint8_t* out_values, const int64_t out_offset) { Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, array.length); arrow::internal::OptionalBinaryBitBlockCounter counter( mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); - int64_t out_offset = 0; + int64_t write_offset = 0; int64_t replacements_offset = 0; - while (out_offset < array.length) { + while (write_offset < array.length) { BitBlockCount block = counter.NextAndBlock(); if (block.AllSet()) { // Copy from replacement array - Functor::CopyData(*array.type, out_values, out_offset, replacements, + Functor::CopyData(*array.type, out_values, out_offset + write_offset, replacements, replacements_offset, block.length); if (replacements_bitmap) { - copy_bitmap.CopyBitmap(out_bitmap, out_offset, replacements_offset, block.length); + copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset, + block.length); } else if (!replacements_bitmap && out_bitmap) { - BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, true); + BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true); } replacements_offset += block.length; } else if (block.popcount) { for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) && + if (BitUtil::GetBit(mask_values, write_offset + mask.offset + i) && (!mask_bitmap || - BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) { - Functor::CopyData(*array.type, out_values, out_offset + i, replacements, - replacements_offset, /*length=*/1); + BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) { + Functor::CopyData(*array.type, out_values, out_offset + write_offset + i, + replacements, replacements_offset, /*length=*/1); if (replacements_bitmap) { - copy_bitmap.SetBit(out_bitmap, out_offset + i, replacements_offset); + copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i, + replacements_offset); } replacements_offset++; } } } - out_offset += block.length; + write_offset += block.length; } } @@ -141,6 +143,7 @@ template Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, const ArrayData& mask, const Datum& replacements, ArrayData* output) { + const int64_t out_offset = output->offset; uint8_t* out_bitmap = nullptr; uint8_t* out_values = output->buffers[1]->mutable_data(); const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; @@ -165,18 +168,19 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, } } if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) { - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(array.length)); out_bitmap = output->buffers[0]->mutable_data(); output->null_count = -1; if (array.MayHaveNulls()) { // Copy array's bitmap arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length, - out_bitmap, /*dest_offset=*/0); + out_bitmap, out_offset); } else { // Array has no bitmap but mask/replacements do, generate an all-valid bitmap - std::memset(out_bitmap, 0xFF, output->buffers[0]->size()); + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true); } } else { + BitUtil::SetBitsTo(output->buffers[0]->mutable_data(), out_offset, array.length, + true); output->null_count = 0; } @@ -186,18 +190,17 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, array, mask, array_repl, replacements_bitmap, CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr, array_repl.offset}, - mask_bitmap, mask_values, out_bitmap, out_values); + mask_bitmap, mask_values, out_bitmap, out_values, out_offset); } else { const Scalar& scalar_repl = *replacements.scalar(); ReplaceWithArrayMaskImpl(array, mask, scalar_repl, replacements_bitmap, CopyScalarBitmap{scalar_repl.is_valid}, mask_bitmap, - mask_values, out_bitmap, out_values); + mask_values, out_bitmap, out_values, out_offset); } if (mask.MayHaveNulls()) { - arrow::internal::BitmapAnd(out_bitmap, /*left_offset=*/0, mask.buffers[0]->data(), - mask.offset, array.length, - /*out_offset=*/0, out_bitmap); + arrow::internal::BitmapAnd(out_bitmap, out_offset, mask.buffers[0]->data(), + mask.offset, array.length, out_offset, out_bitmap); } return Status::OK(); } @@ -465,7 +468,12 @@ void RegisterVectorReplace(FunctionRegistry* registry) { auto add_kernel = [&](detail::GetTypeId get_id, ArrayKernelExec exec) { VectorKernel kernel; kernel.can_execute_chunkwise = false; - kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + if (is_fixed_width(get_id.id)) { + kernel.null_handling = NullHandling::type::COMPUTED_PREALLOCATE; + } else { + kernel.can_write_into_slices = false; + kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE; + } kernel.mem_allocation = MemAllocation::type::PREALLOCATE; kernel.signature = KernelSignature::Make( {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)}, From b5b656a937c927154c6dc47c000a5e088e31eec7 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 14:26:02 -0400 Subject: [PATCH 14/15] ARROW-9430: [C++] Fix replacement array > input array --- .../arrow/compute/kernels/vector_replace.cc | 65 +++++++++++++------ .../compute/kernels/vector_replace_test.cc | 14 +--- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 28cea35d65e..55675c2524b 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -32,6 +32,7 @@ Status ReplacementArrayTooShort(int64_t expected, int64_t actual) { // Helper to implement replace_with kernel with scalar mask for fixed-width types, // using callbacks to handle both bool and byte-sized types +template Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { @@ -41,8 +42,7 @@ Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); *output = *replacement_array->data(); return Status::OK(); - } - if (mask.value) { + } else if (mask.value) { // Output = replacement if (replacements.is_scalar()) { ARROW_ASSIGN_OR_RAISE( @@ -50,11 +50,12 @@ Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, MakeArrayFromScalar(*replacements.scalar(), array.length, ctx->memory_pool())); *output = *replacement_array->data(); } else { - auto replacement_array = replacements.array(); - if (replacement_array->length != array.length) { - return ReplacementArrayTooShort(array.length, replacement_array->length); + const ArrayData& replacement_array = *replacements.array(); + if (replacement_array.length < array.length) { + return ReplacementArrayTooShort(array.length, replacement_array.length); } - *output = *replacement_array; + *output = replacement_array; + output->length = array.length; } } else { // Output = input @@ -98,7 +99,7 @@ struct CopyScalarBitmap { template void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, const Data& replacements, bool replacements_bitmap, - const CopyBitmap copy_bitmap, const uint8_t* mask_bitmap, + const CopyBitmap& copy_bitmap, const uint8_t* mask_bitmap, const uint8_t* mask_values, uint8_t* out_bitmap, uint8_t* out_values, const int64_t out_offset) { Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0, @@ -155,14 +156,9 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, // Check that we have enough replacement values const int64_t replacements_length = replacements.array()->length; - arrow::internal::OptionalBinaryBitBlockCounter counter( - mask_values, mask.offset, mask_bitmap, mask.offset, mask.length); - int64_t count = 0; - for (int64_t offset = 0; offset < mask.length;) { - BitBlockCount block = counter.NextAndBlock(); - count += block.popcount; - offset += block.length; - } + BooleanArray mask_arr(mask.length, mask.buffers[1], mask.buffers[0], mask.null_count, + mask.offset); + const int64_t count = mask_arr.true_count(); if (count > replacements_length) { return ReplacementArrayTooShort(count, replacements_length); } @@ -229,7 +225,8 @@ struct ReplaceWithMask> { static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { - return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); } static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, @@ -256,7 +253,8 @@ struct ReplaceWithMask> { static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { - return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); } static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, const ArrayData& mask, const Datum& replacements, @@ -295,7 +293,8 @@ struct ReplaceWithMask> { static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { - return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); } static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, @@ -332,7 +331,8 @@ struct ReplaceWithMask> { static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { - return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + return ReplaceWithScalarMask>(ctx, array, mask, replacements, + output); } static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, @@ -367,7 +367,32 @@ struct ReplaceWithMask> { static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { - return ReplaceWithScalarMask(ctx, array, mask, replacements, output); + if (!mask.is_valid) { + // Output = null + ARROW_ASSIGN_OR_RAISE( + auto replacement_array, + MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); + *output = *replacement_array->data(); + } else if (mask.value) { + // Output = replacement + if (replacements.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(auto replacement_array, + MakeArrayFromScalar(*replacements.scalar(), array.length, + ctx->memory_pool())); + *output = *replacement_array->data(); + } else { + const ArrayData& replacement_array = *replacements.array(); + if (replacement_array.length < array.length) { + return ReplacementArrayTooShort(array.length, replacement_array.length); + } + *output = replacement_array; + output->length = array.length; + } + } else { + // Output = input + *output = array; + } + return Status::OK(); } static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array, const ArrayData& mask, const Datum& replacements, diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index a55ed709a0f..48f253e7ca9 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -176,6 +176,8 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) { this->array("[]"), this->array("[1]")); this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), this->array("[0]"), this->array("[0]")); + this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), + this->array("[2, 0]"), this->array("[2]")); this->Assert(ReplaceWithMask, this->array("[1]"), this->null_mask_scalar(), this->array("[]"), this->array("[null]")); @@ -258,12 +260,6 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { } TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskErrors) { - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " - "items but got 2 items)"), - this->AssertRaises(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true), - this->array("[0, 1]"))); EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " @@ -349,12 +345,6 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) { } TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) { - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, - ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 " - "items but got 2 items)"), - this->AssertRaises(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true), - this->array("[false, false]"))); EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 " From 6fa62de979fe0eb12e9a00e6d781fadd80dee0e0 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 13 Jul 2021 14:48:07 -0400 Subject: [PATCH 15/15] ARROW-9430: [C++] Properly use preallocation in scalar mask case --- .../arrow/compute/kernels/vector_replace.cc | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 55675c2524b..644aec2a4e9 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -36,30 +36,35 @@ template Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { + Datum source = array; if (!mask.is_valid) { // Output = null - ARROW_ASSIGN_OR_RAISE(auto replacement_array, - MakeArrayOfNull(array.type, array.length, ctx->memory_pool())); - *output = *replacement_array->data(); - return Status::OK(); + source = MakeNullScalar(output->type); } else if (mask.value) { // Output = replacement - if (replacements.is_scalar()) { - ARROW_ASSIGN_OR_RAISE( - auto replacement_array, - MakeArrayFromScalar(*replacements.scalar(), array.length, ctx->memory_pool())); - *output = *replacement_array->data(); + source = replacements; + } + uint8_t* out_bitmap = output->buffers[0]->mutable_data(); + uint8_t* out_values = output->buffers[1]->mutable_data(); + const int64_t out_offset = output->offset; + if (source.is_array()) { + const ArrayData& in_data = *source.array(); + if (in_data.length < array.length) { + return ReplacementArrayTooShort(array.length, in_data.length); + } + Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0, + array.length); + if (in_data.MayHaveNulls()) { + arrow::internal::CopyBitmap(in_data.buffers[0]->data(), in_data.offset, + array.length, out_bitmap, out_offset); } else { - const ArrayData& replacement_array = *replacements.array(); - if (replacement_array.length < array.length) { - return ReplacementArrayTooShort(array.length, replacement_array.length); - } - *output = replacement_array; - output->length = array.length; + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, true); } } else { - // Output = input - *output = array; + const Scalar& in_data = *source.scalar(); + Functor::CopyData(*array.type, out_values, out_offset, in_data, /*in_offset=*/0, + array.length); + BitUtil::SetBitsTo(out_bitmap, out_offset, array.length, in_data.is_valid); } return Status::OK(); }