From b249edaccb88080cec7d6297b8fac1a3f601e0c6 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 27 May 2021 09:46:57 -0400
Subject: [PATCH 01/15] ARROW-9430: [C++] Implement override_mask kernel
---
cpp/src/arrow/CMakeLists.txt | 1 +
cpp/src/arrow/array/array_binary.h | 7 +
cpp/src/arrow/array/util.cc | 5 +
cpp/src/arrow/compute/api_vector.cc | 5 +
cpp/src/arrow/compute/api_vector.h | 17 +
cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 +
.../arrow/compute/kernels/codegen_internal.h | 2 +
.../arrow/compute/kernels/vector_replace.cc | 495 ++++++++++++
.../compute/kernels/vector_replace_test.cc | 736 ++++++++++++++++++
cpp/src/arrow/compute/registry.cc | 1 +
cpp/src/arrow/compute/registry_internal.h | 1 +
cpp/src/arrow/type_traits.h | 1 +
docs/source/cpp/compute.rst | 16 +
docs/source/python/api/compute.rst | 8 +
14 files changed, 1296 insertions(+)
create mode 100644 cpp/src/arrow/compute/kernels/vector_replace.cc
create mode 100644 cpp/src/arrow/compute/kernels/vector_replace_test.cc
diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt
index 634d202623f..88a92d8c2c9 100644
--- a/cpp/src/arrow/CMakeLists.txt
+++ b/cpp/src/arrow/CMakeLists.txt
@@ -401,6 +401,7 @@ if(ARROW_COMPUTE)
compute/kernels/util_internal.cc
compute/kernels/vector_hash.cc
compute/kernels/vector_nested.cc
+ compute/kernels/vector_replace.cc
compute/kernels/vector_selection.cc
compute/kernels/vector_sort.cc
compute/exec/key_hash.cc
diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h
index db3c640b9a4..f8e8c4f8a44 100644
--- a/cpp/src/arrow/array/array_binary.h
+++ b/cpp/src/arrow/array/array_binary.h
@@ -71,6 +71,13 @@ class BaseBinaryArray : public FlatArray {
raw_value_offsets_[i + 1] - pos);
}
+ /// \brief Get binary value as a string_view
+ /// Provided for consistency with other arrays.
+ ///
+ /// \param i the value index
+ /// \return the view over the selected value
+ util::string_view Value(int64_t i) const { return GetView(i); }
+
/// \brief Get binary value as a std::string
///
/// \param i the value index
diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc
index 688cb20cb9a..ed26ecff4e0 100644
--- a/cpp/src/arrow/array/util.cc
+++ b/cpp/src/arrow/array/util.cc
@@ -528,6 +528,11 @@ class RepeatedArrayFactory {
return FinishFixedWidth(value.data(), value.size());
}
+ Status Visit(const Decimal256Type&) {
+ auto value = checked_cast(scalar_).value.ToBytes();
+ return FinishFixedWidth(value.data(), value.size());
+ }
+
template
enable_if_base_binary Visit(const T&) {
std::shared_ptr value =
diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc
index 9c1ef8533b4..a68969b2ee5 100644
--- a/cpp/src/arrow/compute/api_vector.cc
+++ b/cpp/src/arrow/compute/api_vector.cc
@@ -162,6 +162,11 @@ Result> NthToIndices(const Array& values, int64_t n,
return result.make_array();
}
+Result ReplaceWithMask(const Datum& values, const Datum& mask,
+ const Datum& replacements, ExecContext* ctx) {
+ return CallFunction("replace_with_mask", {values, mask, replacements}, ctx);
+}
+
Result> SortIndices(const Array& values, SortOrder order,
ExecContext* ctx) {
ArraySortOptions options(order);
diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h
index 6021492320e..9d8d4271db8 100644
--- a/cpp/src/arrow/compute/api_vector.h
+++ b/cpp/src/arrow/compute/api_vector.h
@@ -171,6 +171,23 @@ Result> GetTakeIndices(
} // namespace internal
+/// \brief ReplaceWithMask replaces each value in the array corresponding
+/// to a true value in the mask with the next element from `replacements`.
+///
+/// \param[in] values Array input to replace
+/// \param[in] mask Array or Scalar of Boolean mask values
+/// \param[in] replacements The replacement values to draw from. There must
+/// be as many replacement values as true values in the mask.
+/// \param[in] ctx the function execution context, optional
+///
+/// \return the resulting datum
+///
+/// \since 5.0.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Result ReplaceWithMask(const Datum& values, const Datum& mask,
+ const Datum& replacements, ExecContext* ctx = NULLPTR);
+
/// \brief Take from an array of values at indices in another array
///
/// The output array will be of the same type as the input values
diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt
index 3362d91cbe8..f8cb67d82c4 100644
--- a/cpp/src/arrow/compute/kernels/CMakeLists.txt
+++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt
@@ -48,6 +48,7 @@ add_arrow_compute_test(vector_test
SOURCES
vector_hash_test.cc
vector_nested_test.cc
+ vector_replace_test.cc
vector_selection_test.cc
vector_sort_test.cc
test_util.cc)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 33b7006491a..12e80423f7f 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1240,6 +1240,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) {
case Type::FLOAT:
case Type::DATE32:
case Type::TIME32:
+ case Type::INTERVAL_MONTHS:
return Generator::Exec;
case Type::UINT64:
case Type::INT64:
@@ -1248,6 +1249,7 @@ ArrayKernelExec GenerateTypeAgnosticPrimitive(detail::GetTypeId get_id) {
case Type::TIMESTAMP:
case Type::TIME64:
case Type::DURATION:
+ case Type::INTERVAL_DAY_TIME:
return Generator::Exec;
default:
DCHECK(false);
diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc
new file mode 100644
index 00000000000..4b768608c62
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_replace.cc
@@ -0,0 +1,495 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/kernels/common.h"
+#include "arrow/util/bitmap_ops.h"
+
+namespace arrow {
+namespace compute {
+namespace internal {
+
+namespace {
+
+Status ReplacementArrayTooShort(int64_t expected, int64_t actual) {
+ return Status::Invalid("Replacement array must be of appropriate length (expected ",
+ expected, " items but got ", actual, " items)");
+}
+
+// Helper to implement replace_with kernel with scalar mask for fixed-width types,
+// using callbacks to handle both bool and byte-sized types
+Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ if (!mask.is_valid) {
+ // Output = null
+ ARROW_ASSIGN_OR_RAISE(auto array,
+ MakeArrayOfNull(array.type, array.length, ctx->memory_pool()));
+ *output = *array->data();
+ return Status::OK();
+ }
+ if (mask.value) {
+ // Output = replacement
+ if (replacements.is_scalar()) {
+ ARROW_ASSIGN_OR_RAISE(
+ auto replacement_array,
+ MakeArrayFromScalar(*replacements.scalar(), array.length, ctx->memory_pool()));
+ *output = *replacement_array->data();
+ } else {
+ auto replacement_array = replacements.array();
+ if (replacement_array->length != array.length) {
+ return ReplacementArrayTooShort(array.length, replacement_array->length);
+ }
+ *output = *replacement_array;
+ }
+ } else {
+ // Output = input
+ *output = array;
+ }
+ return Status::OK();
+}
+
+// Helper to implement replace_with kernel with array mask for fixed-width types,
+// using callbacks to handle both bool and byte-sized types and to handle
+// scalar and array replacements
+template
+Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[1],
+ Functor::AllocateData(ctx, *array.type, array.length));
+
+ uint8_t* out_bitmap = nullptr;
+ uint8_t* out_values = output->buffers[1]->mutable_data();
+ const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
+ const uint8_t* mask_values = mask.buffers[1]->data();
+ bool replacements_bitmap;
+ int64_t replacements_length;
+ if (replacements.is_array()) {
+ replacements_bitmap = replacements.array()->MayHaveNulls();
+ replacements_length = replacements.array()->length;
+ } else {
+ replacements_bitmap = !replacements.scalar()->is_valid;
+ replacements_length = std::numeric_limits::max();
+ }
+ if (array.MayHaveNulls() || mask.MayHaveNulls() || replacements_bitmap) {
+ ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(array.length));
+ out_bitmap = output->buffers[0]->mutable_data();
+ output->null_count = -1;
+ if (array.MayHaveNulls()) {
+ arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length,
+ out_bitmap, /*dest_offset=*/0);
+ } else {
+ std::memset(out_bitmap, 0xFF, output->buffers[0]->size());
+ }
+ } else {
+ output->null_count = 0;
+ }
+ auto copy_bitmap = [&](int64_t out_offset, int64_t in_offset, int64_t length) {
+ DCHECK(out_bitmap);
+ if (replacements.is_array()) {
+ const auto& in_data = *replacements.array();
+ const auto in_bitmap = in_data.GetValues(0, /*absolute_offset=*/0);
+ arrow::internal::CopyBitmap(in_bitmap, in_data.offset + in_offset, length,
+ out_bitmap, out_offset);
+ } else {
+ BitUtil::SetBitsTo(out_bitmap, out_offset, length, !replacements_bitmap);
+ }
+ };
+
+ Functor::CopyData(*array.type, out_values, /*out_offset=*/0, array, /*in_offset=*/0,
+ array.length);
+ arrow::internal::BitBlockCounter value_counter(mask_values, mask.offset, mask.length);
+ arrow::internal::OptionalBitBlockCounter valid_counter(mask_bitmap, mask.offset,
+ mask.length);
+ int64_t out_offset = 0;
+ int64_t replacements_offset = 0;
+ while (out_offset < array.length) {
+ BitBlockCount value_block = value_counter.NextWord();
+ BitBlockCount valid_block = valid_counter.NextWord();
+ DCHECK_EQ(value_block.length, valid_block.length);
+ if (value_block.AllSet() && valid_block.AllSet()) {
+ // Copy from replacement array
+ if (replacements_offset + valid_block.length > replacements_length) {
+ return ReplacementArrayTooShort(replacements_offset + valid_block.length,
+ replacements_length);
+ }
+ Functor::CopyData(*array.type, out_values, out_offset, replacements,
+ replacements_offset, valid_block.length);
+ if (replacements_bitmap) {
+ copy_bitmap(out_offset, replacements_offset, valid_block.length);
+ } else if (!replacements_bitmap && out_bitmap) {
+ BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, true);
+ }
+ replacements_offset += valid_block.length;
+ } else if (value_block.NoneSet() && valid_block.AllSet()) {
+ // Do nothing
+ } else if (valid_block.NoneSet()) {
+ DCHECK(out_bitmap);
+ BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, false);
+ } else {
+ for (int64_t i = 0; i < valid_block.length; ++i) {
+ if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) &&
+ (!mask_bitmap ||
+ BitUtil::GetBit(mask_bitmap, out_offset + mask.offset + i))) {
+ if (replacements_offset >= replacements_length) {
+ return ReplacementArrayTooShort(replacements_offset + 1, replacements_length);
+ }
+ Functor::CopyData(*array.type, out_values, out_offset + i, replacements,
+ replacements_offset,
+ /*length=*/1);
+ if (replacements_bitmap) {
+ copy_bitmap(out_offset + i, replacements_offset, 1);
+ }
+ replacements_offset++;
+ }
+ }
+ }
+ out_offset += valid_block.length;
+ }
+
+ if (mask.MayHaveNulls()) {
+ arrow::internal::BitmapAnd(out_bitmap, /*left_offset=*/0, mask.buffers[0]->data(),
+ mask.offset, array.length,
+ /*out_offset=*/0, out_bitmap);
+ }
+ return Status::OK();
+}
+
+template
+struct ReplaceWithMask {};
+
+template
+struct ReplaceWithMask> {
+ using T = typename TypeTraits::CType;
+
+ static Result> AllocateData(KernelContext* ctx, const DataType&,
+ const int64_t length) {
+ return ctx->Allocate(length * sizeof(T));
+ }
+
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const Datum& in, const int64_t in_offset, const int64_t length) {
+ if (in.is_array()) {
+ const auto& in_data = *in.array();
+ const auto in_arr =
+ in_data.GetValues(1, (in_offset + in_data.offset) * sizeof(T));
+ std::memcpy(out + (out_offset * sizeof(T)), in_arr, length * sizeof(T));
+ } else {
+ T* begin = reinterpret_cast(out + (out_offset * sizeof(T)));
+ T* end = begin + length;
+ std::fill(begin, end, UnboxScalar::Unbox(*in.scalar()));
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask(ctx, array, mask, replacements, output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template
+struct ReplaceWithMask> {
+ static Result> AllocateData(KernelContext* ctx, const DataType&,
+ const int64_t length) {
+ return ctx->AllocateBitmap(length);
+ }
+
+ static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
+ const Datum& in, const int64_t in_offset, const int64_t length) {
+ if (in.is_array()) {
+ const auto& in_data = *in.array();
+ const auto in_arr = in_data.GetValues(1, /*absolute_offset=*/0);
+ arrow::internal::CopyBitmap(in_arr, in_offset + in_data.offset, length, out,
+ out_offset);
+ } else {
+ BitUtil::SetBitsTo(out, out_offset, length, in.scalar()->is_valid);
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask(ctx, array, mask, replacements, output);
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template
+struct ReplaceWithMask> {
+ static Result> AllocateData(KernelContext* ctx,
+ const DataType& ty,
+ const int64_t length) {
+ return ctx->Allocate(length *
+ checked_cast(ty).byte_width());
+ }
+
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const Datum& in, const int64_t in_offset, const int64_t length) {
+ const int32_t width = checked_cast(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ if (in.is_array()) {
+ const auto& in_data = *in.array();
+ const auto in_arr =
+ in_data.GetValues(1, (in_offset + in_data.offset) * width);
+ std::memcpy(begin, in_arr, length * width);
+ } else {
+ const FixedSizeBinaryScalar& scalar =
+ checked_cast(*in.scalar());
+ // Null scalar may have null value buffer
+ if (!scalar.value) return;
+ const Buffer& buffer = *scalar.value;
+ const uint8_t* value = buffer.data();
+ DCHECK_GE(buffer.size(), width);
+ for (int i = 0; i < length; i++) {
+ std::memcpy(begin, value, width);
+ begin += width;
+ }
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask(ctx, array, mask, replacements, output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template
+struct ReplaceWithMask> {
+ using ScalarType = typename TypeTraits::ScalarType;
+
+ static Result> AllocateData(KernelContext* ctx,
+ const DataType& ty,
+ const int64_t length) {
+ return ctx->Allocate(length *
+ checked_cast(ty).byte_width());
+ }
+
+ static void CopyData(const DataType& ty, uint8_t* out, const int64_t out_offset,
+ const Datum& in, const int64_t in_offset, const int64_t length) {
+ const int32_t width = checked_cast(ty).byte_width();
+ uint8_t* begin = out + (out_offset * width);
+ if (in.is_array()) {
+ const auto& in_data = *in.array();
+ const auto in_arr =
+ in_data.GetValues(1, (in_offset + in_data.offset) * width);
+ std::memcpy(begin, in_arr, length * width);
+ } else {
+ const ScalarType& scalar = checked_cast(*in.scalar());
+ const auto value = scalar.value.ToBytes();
+ for (int i = 0; i < length; i++) {
+ std::memcpy(begin, value.data(), width);
+ begin += width;
+ }
+ }
+ }
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask(ctx, array, mask, replacements, output);
+ }
+
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithArrayMask>(ctx, array, mask, replacements,
+ output);
+ }
+};
+
+template
+struct ReplaceWithMask> {
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ *output = array;
+ return Status::OK();
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ *output = array;
+ return Status::OK();
+ }
+};
+
+template
+struct ReplaceWithMask> {
+ using offset_type = typename Type::offset_type;
+ using BuilderType = typename TypeTraits::BuilderType;
+
+ static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
+ const BooleanScalar& mask, const Datum& replacements,
+ ArrayData* output) {
+ return ReplaceWithScalarMask(ctx, array, mask, replacements, output);
+ }
+ static Status ExecArrayMask(KernelContext* ctx, const ArrayData& array,
+ const ArrayData& mask, const Datum& replacements,
+ ArrayData* output) {
+ BuilderType builder(array.type, ctx->memory_pool());
+ RETURN_NOT_OK(builder.Reserve(array.length));
+ RETURN_NOT_OK(builder.ReserveData(array.buffers[2]->size()));
+ int64_t source_offset = 0;
+ int64_t replacements_offset = 0;
+ RETURN_NOT_OK(VisitArrayDataInline(
+ mask,
+ [&](bool replace) {
+ if (replace && replacements.is_scalar()) {
+ const Scalar& scalar = *replacements.scalar();
+ if (scalar.is_valid) {
+ RETURN_NOT_OK(builder.Append(UnboxScalar::Unbox(scalar)));
+ } else {
+ RETURN_NOT_OK(builder.AppendNull());
+ }
+ } else {
+ const ArrayData& source = replace ? *replacements.array() : array;
+ const int64_t offset = replace ? replacements_offset++ : source_offset;
+ if (!source.MayHaveNulls() ||
+ BitUtil::GetBit(source.buffers[0]->data(), source.offset + offset)) {
+ const uint8_t* data = source.buffers[2]->data();
+ const offset_type* offsets = source.GetValues(1);
+ const offset_type offset0 = offsets[offset];
+ const offset_type offset1 = offsets[offset + 1];
+ RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0));
+ } else {
+ RETURN_NOT_OK(builder.AppendNull());
+ }
+ }
+ source_offset++;
+ return Status::OK();
+ },
+ [&]() {
+ RETURN_NOT_OK(builder.AppendNull());
+ source_offset++;
+ return Status::OK();
+ }));
+ std::shared_ptr temp_output;
+ RETURN_NOT_OK(builder.Finish(&temp_output));
+ *output = *temp_output->data();
+ // Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
+ output->type = array.type;
+ return Status::OK();
+ }
+};
+
+template
+struct ReplaceWithMaskFunctor {
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const ArrayData& array = *batch[0].array();
+ const Datum& replacements = batch[2];
+ ArrayData* output = out->array().get();
+ output->length = array.length;
+
+ // Needed for FixedSizeBinary/parameterized types
+ if (!array.type->Equals(*replacements.type(), /*check_metadata=*/false)) {
+ return Status::Invalid("Replacements must be of same type (expected ",
+ array.type->ToString(), " but got ",
+ replacements.type()->ToString(), ")");
+ }
+
+ if (!replacements.is_array() && !replacements.is_scalar()) {
+ return Status::Invalid("Replacements must be array or scalar");
+ }
+
+ if (batch[1].is_scalar()) {
+ return ReplaceWithMask::ExecScalarMask(
+ ctx, array, batch[1].scalar_as(), replacements, output);
+ }
+ const ArrayData& mask = *batch[1].array();
+ if (array.length != mask.length) {
+ return Status::Invalid("Mask must be of same length as array (expected ",
+ array.length, " items but got ", mask.length, " items)");
+ }
+ return ReplaceWithMask::ExecArrayMask(ctx, array, mask, replacements, output);
+ }
+};
+
+} // namespace
+
+const FunctionDoc replace_with_mask_doc(
+ "Replace items using a mask and replacement values",
+ ("Given an array and a Boolean mask (either scalar or of equal length), "
+ "along with replacement values (either scalar or array), "
+ "each element of the array for which the corresponding mask element is "
+ "true will be replaced by the next value from the replacements, "
+ "or with null if the mask is null. "
+ "Hence, for replacement arrays, len(replacements) == sum(mask == true)."),
+ {"values", "mask", "replacements"});
+
+void RegisterVectorReplace(FunctionRegistry* registry) {
+ auto func = std::make_shared("replace_with_mask", Arity::Ternary(),
+ &replace_with_mask_doc);
+ auto add_kernel = [&](detail::GetTypeId get_id, ArrayKernelExec exec) {
+ VectorKernel kernel;
+ kernel.can_execute_chunkwise = false;
+ kernel.null_handling = NullHandling::type::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::type::NO_PREALLOCATE;
+ kernel.signature = KernelSignature::Make(
+ {InputType::Array(get_id.id), InputType(boolean()), InputType(get_id.id)},
+ OutputType(FirstType));
+ kernel.exec = std::move(exec);
+ DCHECK_OK(func->AddKernel(std::move(kernel)));
+ };
+ auto add_primitive_kernel = [&](detail::GetTypeId get_id) {
+ add_kernel(get_id, GenerateTypeAgnosticPrimitive(get_id));
+ };
+ for (const auto& ty : NumericTypes()) {
+ add_primitive_kernel(ty);
+ }
+ for (const auto& ty : TemporalTypes()) {
+ add_primitive_kernel(ty);
+ }
+ add_primitive_kernel(null());
+ add_primitive_kernel(boolean());
+ add_primitive_kernel(day_time_interval());
+ add_primitive_kernel(month_interval());
+ add_kernel(Type::FIXED_SIZE_BINARY, ReplaceWithMaskFunctor::Exec);
+ add_kernel(Type::DECIMAL128, ReplaceWithMaskFunctor::Exec);
+ add_kernel(Type::DECIMAL256, ReplaceWithMaskFunctor::Exec);
+ for (const auto& ty : BaseBinaryTypes()) {
+ add_kernel(ty->id(), GenerateTypeAgnosticVarBinaryBase(*ty));
+ }
+ // TODO: list types
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+
+ // TODO(ARROW-9431): "replace_with_indices"
+}
+} // namespace internal
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
new file mode 100644
index 00000000000..1826b037034
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
@@ -0,0 +1,736 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+#include
+
+#include "arrow/compute/api_vector.h"
+#include "arrow/compute/kernels/test_util.h"
+#include "arrow/testing/gtest_common.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/key_value_metadata.h"
+#include "arrow/util/make_unique.h"
+
+namespace arrow {
+namespace compute {
+
+using arrow::internal::checked_pointer_cast;
+
+namespace {
+template
+enable_if_parameter_free> default_type_instance() {
+ return TypeTraits::type_singleton();
+}
+template
+enable_if_time> default_type_instance() {
+ // Time32 requires second/milli, Time64 requires nano/micro
+ if (TypeTraits::bytes_required(1) == 4) {
+ return std::make_shared(TimeUnit::type::SECOND);
+ } else {
+ return std::make_shared(TimeUnit::type::NANO);
+ }
+}
+template
+enable_if_timestamp> default_type_instance() {
+ return std::make_shared(TimeUnit::type::SECOND);
+}
+template
+enable_if_decimal> default_type_instance() {
+ return std::make_shared(5, 2);
+}
+template
+enable_if_parameter_free::BuilderType>>
+builder_instance() {
+ return arrow::internal::make_unique::BuilderType>();
+}
+template
+enable_if_time::BuilderType>>
+builder_instance() {
+ return arrow::internal::make_unique::BuilderType>(
+ default_type_instance(), default_memory_pool());
+}
+template
+enable_if_timestamp::BuilderType>>
+builder_instance() {
+ return arrow::internal::make_unique::BuilderType>(
+ default_type_instance(), default_memory_pool());
+}
+template
+enable_if_t::value, T> max_int_value() {
+ return static_cast(
+ std::min(16384.0, static_cast(std::numeric_limits::max())));
+}
+template
+enable_if_t::value, T> max_int_value() {
+ return static_cast(
+ std::min(16384.0, static_cast(std::numeric_limits::max())));
+}
+} // namespace
+
+template
+class TestReplaceKernel : public ::testing::Test {
+ protected:
+ virtual std::shared_ptr type() = 0;
+
+ using ReplaceFunction = std::function(const Datum&, const Datum&,
+ const Datum&, ExecContext*)>;
+
+ void SetUp() override { equal_options_ = equal_options_.nans_equal(true); }
+
+ Datum mask_scalar(bool value) { return Datum(std::make_shared(value)); }
+
+ Datum null_mask_scalar() {
+ auto scalar = std::make_shared(true);
+ scalar->is_valid = false;
+ return Datum(std::move(scalar));
+ }
+
+ Datum scalar(const std::string& json) { return ScalarFromJSON(type(), json); }
+
+ std::shared_ptr array(const std::string& value) {
+ return ArrayFromJSON(type(), value);
+ }
+
+ std::shared_ptr mask(const std::string& value) {
+ return ArrayFromJSON(boolean(), value);
+ }
+
+ Status AssertRaises(ReplaceFunction func, const std::shared_ptr& array,
+ const Datum& mask, const std::shared_ptr& replacements) {
+ auto result = func(array, mask, replacements, nullptr);
+ EXPECT_FALSE(result.ok());
+ return result.status();
+ }
+
+ void Assert(ReplaceFunction func, const std::shared_ptr& array,
+ const Datum& mask, Datum replacements,
+ const std::shared_ptr& expected) {
+ SCOPED_TRACE("Replacements: " + (replacements.is_array()
+ ? replacements.make_array()->ToString()
+ : replacements.scalar()->ToString()));
+ SCOPED_TRACE("Mask: " + (mask.is_array() ? mask.make_array()->ToString()
+ : mask.scalar()->ToString()));
+ SCOPED_TRACE("Array: " + array->ToString());
+
+ ASSERT_OK_AND_ASSIGN(auto actual, func(array, mask, replacements, nullptr));
+ ASSERT_TRUE(actual.is_array());
+ ASSERT_OK(actual.make_array()->ValidateFull());
+
+ AssertArraysApproxEqual(*expected, *actual.make_array(), /*verbose=*/true,
+ equal_options_);
+ }
+
+ std::shared_ptr NaiveImpl(
+ const typename TypeTraits::ArrayType& array, const BooleanArray& mask,
+ const typename TypeTraits::ArrayType& replacements) {
+ auto length = array.length();
+ auto builder = builder_instance();
+ int64_t replacement_offset = 0;
+ for (int64_t i = 0; i < length; ++i) {
+ if (mask.IsValid(i)) {
+ if (mask.Value(i)) {
+ if (replacements.IsValid(replacement_offset)) {
+ ARROW_EXPECT_OK(builder->Append(replacements.Value(replacement_offset++)));
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ replacement_offset++;
+ }
+ } else {
+ if (array.IsValid(i)) {
+ ARROW_EXPECT_OK(builder->Append(array.Value(i)));
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ }
+ }
+ } else {
+ ARROW_EXPECT_OK(builder->AppendNull());
+ }
+ }
+ EXPECT_OK_AND_ASSIGN(auto expected, builder->Finish());
+ return expected;
+ }
+
+ EqualOptions equal_options_ = EqualOptions::Defaults();
+};
+
+template
+class TestReplaceNumeric : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override { return default_type_instance(); }
+};
+
+class TestReplaceBoolean : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override {
+ return TypeTraits::type_singleton();
+ }
+};
+
+class TestReplaceFixedSizeBinary : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override { return fixed_size_binary(3); }
+};
+
+template
+class TestReplaceDecimal : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override { return default_type_instance(); }
+};
+
+class TestReplaceDayTimeInterval : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override {
+ return TypeTraits::type_singleton();
+ }
+};
+
+template
+class TestReplaceBinary : public TestReplaceKernel {
+ protected:
+ std::shared_ptr type() override { return default_type_instance(); }
+};
+
+using NumericBasedTypes =
+ ::testing::Types;
+
+TYPED_TEST_SUITE(TestReplaceNumeric, NumericBasedTypes);
+TYPED_TEST_SUITE(TestReplaceDecimal, DecimalArrowTypes);
+TYPED_TEST_SUITE(TestReplaceBinary, BinaryTypes);
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[1]"));
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true),
+ this->array("[0]"), this->array("[0]"));
+ this->Assert(ReplaceWithMask, this->array("[1]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(false),
+ this->scalar("1"), this->array("[0, 0]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true),
+ this->scalar("1"), this->array("[1, 1]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 0]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[0, 1, 2, 3]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"),
+ this->array("[10, 11, 12, 13]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[0, 1, 2, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[true, true, true, true]"), this->array("[10, 11, 12, 13]"),
+ this->array("[10, 11, 12, 13]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2, 3, 4, 5]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[10, null]"), this->array("[10, null, 2, 3, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[10, null]"),
+ this->array("[10, null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("1"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"),
+ this->scalar("10"), this->array("[10, 10]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1]"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[0, 1, 2]"),
+ this->mask("[true, false, null]"), this->scalar("10"),
+ this->array("[10, 1, null]"));
+}
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
+ using ArrayType = typename TypeTraits::ArrayType;
+ using CType = typename TypeTraits::CType;
+ auto ty = this->type();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ const int64_t length = 1023;
+ // Clamp the range because date/time types don't print well with extreme values
+ std::vector values = {"0.01", "0"};
+ values.push_back(std::to_string(max_int_value()));
+ auto options = key_value_metadata({"null_probability", "min", "max"}, values);
+ auto array =
+ checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length));
+ auto mask = checked_pointer_cast(
+ rand.ArrayOf(boolean(), length, /*null_probability=*/0.01));
+ const int64_t num_replacements = std::count_if(
+ mask->begin(), mask->end(),
+ [](util::optional value) { return value.has_value() && *value; });
+ auto replacements = checked_pointer_cast(
+ rand.ArrayOf(*field("a", ty, options), num_replacements));
+ auto expected = this->NaiveImpl(*array, *mask, *replacements);
+
+ this->Assert(ReplaceWithMask, array, mask, replacements, expected);
+ for (int64_t slice = 1; slice <= 16; slice++) {
+ auto sliced_array = checked_pointer_cast(array->Slice(slice, 15));
+ auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15));
+ auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements);
+ this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected);
+ }
+}
+
+TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 2 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1]"), this->mask_scalar(true),
+ this->array("[0, 1]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 "
+ "items but got 1 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"),
+ this->mask("[true, true]"), this->array("[0]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"),
+ this->mask("[true, null]"), this->array("[]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Mask must be of same length as array (expected 2 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[1, 2]"), this->mask("[]"),
+ this->array("[]")));
+}
+
+TEST_F(TestReplaceBoolean, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[true]"));
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true),
+ this->array("[false]"), this->array("[false]"));
+ this->Assert(ReplaceWithMask, this->array("[true]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(false),
+ this->scalar("true"), this->array("[false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true),
+ this->scalar("true"), this->array("[true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[true, true, true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[false, false, false, false]"),
+ this->array("[false, false, false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[true, true, true, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[false, false, false, false]"),
+ this->array("[false, false, false, false]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[true, true, true, true, true, true]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[false, null]"),
+ this->array("[false, null, true, true, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[false, null]"),
+ this->array("[false, null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
+ this->scalar("true"), this->array("[true, true]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[false, false, false]"),
+ this->mask("[true, false, null]"), this->scalar("true"),
+ this->array("[true, false, null]"));
+}
+
+TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 2 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true]"), this->mask_scalar(true),
+ this->array("[false, false]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 2 "
+ "items but got 1 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"),
+ this->mask("[true, true]"), this->array("[false]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Replacement array must be of appropriate length (expected 1 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"),
+ this->mask("[true, null]"), this->array("[]")));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::HasSubstr("Mask must be of same length as array (expected 2 "
+ "items but got 0 items)"),
+ this->AssertRaises(ReplaceWithMask, this->array("[true, true]"), this->mask("[]"),
+ this->array("[]")));
+}
+
+TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["foo"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true),
+ this->array(R"(["bar"])"), this->array(R"(["bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"),
+ this->mask_scalar(false), this->scalar(R"("baz")"),
+ this->array(R"(["foo", "bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar("null"), this->array(R"([null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["aaa", "bbb", "ccc", "ddd"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", "ddd"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["aaa", "bbb", "ccc", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"),
+ this->array(R"(["eee", "fff", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["aaa", "bbb", "ccc", "ddd", "eee", "fff"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["ggg", null])"),
+ this->array(R"(["ggg", null, "ccc", "ddd", null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["aaa", null])"),
+ this->array(R"(["aaa", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("zzz")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"),
+ this->mask("[true, true]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "zzz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb"])"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["aaa", "bbb", "ccc"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "bbb", null])"));
+}
+
+TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMaskErrors) {
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ ::testing::AllOf(
+ ::testing::HasSubstr("Replacements must be of same type (expected "),
+ ::testing::HasSubstr(this->type()->ToString()),
+ ::testing::HasSubstr("but got fixed_size_binary[2]")),
+ this->AssertRaises(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ ArrayFromJSON(fixed_size_binary(2), "[]")));
+}
+
+TYPED_TEST(TestReplaceDecimal, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["1.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->mask_scalar(true),
+ this->array(R"(["0.00"])"), this->array(R"(["0.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["1.00"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(false), this->scalar(R"("1.00")"),
+ this->array(R"(["0.00", "0.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(true), this->scalar(R"("1.00")"),
+ this->array(R"(["1.00", "1.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "0.00"])"),
+ this->mask_scalar(true), this->scalar("null"),
+ this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["0.00", "1.00", "2.00", "3.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", "3.00"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["0.00", "1.00", "2.00", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"),
+ this->array(R"(["10.00", "11.00", "12.00", "13.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["0.00", "1.00", "2.00", "3.00", "4.00", "5.00"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["10.00", null])"),
+ this->array(R"(["10.00", null, "2.00", "3.00", null, null])"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["10.00", null])"),
+ this->array(R"(["10.00", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("1.00")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"),
+ this->mask("[true, true]"), this->scalar(R"("10.00")"),
+ this->array(R"(["10.00", "10.00"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00"])"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["0.00", "1.00", "2.00"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("10.00")"),
+ this->array(R"(["10.00", "1.00", null])"));
+}
+
+TEST_F(TestReplaceDayTimeInterval, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[[1, 2]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->mask_scalar(true),
+ this->array("[[3, 4]]"), this->array("[[3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2]]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(false),
+ this->scalar("[7, 8]"), this->array("[[1, 2], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true),
+ this->scalar("[7, 8]"), this->array("[[7, 8], [7, 8]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"), this->mask_scalar(true),
+ this->scalar("null"), this->array("[null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array("[[1, 2], [1, 2], [1, 2], null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[true, true, true, true]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"),
+ this->array("[[3, 4], [3, 4], [3, 4], [3, 4]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], null]"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array("[null, null, null, null]"));
+ this->Assert(
+ ReplaceWithMask, this->array("[[1, 2], [1, 2], [1, 2], [1, 2], [1, 2], [1, 2]]"),
+ this->mask("[true, true, false, false, null, null]"), this->array("[[3, 4], null]"),
+ this->array("[[3, 4], null, [1, 2], [1, 2], null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[null, null, null, null, null, null]"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array("[[3, 4], null]"),
+ this->array("[[3, 4], null, null, null, null, null]"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar("[7, 8]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"),
+ this->mask("[true, true]"), this->scalar("[7, 8]"),
+ this->array("[[7, 8], [7, 8]]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4]]"),
+ this->mask("[true, true]"), this->scalar("null"),
+ this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array("[[1, 2], [3, 4], [5, 6]]"),
+ this->mask("[true, false, null]"), this->scalar("[7, 8]"),
+ this->array("[[7, 8], [3, 4], null]"));
+}
+
+TYPED_TEST(TestReplaceBinary, ReplaceWithMask) {
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(false),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask_scalar(true),
+ this->array("[]"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array("[]"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(false),
+ this->array("[]"), this->array(R"(["foo"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->mask_scalar(true),
+ this->array(R"(["bar"])"), this->array(R"(["bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo"])"), this->null_mask_scalar(),
+ this->array("[]"), this->array("[null]"));
+
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"),
+ this->mask_scalar(false), this->scalar(R"("baz")"),
+ this->array(R"(["foo", "bar"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar(R"("baz")"), this->array(R"(["baz", "baz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["foo", "bar"])"), this->mask_scalar(true),
+ this->scalar("null"), this->array(R"([null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->array("[]"),
+ this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["a", "bb", "ccc", "dddd"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", "dddd"])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[false, false, false, false]"), this->array("[]"),
+ this->array(R"(["a", "bb", "ccc", null])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[true, true, true, true]"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"),
+ this->array(R"(["eeeee", "f", "ggg", "hhh"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc", null])"),
+ this->mask("[null, null, null, null]"), this->array("[]"),
+ this->array(R"([null, null, null, null])"));
+ this->Assert(ReplaceWithMask,
+ this->array(R"(["a", "bb", "ccc", "dddd", "eeeee", "f"])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["ggg", null])"),
+ this->array(R"(["ggg", null, "ccc", "dddd", null, null])"));
+ this->Assert(ReplaceWithMask, this->array(R"([null, null, null, null, null, null])"),
+ this->mask("[true, true, false, false, null, null]"),
+ this->array(R"(["a", null])"),
+ this->array(R"(["a", null, null, null, null, null])"));
+
+ this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
+ this->scalar(R"("zzz")"), this->array("[]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"),
+ this->scalar(R"("zzz")"), this->array(R"(["zzz", "zzz"])"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb"])"), this->mask("[true, true]"),
+ this->scalar("null"), this->array("[null, null]"));
+ this->Assert(ReplaceWithMask, this->array(R"(["a", "bb", "ccc"])"),
+ this->mask("[true, false, null]"), this->scalar(R"("zzz")"),
+ this->array(R"(["zzz", "bb", null])"));
+}
+
+TYPED_TEST(TestReplaceBinary, ReplaceWithMaskRandom) {
+ using ArrayType = typename TypeTraits::ArrayType;
+ auto ty = this->type();
+
+ random::RandomArrayGenerator rand(/*seed=*/0);
+ const int64_t length = 1023;
+ auto options = key_value_metadata({{"null_probability", "0.01"}, {"max_length", "5"}});
+ auto array =
+ checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length));
+ auto mask = checked_pointer_cast(
+ rand.ArrayOf(boolean(), length, /*null_probability=*/0.01));
+ const int64_t num_replacements = std::count_if(
+ mask->begin(), mask->end(),
+ [](util::optional value) { return value.has_value() && *value; });
+ auto replacements = checked_pointer_cast(
+ rand.ArrayOf(*field("a", ty, options), num_replacements));
+ auto expected = this->NaiveImpl(*array, *mask, *replacements);
+
+ this->Assert(ReplaceWithMask, array, mask, replacements, expected);
+ for (int64_t slice = 1; slice <= 16; slice++) {
+ auto sliced_array = checked_pointer_cast(array->Slice(slice, 15));
+ auto sliced_mask = checked_pointer_cast(mask->Slice(slice, 15));
+ auto new_expected = this->NaiveImpl(*sliced_array, *sliced_mask, *replacements);
+ this->Assert(ReplaceWithMask, sliced_array, sliced_mask, replacements, new_expected);
+ }
+}
+
+} // namespace compute
+} // namespace arrow
diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc
index 8a0d9e62518..ca7b6137306 100644
--- a/cpp/src/arrow/compute/registry.cc
+++ b/cpp/src/arrow/compute/registry.cc
@@ -168,6 +168,7 @@ static std::unique_ptr CreateBuiltInRegistry() {
// Vector functions
RegisterVectorHash(registry.get());
+ RegisterVectorReplace(registry.get());
RegisterVectorSelection(registry.get());
RegisterVectorNested(registry.get());
RegisterVectorSort(registry.get());
diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h
index dd0271eb43d..892b54341da 100644
--- a/cpp/src/arrow/compute/registry_internal.h
+++ b/cpp/src/arrow/compute/registry_internal.h
@@ -41,6 +41,7 @@ void RegisterScalarOptions(FunctionRegistry* registry);
// Vector functions
void RegisterVectorHash(FunctionRegistry* registry);
+void RegisterVectorReplace(FunctionRegistry* registry);
void RegisterVectorSelection(FunctionRegistry* registry);
void RegisterVectorNested(FunctionRegistry* registry);
void RegisterVectorSort(FunctionRegistry* registry);
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index 86664bbb162..e4d809967f9 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -233,6 +233,7 @@ struct TypeTraits {
using ArrayType = MonthIntervalArray;
using BuilderType = MonthIntervalBuilder;
using ScalarType = MonthIntervalScalar;
+ using CType = MonthIntervalType::c_type;
static constexpr int64_t bytes_required(int64_t elements) {
return elements * static_cast(sizeof(int32_t));
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index fc6c8b7c7e1..6d857ed6822 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1090,6 +1090,22 @@ Associative transforms
Each output element corresponds to a unique value in the input, along
with the number of times this value has appeared.
+Replacements
+~~~~~~~~~~~~
+
+These functions create a copy of the first input with some elements
+replaced, based on the remaining inputs.
+
++-------------------+------------+-----------------------+--------------+--------------+--------------+-------+
+| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes |
++===================+============+=======================+==============+==============+==============+=======+
+| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) |
++-------------------+------------+-----------------------+--------------+--------------+--------------+-------+
+
+* \(1) Each element in input 1 for which the corresponding Boolean in input 2
+ is true is replaced with the next value from input 3. A null in input 2
+ results in a corresponding null in the output.
+
Selections
~~~~~~~~~~
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index a611d2a2384..09c67598193 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -292,6 +292,14 @@ Conversions
cast
strptime
+Replacements
+------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ replace_with_mask
+
Selections
----------
From ccca017bba32a6ea9ffc5287c09eec725faa4462 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 08:22:35 -0400
Subject: [PATCH 02/15] ARROW-9430: [C++] Clarify replace_with_mask
implementation
---
cpp/src/arrow/compute/kernels/vector_replace.cc | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc
index 4b768608c62..7c0edb78226 100644
--- a/cpp/src/arrow/compute/kernels/vector_replace.cc
+++ b/cpp/src/arrow/compute/kernels/vector_replace.cc
@@ -37,9 +37,9 @@ Status ReplaceWithScalarMask(KernelContext* ctx, const ArrayData& array,
ArrayData* output) {
if (!mask.is_valid) {
// Output = null
- ARROW_ASSIGN_OR_RAISE(auto array,
+ ARROW_ASSIGN_OR_RAISE(auto replacement_array,
MakeArrayOfNull(array.type, array.length, ctx->memory_pool()));
- *output = *array->data();
+ *output = *replacement_array->data();
return Status::OK();
}
if (mask.value) {
@@ -91,9 +91,11 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
out_bitmap = output->buffers[0]->mutable_data();
output->null_count = -1;
if (array.MayHaveNulls()) {
+ // Copy array's bitmap
arrow::internal::CopyBitmap(array.buffers[0]->data(), array.offset, array.length,
out_bitmap, /*dest_offset=*/0);
} else {
+ // Array has no bitmap but mask/replacements do, generate an all-valid bitmap
std::memset(out_bitmap, 0xFF, output->buffers[0]->size());
}
} else {
@@ -136,11 +138,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, true);
}
replacements_offset += valid_block.length;
- } else if (value_block.NoneSet() && valid_block.AllSet()) {
+ } else if ((value_block.NoneSet() && valid_block.AllSet()) || valid_block.NoneSet()) {
// Do nothing
- } else if (valid_block.NoneSet()) {
- DCHECK(out_bitmap);
- BitUtil::SetBitsTo(out_bitmap, out_offset, valid_block.length, false);
} else {
for (int64_t i = 0; i < valid_block.length; ++i) {
if (BitUtil::GetBit(mask_values, out_offset + mask.offset + i) &&
From ea1517a3db921da1118d90fa517d9773b6b99851 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 15 Jun 2021 15:01:18 -0400
Subject: [PATCH 03/15] ARROW-9430: [C++] Cross-reference if_else and
replace_with_mask
---
docs/source/cpp/compute.rst | 38 ++++++++++++++++++++-----------------
1 file changed, 21 insertions(+), 17 deletions(-)
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 6d857ed6822..00391052b1e 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -850,6 +850,7 @@ in reverse order.
as given by :struct:`SliceOptions` where ``start`` and ``stop`` are measured
in codeunits. Null inputs emit null.
+.. _cpp-compute-scalar-structural-transforms:
Structural transforms
~~~~~~~~~~~~~~~~~~~~~
@@ -861,7 +862,7 @@ Structural transforms
+==========================+============+================================================+=====================+=========+
| fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) |
+--------------------------+------------+------------------------------------------------+---------------------+---------+
-| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type + \(2) |
+| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(2) |
+--------------------------+------------+------------------------------------------------+---------------------+---------+
| is_finite | Unary | Float, Double | Boolean | \(3) |
+--------------------------+------------+------------------------------------------------+---------------------+---------+
@@ -888,6 +889,8 @@ Structural transforms
input. If the nulls present on the first input, they will be promoted to the
output, otherwise nulls will be chosen based on the first input values.
+ Also see: :ref:`replace_with_mask `.
+
* \(3) Output is true iff the corresponding input element is finite (not Infinity,
-Infinity, or NaN).
@@ -1090,22 +1093,6 @@ Associative transforms
Each output element corresponds to a unique value in the input, along
with the number of times this value has appeared.
-Replacements
-~~~~~~~~~~~~
-
-These functions create a copy of the first input with some elements
-replaced, based on the remaining inputs.
-
-+-------------------+------------+-----------------------+--------------+--------------+--------------+-------+
-| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes |
-+===================+============+=======================+==============+==============+==============+=======+
-| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) |
-+-------------------+------------+-----------------------+--------------+--------------+--------------+-------+
-
-* \(1) Each element in input 1 for which the corresponding Boolean in input 2
- is true is replaced with the next value from input 3. A null in input 2
- results in a corresponding null in the output.
-
Selections
~~~~~~~~~~
@@ -1170,6 +1157,8 @@ value, but smaller than nulls.
table. If the input is a record batch or table, one or more sort
keys must be specified.
+.. _cpp-compute-vector-structural-transforms:
+
Structural transforms
~~~~~~~~~~~~~~~~~~~~~
@@ -1188,3 +1177,18 @@ Structural transforms
* \(2) For each value in the list child array, the index at which it is found
in the list array is appended to the output. Nulls in the parent list array
are discarded.
+
+These functions create a copy of the first input with some elements
+replaced, based on the remaining inputs.
+
++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+
+| Function name | Arity | Input type 1 | Input type 2 | Input type 3 | Output type | Notes |
++==========================+============+=======================+==============+==============+==============+=======+
+| replace_with_mask | Ternary | Fixed-width or binary | Boolean | Input type 1 | Input type 1 | \(1) |
++--------------------------+------------+-----------------------+--------------+--------------+--------------+-------+
+
+* \(1) Each element in input 1 for which the corresponding Boolean in input 2
+ is true is replaced with the next value from input 3. A null in input 2
+ results in a corresponding null in the output.
+
+ Also see: :ref:`if_else `.
From 853eee3388fb4f5b120d0e8f8f11602a48ae335a Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 25 Jun 2021 14:19:22 -0400
Subject: [PATCH 04/15] ARROW-9430: [C++] Clean up tests slightly
---
.../compute/kernels/vector_replace_test.cc | 41 ++++---------------
1 file changed, 8 insertions(+), 33 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
index 1826b037034..45e99dbf777 100644
--- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
@@ -32,6 +32,7 @@ namespace compute {
using arrow::internal::checked_pointer_cast;
namespace {
+// Helper to get a default instance of a type, including parameterized types
template
enable_if_parameter_free> default_type_instance() {
return TypeTraits::type_singleton();
@@ -39,11 +40,10 @@ enable_if_parameter_free> default_type_instance() {
template
enable_if_time> default_type_instance() {
// Time32 requires second/milli, Time64 requires nano/micro
- if (TypeTraits::bytes_required(1) == 4) {
+ if (bit_width(T::type_id) == 32) {
return std::make_shared(TimeUnit::type::SECOND);
- } else {
- return std::make_shared(TimeUnit::type::NANO);
}
+ return std::make_shared(TimeUnit::type::NANO);
}
template
enable_if_timestamp> default_type_instance() {
@@ -53,33 +53,6 @@ template
enable_if_decimal> default_type_instance() {
return std::make_shared(5, 2);
}
-template
-enable_if_parameter_free::BuilderType>>
-builder_instance() {
- return arrow::internal::make_unique::BuilderType>();
-}
-template
-enable_if_time::BuilderType>>
-builder_instance() {
- return arrow::internal::make_unique::BuilderType>(
- default_type_instance(), default_memory_pool());
-}
-template
-enable_if_timestamp::BuilderType>>
-builder_instance() {
- return arrow::internal::make_unique::BuilderType>(
- default_type_instance(), default_memory_pool());
-}
-template
-enable_if_t::value, T> max_int_value() {
- return static_cast(
- std::min(16384.0, static_cast(std::numeric_limits::max())));
-}
-template
-enable_if_t::value, T> max_int_value() {
- return static_cast(
- std::min(16384.0, static_cast(std::numeric_limits::max())));
-}
} // namespace
template
@@ -139,7 +112,8 @@ class TestReplaceKernel : public ::testing::Test {
const typename TypeTraits::ArrayType& array, const BooleanArray& mask,
const typename TypeTraits::ArrayType& replacements) {
auto length = array.length();
- auto builder = builder_instance();
+ auto builder = arrow::internal::make_unique::BuilderType>(
+ default_type_instance(), default_memory_pool());
int64_t replacement_offset = 0;
for (int64_t i = 0; i < length; ++i) {
if (mask.IsValid(i)) {
@@ -282,9 +256,10 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
random::RandomArrayGenerator rand(/*seed=*/0);
const int64_t length = 1023;
- // Clamp the range because date/time types don't print well with extreme values
std::vector values = {"0.01", "0"};
- values.push_back(std::to_string(max_int_value()));
+ // Clamp the range because date/time types don't print well with extreme values
+ values.push_back(std::to_string(static_cast(std::min(
+ 16384.0, static_cast(std::numeric_limits::max())))));
auto options = key_value_metadata({"null_probability", "min", "max"}, values);
auto array =
checked_pointer_cast(rand.ArrayOf(*field("a", ty, options), length));
From e15493000020ef0ec5060e4a5ac138841332a6a2 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 28 Jun 2021 09:10:34 -0400
Subject: [PATCH 05/15] ARROW-9430: [C++] Move test helper into test_util.h
---
cpp/src/arrow/compute/kernels/test_util.h | 22 +++++++++++++++++
.../compute/kernels/vector_replace_test.cc | 24 -------------------
2 files changed, 22 insertions(+), 24 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h
index f4854087b51..c691a9f3be3 100644
--- a/cpp/src/arrow/compute/kernels/test_util.h
+++ b/cpp/src/arrow/compute/kernels/test_util.h
@@ -172,5 +172,27 @@ void CheckDispatchBest(std::string func_name, std::vector descrs,
// Check that function fails to produce a Kernel for the set of ValueDescrs.
void CheckDispatchFails(std::string func_name, std::vector descrs);
+// Helper to get a default instance of a type, including parameterized types
+template
+enable_if_parameter_free> default_type_instance() {
+ return TypeTraits::type_singleton();
+}
+template
+enable_if_time> default_type_instance() {
+ // Time32 requires second/milli, Time64 requires nano/micro
+ if (bit_width(T::type_id) == 32) {
+ return std::make_shared(TimeUnit::type::SECOND);
+ }
+ return std::make_shared(TimeUnit::type::NANO);
+}
+template
+enable_if_timestamp> default_type_instance() {
+ return std::make_shared(TimeUnit::type::SECOND);
+}
+template
+enable_if_decimal> default_type_instance() {
+ return std::make_shared(5, 2);
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
index 45e99dbf777..a55ed709a0f 100644
--- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc
@@ -31,30 +31,6 @@ namespace compute {
using arrow::internal::checked_pointer_cast;
-namespace {
-// Helper to get a default instance of a type, including parameterized types
-template
-enable_if_parameter_free> default_type_instance() {
- return TypeTraits::type_singleton();
-}
-template
-enable_if_time> default_type_instance() {
- // Time32 requires second/milli, Time64 requires nano/micro
- if (bit_width(T::type_id) == 32) {
- return std::make_shared(TimeUnit::type::SECOND);
- }
- return std::make_shared(TimeUnit::type::NANO);
-}
-template
-enable_if_timestamp> default_type_instance() {
- return std::make_shared(TimeUnit::type::SECOND);
-}
-template
-enable_if_decimal> default_type_instance() {
- return std::make_shared(5, 2);
-}
-} // namespace
-
template
class TestReplaceKernel : public ::testing::Test {
protected:
From 97fabbe4e3891dedbdca7001ca7163fa8444ad00 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 28 Jun 2021 09:40:43 -0400
Subject: [PATCH 06/15] ARROW-9430: [C++] Take advantage of preallocation
---
.../arrow/compute/kernels/vector_replace.cc | 30 +------------------
1 file changed, 1 insertion(+), 29 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc
index 7c0edb78226..9f3aeb2ca7e 100644
--- a/cpp/src/arrow/compute/kernels/vector_replace.cc
+++ b/cpp/src/arrow/compute/kernels/vector_replace.cc
@@ -70,9 +70,6 @@ template
Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
const ArrayData& mask, const Datum& replacements,
ArrayData* output) {
- ARROW_ASSIGN_OR_RAISE(output->buffers[1],
- Functor::AllocateData(ctx, *array.type, array.length));
-
uint8_t* out_bitmap = nullptr;
uint8_t* out_values = output->buffers[1]->mutable_data();
const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
@@ -176,11 +173,6 @@ template
struct ReplaceWithMask> {
using T = typename TypeTraits::CType;
- static Result> AllocateData(KernelContext* ctx, const DataType&,
- const int64_t length) {
- return ctx->Allocate(length * sizeof(T));
- }
-
static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
const Datum& in, const int64_t in_offset, const int64_t length) {
if (in.is_array()) {
@@ -211,11 +203,6 @@ struct ReplaceWithMask> {
template