From 868d15fa110921fa9f211e65c88c8ca2b6a0894d Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Mon, 28 Dec 2020 01:54:32 +0100 Subject: [PATCH 1/8] ARROW-11044: [C++] Add "replace" kernel --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_scalar.cc | 4 + cpp/src/arrow/compute/api_scalar.h | 16 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/scalar_replace.cc | 241 ++++++++++++++++++ .../compute/kernels/scalar_replace_test.cc | 225 ++++++++++++++++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + 8 files changed, 490 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/scalar_replace.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_replace_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index eaa6b325cc6..f8d4cd4286d 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -382,6 +382,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc compute/kernels/scalar_fill_null.cc + compute/kernels/scalar_replace.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 671c8246378..a4be2eb0c7d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -140,5 +140,9 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } +Result Replace(const Datum& values, const Datum& mask, const Datum& replacement, ExecContext* ctx) { + return CallFunction("replace", {values, mask, replacement}, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 37f3077e4bd..f0bbfe5bab6 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -388,5 +388,21 @@ ARROW_EXPORT Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); +/// \brief Replace replaces each element in `values` for which mask bit is true +/// with `fill_value` +/// +/// \param[in] values input to replace based on mask +/// \param[in] mask bits +/// \param[in] replacement scalar +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since X.X.X +/// \note API not yet finalized +ARROW_EXPORT +Result Replace(const Datum& values, const Datum& mask, const Datum& replacement, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 577b250da87..4383ceb8b00 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test scalar_string_test.cc scalar_validity_test.cc scalar_fill_null_test.cc + scalar_replace_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc new file mode 100644 index 00000000000..a4200ad4b9f --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -0,0 +1,241 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using internal::BitBlockCount; +using internal::BitBlockCounter; + +namespace compute { +namespace internal { + +namespace { + +template +struct ReplaceFunctor {}; + +// Numeric inputs + +template +struct ReplaceFunctor::value>> { + using T = typename TypeTraits::CType; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& data = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const Scalar& replacement = *batch[2].scalar(); + ArrayData* output = out->mutable_array(); + + if (replacement.is_valid) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, + ctx->Allocate(data.length * sizeof(T))); + T value = UnboxScalar::Unbox(replacement); + const uint8_t* to_replace = mask.buffers[1]->data(); + const T* in_values = data.GetValues(1); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + int64_t offset = data.offset; + BitBlockCounter bit_counter(to_replace, data.offset, data.length); + while (offset < data.offset + data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.NoneSet()) { + std::memcpy(out_values, in_values, block.length * sizeof(T)); + } else if (block.AllSet()) { + std::fill(out_values, out_values + block.length, value); + } else { + for (int64_t i = 0; i < block.length; ++i) { + out_values[i] = BitUtil::GetBit(to_replace, offset + i) ? value : in_values[i]; + } + } + offset += block.length; + out_values += block.length; + in_values += block.length; + } + output->buffers[1] = out_buf; + } else { + *output = data; + } + } +}; + +// Boolean input + +template +struct ReplaceFunctor::value>> { + static void Exec(KernelContext* ctx, const ExecBatch batch, Datum* out) { + const ArrayData& data = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const Scalar& replacement = *batch[2].scalar(); + ArrayData* output = out->mutable_array(); + + bool value = UnboxScalar::Unbox(replacement); + if (replacement.is_valid) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, + ctx->AllocateBitmap(data.length)); + + const uint8_t* to_replace = mask.buffers[1]->data(); + const uint8_t* data_bitmap = data.buffers[1]->data(); + uint8_t* out_bitmap = out_buf->mutable_data(); + + int64_t data_offset = data.offset; + BitBlockCounter bit_counter(to_replace, data.offset, data.length); + + int64_t out_offset = 0; + while (out_offset < data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.NoneSet()) { + ::arrow::internal::CopyBitmap(data_bitmap, data_offset, block.length, + out_bitmap, out_offset); + } else if (block.AllSet()) { + BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, value); + } else { + for (int64_t i = 0 ; i < block.length ; ++i) { + BitUtil::SetBitTo(out_bitmap, + out_offset + i, + BitUtil::GetBit(to_replace, data_offset + i) ? value : BitUtil::GetBit(data_bitmap, data_offset + i)); + } + } + data_offset += block.length; + out_offset += block.length; + } + output->buffers[1] = out_buf; + } else { + *output = data; + } + } +}; + +// Null input + +template +struct ReplaceFunctor::value>> { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Nothing preallocated, so we assign into the output + *out->mutable_array() = *batch[0].array(); + } +}; + +// Binary-like + +template +struct ReplaceFunctor::value>> { + using BuilderType = typename TypeTraits::BuilderType; + using OffsetType = typename Type::offset_type; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& input = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const auto& replacement_scalar = checked_cast(*batch[2].scalar()); + util::string_view replacement(*replacement_scalar.value); + ArrayData* output = out->mutable_array(); + + const uint8_t* to_replace = mask.buffers[1]->data(); + uint64_t replace_count = 0; + { + BitBlockCounter bit_counter(to_replace, input.offset, input.length); + int64_t out_offset = 0; + while (out_offset < input.length) { + BitBlockCount block = bit_counter.NextWord(); + replace_count += block.popcount; + out_offset += block.length; + } + } + + if (replace_count > 0 && replacement_scalar.is_valid) { + auto input_offsets = input.GetValues(1); + auto input_values = input.GetValues(2, input.offset); + BuilderType builder(input.type, ctx->memory_pool()); + KERNEL_RETURN_IF_ERROR(ctx, builder.ReserveData(input.buffers[2]->size() + + replacement.length() * replace_count)); + KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); + + BitBlockCounter bit_counter(to_replace, input.offset, input.length); + int64_t input_offset = 0; + while (input_offset < input.length) { + BitBlockCount block = bit_counter.NextWord(); + for (int64_t i = 0 ; i < block.length ; ++i) { + if (BitUtil::GetBit(to_replace, input_offset + i)) { + builder.UnsafeAppend(replacement); + } else { + auto current_offset = input_offsets[input_offset+i]; + auto next_offset = input_offsets[input_offset+i+1]; + auto string_value = util::string_view(input_values + current_offset, + next_offset - current_offset); + builder.UnsafeAppend(string_value); + } + } + input_offset += block.length; + } + std::shared_ptr string_array; + KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&string_array)); + *output = *string_array->data(); + // The builder does not match the logical type, due to + // GenerateTypeAgnosticVarBinaryBase + output->type = input.type; + } else { + *output = input; + } + } +}; + +void AddBasicReplaceKernels(ScalarKernel kernel, ScalarFunction* func) { + auto AddKernels = [&](const std::vector>& types) { + for (const std::shared_ptr& ty : types) { + kernel.signature = KernelSignature::Make({InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); + kernel.exec = GenerateTypeAgnosticPrimitive(*ty); + DCHECK_OK(func->AddKernel(kernel)); + } + }; + AddKernels(NumericTypes()); + AddKernels(TemporalTypes()); + AddKernels({boolean(), null()}); +} + +void AddBinaryReplaceKernels(ScalarKernel kernel, ScalarFunction* func) { + for (const std::shared_ptr& ty : BaseBinaryTypes()) { + kernel.signature = KernelSignature::Make({InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); + kernel.exec = GenerateTypeAgnosticVarBinaryBase(*ty); + DCHECK_OK(func->AddKernel(kernel)); + } +} + +const FunctionDoc replace_doc{ + "Replace selected elements", + ("`replacement` must be a scalar of the same type as `values`.\n" + "Each unmasked value in `values` is emitted as-is.\n" + "Each masked value in `values` is replaced with `replacement`."), + {"values", "mask", "replacement"}}; + +} // namespace + +void RegisterScalarReplace(FunctionRegistry* registry) { + ScalarKernel replace_base; + replace_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + replace_base.mem_allocation = MemAllocation::NO_PREALLOCATE; + auto replace = std::make_shared("replace", Arity::Ternary(), &replace_doc); + AddBasicReplaceKernels(replace_base, replace.get()); + AddBinaryReplaceKernels(replace_base, replace.get()); + DCHECK_OK(registry->AddFunction(replace)); +} + +} // namespace internal +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc new file mode 100644 index 00000000000..491277af37a --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/compute/api.h" +#include "arrow/scalar.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bitmap_ops.h" + +namespace arrow { + +using internal::InvertBitmap; + +namespace compute { + +void CheckReplace(const Array& input, + const Array& mask, + const Datum& replacement, + const Array& expected) { + auto Check = [&](const Array& input, const Array& mask, const Array& expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(expected, *result, /*verbose=*/true); + }; + + Check(input, mask, expected); + if (input.length() > 0) { + Check(*input.Slice(1), + *mask.Slice(1), + *expected.Slice(1)); + } +} + +void CheckReplace(const std::shared_ptr& type, + const std::string& in_values, + const std::string& in_mask, + const Datum& replacement, + const std::string& out_values) { + std::shared_ptr input = ArrayFromJSON(type, in_values); + std::shared_ptr mask = ArrayFromJSON(boolean(), in_mask); + std::shared_ptr expected = ArrayFromJSON(type, out_values); + CheckReplace(*input, *mask, replacement, *expected); +} + +class TestReplaceKernel : public ::testing::Test {}; + +template +class TestReplacePrimitive : public ::testing::Test {}; + +typedef ::testing::Types + PrimitiveTypes; + +TEST_F(TestReplaceKernel, ReplaceInvalidScalar) { + auto scalar = std::make_shared(3); + scalar->is_valid = false; + CheckReplace(int8(), + "[2, 4, 7, 9]", + "[true, false, false, false]", + Datum(scalar), + "[2, 4, 7, 9]"); +} + +TYPED_TEST_SUITE(TestReplacePrimitive, PrimitiveTypes); + +TYPED_TEST(TestReplacePrimitive, Replace) { + using T = typename TypeParam::c_type; + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + auto type = TypeTraits::type_singleton(); + auto scalar = std::make_shared(static_cast(42)); + + // No replacement + CheckReplace(type, "[2, 4, 7, 9]", "[false, false, false, false]", Datum(scalar), "[2, 4, 7, 9]"); + // Some replacements + CheckReplace(type, "[2, 4, 7, 9]", "[true, false, true, false]", Datum(scalar), "[42, 4, 42, 9]"); + // Empty Array + CheckReplace(type, "[]", "[]", Datum(scalar), "[]"); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto arr = std::static_pointer_cast(rand.ArrayOf(type, 1000, /*null_probability=*/0.01)); + // use arr inverted null bits as mask, so expect to replace all null values... + auto mask_data = std::make_shared(boolean(), arr->length(), 0); + mask_data->null_count = 0; + mask_data->buffers.resize(2); + mask_data->buffers[0] = nullptr; + mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), + mask_data->buffers[1]->mutable_data(), mask_data->offset); + std::shared_ptr expected_data = arr->data()->Copy(); + expected_data->null_count = 0; + expected_data->buffers[0] = nullptr; + expected_data->buffers[1] = *AllocateBuffer(arr->length() * sizeof(T)); + T* out_data = expected_data->GetMutableValues(1); + for (int64_t i = 0 ; i < arr->length() ; ++i) { + if (arr->IsValid(i)) { + out_data[i] = arr->Value(i); + } else { + out_data[i] = scalar->value; + } + } + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data)); +} + +TEST_F(TestReplaceKernel, ReplaceNull) { + auto datum = Datum(std::make_shared()); + CheckReplace(null(), + "[null, null, null, null]", + "[true, true, true, true]", + /*replacement=*/datum, + "[null, null, null, null]"); +} + +TEST_F(TestReplaceKernel, ReplaceBoolean) { + auto scalar1 = std::make_shared(false); + auto scalar2 = std::make_shared(true); + + // No replacement + CheckReplace(boolean(), + "[true, false, true, false]", + "[false, false, false, false]", + Datum(scalar1), + "[true, false, true, false]"); + // Some replacements + CheckReplace(boolean(), + "[true, false, true, false]", + "[true, false, true, false]", + Datum(scalar1), + "[false, false, false, false]"); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto arr = std::static_pointer_cast(rand.Boolean(1000, + /*true_probability=*/0.5, + /*null_probability=*/0.01)); + // use arr inverted null bits as mask, so expect to replace all null values... + auto mask_data = std::make_shared(boolean(), arr->length(), 0); + mask_data->null_count = 0; + mask_data->buffers.resize(2); + mask_data->buffers[0] = nullptr; + mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), + mask_data->buffers[1]->mutable_data(), mask_data->offset); + auto expected_data = arr->data()->Copy(); + expected_data->null_count = 0; + expected_data->buffers[0] = nullptr; + expected_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + uint8_t* out_data = expected_data->buffers[1]->mutable_data(); + for (int64_t i = 0 ; i < arr->length() ; ++i) { + if (arr->IsValid(i)) { + BitUtil::SetBitTo(out_data, i, arr->Value(i)); + } else { + BitUtil::SetBitTo(out_data, i, scalar1->value); + } + } + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), BooleanArray(expected_data)); +} + +TEST_F(TestReplaceKernel, ReplaceTimestamp) { + auto time32_type = time32(TimeUnit::SECOND); + auto time64_type = time64(TimeUnit::NANO); + auto scalar1 = std::make_shared(5, time32_type); + auto scalar2 = std::make_shared(6, time64_type); + // No replacement + CheckReplace(time32_type, + "[2, 1, 6, 9]", + "[false, false, false, false]", + Datum(scalar1), + "[2, 1, 6, 9]"); + CheckReplace(time64_type, + "[2, 1, 6, 9]", + "[false, false, false, false]", + Datum(scalar2), + "[2, 1, 6, 9]"); + // Some replacements + CheckReplace(time32_type, + "[2, 1, 6, 9]", + "[true, false, true, false]", + Datum(scalar1), + "[5, 1, 5, 9]"); + CheckReplace(time64_type, + "[2, 1, 6, 9]", + "[false, true, false, true]", + Datum(scalar2), + "[2, 6, 6, 6]"); +} + +TEST_F(TestReplaceKernel, ReplaceString) { + auto type = large_utf8(); + auto scalar = std::make_shared("arrow"); + // No replacement + CheckReplace(type, + R"(["foo", "bar"])", + "[false, false]", + Datum(scalar), + R"(["foo", "bar"])"); + // Some replacements + CheckReplace(type, + R"(["foo", "bar"])", + "[true, false]", + Datum(scalar), + R"(["arrow", "bar"])"); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index b1e0d48ccdc..bb9d0ce0e37 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -125,6 +125,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); RegisterScalarFillNull(registry.get()); + RegisterScalarReplace(registry.get()); // Aggregate functions RegisterScalarAggregateBasic(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 4e39eeb8204..ca85624e060 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -34,6 +34,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); void RegisterScalarFillNull(FunctionRegistry* registry); +void RegisterScalarReplace(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); From c9789adff70eee686a453508a3db1c23fbb12925 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Tue, 29 Dec 2020 01:20:06 +0100 Subject: [PATCH 2/8] ARROW-11044: [C++] Add "replace" kernel Run clang-format to fix indentation and pass CI lint check. --- cpp/src/arrow/compute/api_scalar.cc | 3 +- .../arrow/compute/kernels/scalar_replace.cc | 321 +++++++++--------- .../compute/kernels/scalar_replace_test.cc | 273 +++++++-------- 3 files changed, 287 insertions(+), 310 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index a4be2eb0c7d..3034ac0b7b8 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -140,7 +140,8 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } -Result Replace(const Datum& values, const Datum& mask, const Datum& replacement, ExecContext* ctx) { +Result Replace(const Datum& values, const Datum& mask, const Datum& replacement, + ExecContext* ctx) { return CallFunction("replace", {values, mask, replacement}, ctx); } diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index a4200ad4b9f..f2036d359ba 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -37,203 +37,210 @@ struct ReplaceFunctor {}; template struct ReplaceFunctor::value>> { - using T = typename TypeTraits::CType; - - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const ArrayData& data = *batch[0].array(); - const ArrayData& mask = *batch[1].array(); - const Scalar& replacement = *batch[2].scalar(); - ArrayData* output = out->mutable_array(); - - if (replacement.is_valid) { - KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, - ctx->Allocate(data.length * sizeof(T))); - T value = UnboxScalar::Unbox(replacement); - const uint8_t* to_replace = mask.buffers[1]->data(); - const T* in_values = data.GetValues(1); - T* out_values = reinterpret_cast(out_buf->mutable_data()); - int64_t offset = data.offset; - BitBlockCounter bit_counter(to_replace, data.offset, data.length); - while (offset < data.offset + data.length) { - BitBlockCount block = bit_counter.NextWord(); - if (block.NoneSet()) { - std::memcpy(out_values, in_values, block.length * sizeof(T)); - } else if (block.AllSet()) { - std::fill(out_values, out_values + block.length, value); - } else { - for (int64_t i = 0; i < block.length; ++i) { - out_values[i] = BitUtil::GetBit(to_replace, offset + i) ? value : in_values[i]; - } - } - offset += block.length; - out_values += block.length; - in_values += block.length; - } - output->buffers[1] = out_buf; + using T = typename TypeTraits::CType; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& data = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const Scalar& replacement = *batch[2].scalar(); + ArrayData* output = out->mutable_array(); + + if (replacement.is_valid) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, + ctx->Allocate(data.length * sizeof(T))); + T value = UnboxScalar::Unbox(replacement); + const uint8_t* to_replace = mask.buffers[1]->data(); + const T* in_values = data.GetValues(1); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + int64_t offset = data.offset; + BitBlockCounter bit_counter(to_replace, data.offset, data.length); + while (offset < data.offset + data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.NoneSet()) { + std::memcpy(out_values, in_values, block.length * sizeof(T)); + } else if (block.AllSet()) { + std::fill(out_values, out_values + block.length, value); } else { - *output = data; + for (int64_t i = 0; i < block.length; ++i) { + out_values[i] = + BitUtil::GetBit(to_replace, offset + i) ? value : in_values[i]; + } } + offset += block.length; + out_values += block.length; + in_values += block.length; + } + output->buffers[1] = out_buf; + } else { + *output = data; } + } }; // Boolean input template struct ReplaceFunctor::value>> { - static void Exec(KernelContext* ctx, const ExecBatch batch, Datum* out) { - const ArrayData& data = *batch[0].array(); - const ArrayData& mask = *batch[1].array(); - const Scalar& replacement = *batch[2].scalar(); - ArrayData* output = out->mutable_array(); - - bool value = UnboxScalar::Unbox(replacement); - if (replacement.is_valid) { - KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, - ctx->AllocateBitmap(data.length)); - - const uint8_t* to_replace = mask.buffers[1]->data(); - const uint8_t* data_bitmap = data.buffers[1]->data(); - uint8_t* out_bitmap = out_buf->mutable_data(); - - int64_t data_offset = data.offset; - BitBlockCounter bit_counter(to_replace, data.offset, data.length); - - int64_t out_offset = 0; - while (out_offset < data.length) { - BitBlockCount block = bit_counter.NextWord(); - if (block.NoneSet()) { - ::arrow::internal::CopyBitmap(data_bitmap, data_offset, block.length, - out_bitmap, out_offset); - } else if (block.AllSet()) { - BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, value); - } else { - for (int64_t i = 0 ; i < block.length ; ++i) { - BitUtil::SetBitTo(out_bitmap, - out_offset + i, - BitUtil::GetBit(to_replace, data_offset + i) ? value : BitUtil::GetBit(data_bitmap, data_offset + i)); - } - } - data_offset += block.length; - out_offset += block.length; - } - output->buffers[1] = out_buf; + static void Exec(KernelContext* ctx, const ExecBatch batch, Datum* out) { + const ArrayData& data = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const Scalar& replacement = *batch[2].scalar(); + ArrayData* output = out->mutable_array(); + + bool value = UnboxScalar::Unbox(replacement); + if (replacement.is_valid) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, + ctx->AllocateBitmap(data.length)); + + const uint8_t* to_replace = mask.buffers[1]->data(); + const uint8_t* data_bitmap = data.buffers[1]->data(); + uint8_t* out_bitmap = out_buf->mutable_data(); + + int64_t data_offset = data.offset; + BitBlockCounter bit_counter(to_replace, data.offset, data.length); + + int64_t out_offset = 0; + while (out_offset < data.length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.NoneSet()) { + ::arrow::internal::CopyBitmap(data_bitmap, data_offset, block.length, + out_bitmap, out_offset); + } else if (block.AllSet()) { + BitUtil::SetBitsTo(out_bitmap, out_offset, block.length, value); } else { - *output = data; + for (int64_t i = 0; i < block.length; ++i) { + BitUtil::SetBitTo(out_bitmap, out_offset + i, + BitUtil::GetBit(to_replace, data_offset + i) + ? value + : BitUtil::GetBit(data_bitmap, data_offset + i)); + } } + data_offset += block.length; + out_offset += block.length; + } + output->buffers[1] = out_buf; + } else { + *output = data; } + } }; // Null input template struct ReplaceFunctor::value>> { - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Nothing preallocated, so we assign into the output - *out->mutable_array() = *batch[0].array(); - } + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Nothing preallocated, so we assign into the output + *out->mutable_array() = *batch[0].array(); + } }; // Binary-like template struct ReplaceFunctor::value>> { - using BuilderType = typename TypeTraits::BuilderType; - using OffsetType = typename Type::offset_type; - - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const ArrayData& input = *batch[0].array(); - const ArrayData& mask = *batch[1].array(); - const auto& replacement_scalar = checked_cast(*batch[2].scalar()); - util::string_view replacement(*replacement_scalar.value); - ArrayData* output = out->mutable_array(); - - const uint8_t* to_replace = mask.buffers[1]->data(); - uint64_t replace_count = 0; - { - BitBlockCounter bit_counter(to_replace, input.offset, input.length); - int64_t out_offset = 0; - while (out_offset < input.length) { - BitBlockCount block = bit_counter.NextWord(); - replace_count += block.popcount; - out_offset += block.length; - } - } + using BuilderType = typename TypeTraits::BuilderType; + using OffsetType = typename Type::offset_type; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const ArrayData& input = *batch[0].array(); + const ArrayData& mask = *batch[1].array(); + const auto& replacement_scalar = + checked_cast(*batch[2].scalar()); + util::string_view replacement(*replacement_scalar.value); + ArrayData* output = out->mutable_array(); + + const uint8_t* to_replace = mask.buffers[1]->data(); + uint64_t replace_count = 0; + { + BitBlockCounter bit_counter(to_replace, input.offset, input.length); + int64_t out_offset = 0; + while (out_offset < input.length) { + BitBlockCount block = bit_counter.NextWord(); + replace_count += block.popcount; + out_offset += block.length; + } + } - if (replace_count > 0 && replacement_scalar.is_valid) { - auto input_offsets = input.GetValues(1); - auto input_values = input.GetValues(2, input.offset); - BuilderType builder(input.type, ctx->memory_pool()); - KERNEL_RETURN_IF_ERROR(ctx, builder.ReserveData(input.buffers[2]->size() - + replacement.length() * replace_count)); - KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); - - BitBlockCounter bit_counter(to_replace, input.offset, input.length); - int64_t input_offset = 0; - while (input_offset < input.length) { - BitBlockCount block = bit_counter.NextWord(); - for (int64_t i = 0 ; i < block.length ; ++i) { - if (BitUtil::GetBit(to_replace, input_offset + i)) { - builder.UnsafeAppend(replacement); - } else { - auto current_offset = input_offsets[input_offset+i]; - auto next_offset = input_offsets[input_offset+i+1]; - auto string_value = util::string_view(input_values + current_offset, - next_offset - current_offset); - builder.UnsafeAppend(string_value); - } - } - input_offset += block.length; - } - std::shared_ptr string_array; - KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&string_array)); - *output = *string_array->data(); - // The builder does not match the logical type, due to - // GenerateTypeAgnosticVarBinaryBase - output->type = input.type; - } else { - *output = input; + if (replace_count > 0 && replacement_scalar.is_valid) { + auto input_offsets = input.GetValues(1); + auto input_values = input.GetValues(2, input.offset); + BuilderType builder(input.type, ctx->memory_pool()); + KERNEL_RETURN_IF_ERROR(ctx, + builder.ReserveData(input.buffers[2]->size() + + replacement.length() * replace_count)); + KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); + + BitBlockCounter bit_counter(to_replace, input.offset, input.length); + int64_t input_offset = 0; + while (input_offset < input.length) { + BitBlockCount block = bit_counter.NextWord(); + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(to_replace, input_offset + i)) { + builder.UnsafeAppend(replacement); + } else { + auto current_offset = input_offsets[input_offset + i]; + auto next_offset = input_offsets[input_offset + i + 1]; + auto string_value = util::string_view(input_values + current_offset, + next_offset - current_offset); + builder.UnsafeAppend(string_value); + } } + input_offset += block.length; + } + std::shared_ptr string_array; + KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&string_array)); + *output = *string_array->data(); + // The builder does not match the logical type, due to + // GenerateTypeAgnosticVarBinaryBase + output->type = input.type; + } else { + *output = input; } + } }; void AddBasicReplaceKernels(ScalarKernel kernel, ScalarFunction* func) { - auto AddKernels = [&](const std::vector>& types) { - for (const std::shared_ptr& ty : types) { - kernel.signature = KernelSignature::Make({InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); - kernel.exec = GenerateTypeAgnosticPrimitive(*ty); - DCHECK_OK(func->AddKernel(kernel)); - } - }; - AddKernels(NumericTypes()); - AddKernels(TemporalTypes()); - AddKernels({boolean(), null()}); + auto AddKernels = [&](const std::vector>& types) { + for (const std::shared_ptr& ty : types) { + kernel.signature = KernelSignature::Make( + {InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); + kernel.exec = GenerateTypeAgnosticPrimitive(*ty); + DCHECK_OK(func->AddKernel(kernel)); + } + }; + AddKernels(NumericTypes()); + AddKernels(TemporalTypes()); + AddKernels({boolean(), null()}); } void AddBinaryReplaceKernels(ScalarKernel kernel, ScalarFunction* func) { - for (const std::shared_ptr& ty : BaseBinaryTypes()) { - kernel.signature = KernelSignature::Make({InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); - kernel.exec = GenerateTypeAgnosticVarBinaryBase(*ty); - DCHECK_OK(func->AddKernel(kernel)); - } + for (const std::shared_ptr& ty : BaseBinaryTypes()) { + kernel.signature = KernelSignature::Make( + {InputType::Array(ty), InputType::Array(boolean()), InputType::Scalar(ty)}, ty); + kernel.exec = GenerateTypeAgnosticVarBinaryBase(*ty); + DCHECK_OK(func->AddKernel(kernel)); + } } const FunctionDoc replace_doc{ - "Replace selected elements", - ("`replacement` must be a scalar of the same type as `values`.\n" - "Each unmasked value in `values` is emitted as-is.\n" - "Each masked value in `values` is replaced with `replacement`."), - {"values", "mask", "replacement"}}; + "Replace selected elements", + ("`replacement` must be a scalar of the same type as `values`.\n" + "Each unmasked value in `values` is emitted as-is.\n" + "Each masked value in `values` is replaced with `replacement`."), + {"values", "mask", "replacement"}}; } // namespace void RegisterScalarReplace(FunctionRegistry* registry) { - ScalarKernel replace_base; - replace_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - replace_base.mem_allocation = MemAllocation::NO_PREALLOCATE; - auto replace = std::make_shared("replace", Arity::Ternary(), &replace_doc); - AddBasicReplaceKernels(replace_base, replace.get()); - AddBinaryReplaceKernels(replace_base, replace.get()); - DCHECK_OK(registry->AddFunction(replace)); + ScalarKernel replace_base; + replace_base.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + replace_base.mem_allocation = MemAllocation::NO_PREALLOCATE; + auto replace = + std::make_shared("replace", Arity::Ternary(), &replace_doc); + AddBasicReplaceKernels(replace_base, replace.get()); + AddBinaryReplaceKernels(replace_base, replace.get()); + DCHECK_OK(registry->AddFunction(replace)); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 491277af37a..06dc8360ace 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -31,34 +31,28 @@ using internal::InvertBitmap; namespace compute { -void CheckReplace(const Array& input, - const Array& mask, - const Datum& replacement, +void CheckReplace(const Array& input, const Array& mask, const Datum& replacement, const Array& expected) { - auto Check = [&](const Array& input, const Array& mask, const Array& expected) { - ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); - std::shared_ptr result = datum_out.make_array(); - ASSERT_OK(result->ValidateFull()); - AssertArraysEqual(expected, *result, /*verbose=*/true); - }; - - Check(input, mask, expected); - if (input.length() > 0) { - Check(*input.Slice(1), - *mask.Slice(1), - *expected.Slice(1)); - } + auto Check = [&](const Array& input, const Array& mask, const Array& expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(expected, *result, /*verbose=*/true); + }; + + Check(input, mask, expected); + if (input.length() > 0) { + Check(*input.Slice(1), *mask.Slice(1), *expected.Slice(1)); + } } -void CheckReplace(const std::shared_ptr& type, - const std::string& in_values, - const std::string& in_mask, - const Datum& replacement, +void CheckReplace(const std::shared_ptr& type, const std::string& in_values, + const std::string& in_mask, const Datum& replacement, const std::string& out_values) { - std::shared_ptr input = ArrayFromJSON(type, in_values); - std::shared_ptr mask = ArrayFromJSON(boolean(), in_mask); - std::shared_ptr expected = ArrayFromJSON(type, out_values); - CheckReplace(*input, *mask, replacement, *expected); + std::shared_ptr input = ArrayFromJSON(type, in_values); + std::shared_ptr mask = ArrayFromJSON(boolean(), in_mask); + std::shared_ptr expected = ArrayFromJSON(type, out_values); + CheckReplace(*input, *mask, replacement, *expected); } class TestReplaceKernel : public ::testing::Test {}; @@ -72,153 +66,128 @@ typedef ::testing::Types(3); - scalar->is_valid = false; - CheckReplace(int8(), - "[2, 4, 7, 9]", - "[true, false, false, false]", - Datum(scalar), - "[2, 4, 7, 9]"); + auto scalar = std::make_shared(3); + scalar->is_valid = false; + CheckReplace(int8(), "[2, 4, 7, 9]", "[true, false, false, false]", Datum(scalar), + "[2, 4, 7, 9]"); } TYPED_TEST_SUITE(TestReplacePrimitive, PrimitiveTypes); TYPED_TEST(TestReplacePrimitive, Replace) { - using T = typename TypeParam::c_type; - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - auto type = TypeTraits::type_singleton(); - auto scalar = std::make_shared(static_cast(42)); - - // No replacement - CheckReplace(type, "[2, 4, 7, 9]", "[false, false, false, false]", Datum(scalar), "[2, 4, 7, 9]"); - // Some replacements - CheckReplace(type, "[2, 4, 7, 9]", "[true, false, true, false]", Datum(scalar), "[42, 4, 42, 9]"); - // Empty Array - CheckReplace(type, "[]", "[]", Datum(scalar), "[]"); - - random::RandomArrayGenerator rand(/*seed=*/0); - auto arr = std::static_pointer_cast(rand.ArrayOf(type, 1000, /*null_probability=*/0.01)); - // use arr inverted null bits as mask, so expect to replace all null values... - auto mask_data = std::make_shared(boolean(), arr->length(), 0); - mask_data->null_count = 0; - mask_data->buffers.resize(2); - mask_data->buffers[0] = nullptr; - mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); - InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), - mask_data->buffers[1]->mutable_data(), mask_data->offset); - std::shared_ptr expected_data = arr->data()->Copy(); - expected_data->null_count = 0; - expected_data->buffers[0] = nullptr; - expected_data->buffers[1] = *AllocateBuffer(arr->length() * sizeof(T)); - T* out_data = expected_data->GetMutableValues(1); - for (int64_t i = 0 ; i < arr->length() ; ++i) { - if (arr->IsValid(i)) { - out_data[i] = arr->Value(i); - } else { - out_data[i] = scalar->value; - } + using T = typename TypeParam::c_type; + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + auto type = TypeTraits::type_singleton(); + auto scalar = std::make_shared(static_cast(42)); + + // No replacement + CheckReplace(type, "[2, 4, 7, 9]", "[false, false, false, false]", Datum(scalar), + "[2, 4, 7, 9]"); + // Some replacements + CheckReplace(type, "[2, 4, 7, 9]", "[true, false, true, false]", Datum(scalar), + "[42, 4, 42, 9]"); + // Empty Array + CheckReplace(type, "[]", "[]", Datum(scalar), "[]"); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto arr = std::static_pointer_cast( + rand.ArrayOf(type, 1000, /*null_probability=*/0.01)); + // use arr inverted null bits as mask, so expect to replace all null values... + auto mask_data = std::make_shared(boolean(), arr->length(), 0); + mask_data->null_count = 0; + mask_data->buffers.resize(2); + mask_data->buffers[0] = nullptr; + mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), + mask_data->buffers[1]->mutable_data(), mask_data->offset); + std::shared_ptr expected_data = arr->data()->Copy(); + expected_data->null_count = 0; + expected_data->buffers[0] = nullptr; + expected_data->buffers[1] = *AllocateBuffer(arr->length() * sizeof(T)); + T* out_data = expected_data->GetMutableValues(1); + for (int64_t i = 0; i < arr->length(); ++i) { + if (arr->IsValid(i)) { + out_data[i] = arr->Value(i); + } else { + out_data[i] = scalar->value; } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data)); + } + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data)); } TEST_F(TestReplaceKernel, ReplaceNull) { - auto datum = Datum(std::make_shared()); - CheckReplace(null(), - "[null, null, null, null]", - "[true, true, true, true]", - /*replacement=*/datum, - "[null, null, null, null]"); + auto datum = Datum(std::make_shared()); + CheckReplace(null(), "[null, null, null, null]", "[true, true, true, true]", + /*replacement=*/datum, "[null, null, null, null]"); } TEST_F(TestReplaceKernel, ReplaceBoolean) { - auto scalar1 = std::make_shared(false); - auto scalar2 = std::make_shared(true); - - // No replacement - CheckReplace(boolean(), - "[true, false, true, false]", - "[false, false, false, false]", - Datum(scalar1), - "[true, false, true, false]"); - // Some replacements - CheckReplace(boolean(), - "[true, false, true, false]", - "[true, false, true, false]", - Datum(scalar1), - "[false, false, false, false]"); - - random::RandomArrayGenerator rand(/*seed=*/0); - auto arr = std::static_pointer_cast(rand.Boolean(1000, - /*true_probability=*/0.5, - /*null_probability=*/0.01)); - // use arr inverted null bits as mask, so expect to replace all null values... - auto mask_data = std::make_shared(boolean(), arr->length(), 0); - mask_data->null_count = 0; - mask_data->buffers.resize(2); - mask_data->buffers[0] = nullptr; - mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); - InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), - mask_data->buffers[1]->mutable_data(), mask_data->offset); - auto expected_data = arr->data()->Copy(); - expected_data->null_count = 0; - expected_data->buffers[0] = nullptr; - expected_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); - uint8_t* out_data = expected_data->buffers[1]->mutable_data(); - for (int64_t i = 0 ; i < arr->length() ; ++i) { - if (arr->IsValid(i)) { - BitUtil::SetBitTo(out_data, i, arr->Value(i)); - } else { - BitUtil::SetBitTo(out_data, i, scalar1->value); - } + auto scalar1 = std::make_shared(false); + auto scalar2 = std::make_shared(true); + + // No replacement + CheckReplace(boolean(), "[true, false, true, false]", "[false, false, false, false]", + Datum(scalar1), "[true, false, true, false]"); + // Some replacements + CheckReplace(boolean(), "[true, false, true, false]", "[true, false, true, false]", + Datum(scalar1), "[false, false, false, false]"); + + random::RandomArrayGenerator rand(/*seed=*/0); + auto arr = + std::static_pointer_cast(rand.Boolean(1000, + /*true_probability=*/0.5, + /*null_probability=*/0.01)); + // use arr inverted null bits as mask, so expect to replace all null values... + auto mask_data = std::make_shared(boolean(), arr->length(), 0); + mask_data->null_count = 0; + mask_data->buffers.resize(2); + mask_data->buffers[0] = nullptr; + mask_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + InvertBitmap(arr->data()->buffers[0]->data(), arr->offset(), arr->length(), + mask_data->buffers[1]->mutable_data(), mask_data->offset); + auto expected_data = arr->data()->Copy(); + expected_data->null_count = 0; + expected_data->buffers[0] = nullptr; + expected_data->buffers[1] = *AllocateEmptyBitmap(arr->length()); + uint8_t* out_data = expected_data->buffers[1]->mutable_data(); + for (int64_t i = 0; i < arr->length(); ++i) { + if (arr->IsValid(i)) { + BitUtil::SetBitTo(out_data, i, arr->Value(i)); + } else { + BitUtil::SetBitTo(out_data, i, scalar1->value); } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), BooleanArray(expected_data)); + } + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), + BooleanArray(expected_data)); } TEST_F(TestReplaceKernel, ReplaceTimestamp) { - auto time32_type = time32(TimeUnit::SECOND); - auto time64_type = time64(TimeUnit::NANO); - auto scalar1 = std::make_shared(5, time32_type); - auto scalar2 = std::make_shared(6, time64_type); - // No replacement - CheckReplace(time32_type, - "[2, 1, 6, 9]", - "[false, false, false, false]", - Datum(scalar1), - "[2, 1, 6, 9]"); - CheckReplace(time64_type, - "[2, 1, 6, 9]", - "[false, false, false, false]", - Datum(scalar2), - "[2, 1, 6, 9]"); - // Some replacements - CheckReplace(time32_type, - "[2, 1, 6, 9]", - "[true, false, true, false]", - Datum(scalar1), - "[5, 1, 5, 9]"); - CheckReplace(time64_type, - "[2, 1, 6, 9]", - "[false, true, false, true]", - Datum(scalar2), - "[2, 6, 6, 6]"); + auto time32_type = time32(TimeUnit::SECOND); + auto time64_type = time64(TimeUnit::NANO); + auto scalar1 = std::make_shared(5, time32_type); + auto scalar2 = std::make_shared(6, time64_type); + // No replacement + CheckReplace(time32_type, "[2, 1, 6, 9]", "[false, false, false, false]", + Datum(scalar1), "[2, 1, 6, 9]"); + CheckReplace(time64_type, "[2, 1, 6, 9]", "[false, false, false, false]", + Datum(scalar2), "[2, 1, 6, 9]"); + // Some replacements + CheckReplace(time32_type, "[2, 1, 6, 9]", "[true, false, true, false]", Datum(scalar1), + "[5, 1, 5, 9]"); + CheckReplace(time64_type, "[2, 1, 6, 9]", "[false, true, false, true]", Datum(scalar2), + "[2, 6, 6, 6]"); } TEST_F(TestReplaceKernel, ReplaceString) { - auto type = large_utf8(); - auto scalar = std::make_shared("arrow"); - // No replacement - CheckReplace(type, - R"(["foo", "bar"])", - "[false, false]", - Datum(scalar), - R"(["foo", "bar"])"); - // Some replacements - CheckReplace(type, - R"(["foo", "bar"])", - "[true, false]", - Datum(scalar), - R"(["arrow", "bar"])"); + auto type = large_utf8(); + auto scalar = std::make_shared("arrow"); + // No replacement + CheckReplace(type, R"(["foo", "bar"])", "[false, false]", Datum(scalar), + R"(["foo", "bar"])"); + // Some replacements + CheckReplace(type, R"(["foo", "bar"])", "[true, false]", Datum(scalar), + R"(["arrow", "bar"])"); } } // namespace compute From 83192dc43bb2f7f03c228b271be8ebfbf7705802 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Tue, 29 Dec 2020 01:53:42 +0100 Subject: [PATCH 3/8] ARROW-11044: [C++] Add "replace" kernel Fix newline characters at end of file for CI lint check. --- cpp/src/arrow/compute/kernels/scalar_replace.cc | 2 +- cpp/src/arrow/compute/kernels/scalar_replace_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index f2036d359ba..d7f8f415c8f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -245,4 +245,4 @@ void RegisterScalarReplace(FunctionRegistry* registry) { } // namespace internal } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 06dc8360ace..1ae5ba54d52 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -191,4 +191,4 @@ TEST_F(TestReplaceKernel, ReplaceString) { } } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow From 28d53278eb3cbb91e9d7b4a156139f65070dacf3 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Tue, 29 Dec 2020 21:01:15 +0100 Subject: [PATCH 4/8] ARROW-11044: [C++] Add "replace" kernel Add null handling to "replace" kernel: * Add some more tests with nulls (start with the boolean kernel) --- cpp/src/arrow/compute/kernels/scalar_replace.cc | 16 ++++++++++++++++ .../compute/kernels/scalar_replace_test.cc | 17 ++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index d7f8f415c8f..eec3dca74b8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -45,6 +45,10 @@ struct ReplaceFunctor::value>> { const Scalar& replacement = *batch[2].scalar(); ArrayData* output = out->mutable_array(); + // Ensure the kernel is configured properly to have no validity bitmap / + // null count 0 unless we explicitly propagate it below. + DCHECK(output->buffers[0] == nullptr); + if (replacement.is_valid) { KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, ctx->Allocate(data.length * sizeof(T))); @@ -87,8 +91,16 @@ struct ReplaceFunctor::value>> { const Scalar& replacement = *batch[2].scalar(); ArrayData* output = out->mutable_array(); + // Ensure the kernel is configured properly to have no validity bitmap / + // null count 0 unless we explicitly propagate it below. + DCHECK(output->buffers[0] == nullptr); + bool value = UnboxScalar::Unbox(replacement); if (replacement.is_valid) { + + // TODO: Allocate bitmap and compute data.buffers[0] | (mask.buffers[0] & mask.buffers[1]) + // Then factor the code in a function to reuse in all ReplaceFunctors... + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, ctx->AllocateBitmap(data.length)); @@ -150,6 +162,10 @@ struct ReplaceFunctor::value>> { util::string_view replacement(*replacement_scalar.value); ArrayData* output = out->mutable_array(); + // Ensure the kernel is configured properly to have no validity bitmap / + // null count 0 unless we explicitly propagate it below. + DCHECK(output->buffers[0] == nullptr); + const uint8_t* to_replace = mask.buffers[1]->data(); uint64_t replace_count = 0; { diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 1ae5ba54d52..92c48124a26 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -117,9 +117,20 @@ TYPED_TEST(TestReplacePrimitive, Replace) { } TEST_F(TestReplaceKernel, ReplaceNull) { - auto datum = Datum(std::make_shared()); - CheckReplace(null(), "[null, null, null, null]", "[true, true, true, true]", - /*replacement=*/datum, "[null, null, null, null]"); + auto null_scalar = Datum(MakeNullScalar(boolean())); + auto true_scalar = Datum(MakeScalar(true)); + // Replace with invalid null value + CheckReplace(boolean(), "[null, null, null, null]", "[true, true, true, true]", + /*replacement=*/null_scalar, "[null, null, null, null]"); + // No replacement + CheckReplace(boolean(), "[null, null, null, null]", "[false, false, false, false]", + /*replacement=*/true_scalar, "[null, null, null, null]"); + // Some replacements + CheckReplace(boolean(), "[null, null, null, null]", "[true, false, true, false]", + /*replacement=*/true_scalar, "[true, null, true, null]"); + // Replace all + CheckReplace(boolean(), "[null, null, null, null]", "[true, true, true, true]", + /*replacement=*/true_scalar, "[true, true, true, true]"); } TEST_F(TestReplaceKernel, ReplaceBoolean) { From c27a696465f1e5c6e6e047eb12cf71e8b4c19a1c Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Wed, 30 Dec 2020 03:14:34 +0100 Subject: [PATCH 5/8] ARROW-11044: [C++] Add "replace" kernel Add null handling to "replace" kernel: * Add null handling to the boolean kernel implementation. * Test pass but still need more tests. --- .../arrow/compute/kernels/scalar_replace.cc | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index eec3dca74b8..aec7718a3b4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -101,9 +101,40 @@ struct ReplaceFunctor::value>> { // TODO: Allocate bitmap and compute data.buffers[0] | (mask.buffers[0] & mask.buffers[1]) // Then factor the code in a function to reuse in all ReplaceFunctors... + if (data.MayHaveNulls()) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_nulls, ctx, + ctx->AllocateBitmap(data.length)); + + if (mask.MayHaveNulls()) { + ::arrow::internal::BitmapAnd(mask.buffers[0]->data(), + mask.offset, + mask.buffers[1]->data(), + mask.offset, + mask.length, + output->offset, + out_nulls->mutable_data()); + ::arrow::internal::BitmapOr(data.buffers[0]->data(), + data.offset, + out_nulls->data(), + output->offset, + data.length, + output->offset, + out_nulls->mutable_data()); + } else { + ::arrow::internal::BitmapOr(data.buffers[0]->data(), + data.offset, + mask.buffers[1]->data(), + mask.offset, + mask.length, + output->offset, + out_nulls->mutable_data()); + } + + output->buffers[0] = out_nulls; + } + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, ctx->AllocateBitmap(data.length)); - const uint8_t* to_replace = mask.buffers[1]->data(); const uint8_t* data_bitmap = data.buffers[1]->data(); uint8_t* out_bitmap = out_buf->mutable_data(); From 509e67ae8e5d04e8ce00787735368b64b3c31c84 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Thu, 31 Dec 2020 02:01:08 +0100 Subject: [PATCH 6/8] ARROW-11044: [C++] Add "replace" kernel Generalize null handling to "replace" kernel: * Add null handling to the Number, Timestamps, and String kernel implementations. * Add more tests regarding null handling logic (with nulls either in input values and mask). * Still need to fix a tricky bug in test of the String kernel implementation. --- .../arrow/compute/kernels/scalar_replace.cc | 113 +++++++++++------- .../compute/kernels/scalar_replace_test.cc | 67 +++++++---- 2 files changed, 110 insertions(+), 70 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index aec7718a3b4..fce27be4c3b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -30,6 +30,41 @@ namespace internal { namespace { +void handle_nulls(KernelContext* ctx, const ArrayData& data, const ArrayData& mask, ArrayData* output) { + if (data.MayHaveNulls()) { + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_nulls, ctx, + ctx->AllocateBitmap(data.length)); + + if (mask.MayHaveNulls()) { + ::arrow::internal::BitmapAnd(mask.buffers[0]->data(), + mask.offset, + mask.buffers[1]->data(), + mask.offset, + mask.length, + output->offset, + out_nulls->mutable_data()); + ::arrow::internal::BitmapOr(data.buffers[0]->data(), + data.offset, + out_nulls->data(), + output->offset, + data.length, + output->offset, + out_nulls->mutable_data()); + } else { + ::arrow::internal::BitmapOr(data.buffers[0]->data(), + data.offset, + mask.buffers[1]->data(), + mask.offset, + mask.length, + output->offset, + out_nulls->mutable_data()); + } + + if (::arrow::internal::CountSetBits(out_nulls->data(), output->offset, data.length) < data.length) + output->buffers[0] = out_nulls; + } +} + template struct ReplaceFunctor {}; @@ -50,6 +85,8 @@ struct ReplaceFunctor::value>> { DCHECK(output->buffers[0] == nullptr); if (replacement.is_valid) { + handle_nulls(ctx, data, mask, output); + KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, ctx->Allocate(data.length * sizeof(T))); T value = UnboxScalar::Unbox(replacement); @@ -97,41 +134,7 @@ struct ReplaceFunctor::value>> { bool value = UnboxScalar::Unbox(replacement); if (replacement.is_valid) { - - // TODO: Allocate bitmap and compute data.buffers[0] | (mask.buffers[0] & mask.buffers[1]) - // Then factor the code in a function to reuse in all ReplaceFunctors... - - if (data.MayHaveNulls()) { - KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_nulls, ctx, - ctx->AllocateBitmap(data.length)); - - if (mask.MayHaveNulls()) { - ::arrow::internal::BitmapAnd(mask.buffers[0]->data(), - mask.offset, - mask.buffers[1]->data(), - mask.offset, - mask.length, - output->offset, - out_nulls->mutable_data()); - ::arrow::internal::BitmapOr(data.buffers[0]->data(), - data.offset, - out_nulls->data(), - output->offset, - data.length, - output->offset, - out_nulls->mutable_data()); - } else { - ::arrow::internal::BitmapOr(data.buffers[0]->data(), - data.offset, - mask.buffers[1]->data(), - mask.offset, - mask.length, - output->offset, - out_nulls->mutable_data()); - } - - output->buffers[0] = out_nulls; - } + handle_nulls(ctx, data, mask, output); KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx, ctx->AllocateBitmap(data.length)); @@ -197,10 +200,22 @@ struct ReplaceFunctor::value>> { // null count 0 unless we explicitly propagate it below. DCHECK(output->buffers[0] == nullptr); - const uint8_t* to_replace = mask.buffers[1]->data(); + const uint8_t* mask_validities = mask.buffers[0] == nullptr ? nullptr : mask.buffers[0]->data(); + const std::shared_ptr to_replace = mask.buffers[1]; + std::shared_ptr to_replace_valid = nullptr; uint64_t replace_count = 0; { - BitBlockCounter bit_counter(to_replace, input.offset, input.length); + if (mask_validities == nullptr) { + to_replace_valid = to_replace; + } else { + KERNEL_ASSIGN_OR_RAISE(to_replace_valid, + ctx, + ::arrow::internal::BitmapAnd(ctx->memory_pool(), + mask_validities, mask.offset, + to_replace->data(), mask.offset, + mask.length, output->offset)); + } + BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); int64_t out_offset = 0; while (out_offset < input.length) { BitBlockCount block = bit_counter.NextWord(); @@ -210,27 +225,33 @@ struct ReplaceFunctor::value>> { } if (replace_count > 0 && replacement_scalar.is_valid) { - auto input_offsets = input.GetValues(1); - auto input_values = input.GetValues(2, input.offset); + const uint8_t* input_validities = input.buffers[0] == nullptr ? nullptr : input.buffers[0]->data(); + const auto input_offsets = input.GetValues(1); + const auto input_values = input.GetValues(2, input.offset); BuilderType builder(input.type, ctx->memory_pool()); KERNEL_RETURN_IF_ERROR(ctx, builder.ReserveData(input.buffers[2]->size() + replacement.length() * replace_count)); KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); - BitBlockCounter bit_counter(to_replace, input.offset, input.length); + BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); int64_t input_offset = 0; while (input_offset < input.length) { BitBlockCount block = bit_counter.NextWord(); for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(to_replace, input_offset + i)) { + + if (BitUtil::GetBit(to_replace_valid->data(), input_offset + i)) { builder.UnsafeAppend(replacement); } else { - auto current_offset = input_offsets[input_offset + i]; - auto next_offset = input_offsets[input_offset + i + 1]; - auto string_value = util::string_view(input_values + current_offset, - next_offset - current_offset); - builder.UnsafeAppend(string_value); + if (input_validities == nullptr || BitUtil::GetBit(input_validities, input_offset + i)) { + auto current_offset = input_offsets[input_offset + i]; + auto next_offset = input_offsets[input_offset + i + 1]; + auto string_value = util::string_view(input_values + current_offset, + next_offset - current_offset); + builder.UnsafeAppend(string_value); + } else { + builder.UnsafeAppendNull(); + } } } input_offset += block.length; diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 92c48124a26..9b00d393d38 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -31,13 +31,20 @@ using internal::InvertBitmap; namespace compute { -void CheckReplace(const Array& input, const Array& mask, const Datum& replacement, - const Array& expected) { +void CheckReplace(const Array& input, + const Array& mask, + const Datum& replacement, + const Array& expected, + const bool ensure_no_nulls = false) { auto Check = [&](const Array& input, const Array& mask, const Array& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); std::shared_ptr result = datum_out.make_array(); ASSERT_OK(result->ValidateFull()); AssertArraysEqual(expected, *result, /*verbose=*/true); + if (ensure_no_nulls) { + if (result->null_count() != 0 || result->data()->buffers[0] != nullptr) + FAIL() << "Result shall have null_count == 0 and validity bitmap == nullptr!"; + } }; Check(input, mask, expected); @@ -48,11 +55,12 @@ void CheckReplace(const Array& input, const Array& mask, const Datum& replacemen void CheckReplace(const std::shared_ptr& type, const std::string& in_values, const std::string& in_mask, const Datum& replacement, - const std::string& out_values) { + const std::string& out_values, + const bool ensure_no_nulls = false) { std::shared_ptr input = ArrayFromJSON(type, in_values); std::shared_ptr mask = ArrayFromJSON(boolean(), in_mask); std::shared_ptr expected = ArrayFromJSON(type, out_values); - CheckReplace(*input, *mask, replacement, *expected); + CheckReplace(*input, *mask, replacement, *expected, ensure_no_nulls); } class TestReplaceKernel : public ::testing::Test {}; @@ -82,11 +90,11 @@ TYPED_TEST(TestReplacePrimitive, Replace) { auto scalar = std::make_shared(static_cast(42)); // No replacement - CheckReplace(type, "[2, 4, 7, 9]", "[false, false, false, false]", Datum(scalar), - "[2, 4, 7, 9]"); + CheckReplace(type, "[2, 4, 7, 9, null]", "[false, false, false, false, false]", Datum(scalar), + "[2, 4, 7, 9, null]"); // Some replacements - CheckReplace(type, "[2, 4, 7, 9]", "[true, false, true, false]", Datum(scalar), - "[42, 4, 42, 9]"); + CheckReplace(type, "[2, 4, 7, 9, null]", "[true, false, true, false, true]", Datum(scalar), + "[42, 4, 42, 9, 42]", true); // Empty Array CheckReplace(type, "[]", "[]", Datum(scalar), "[]"); @@ -113,7 +121,7 @@ TYPED_TEST(TestReplacePrimitive, Replace) { out_data[i] = scalar->value; } } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data)); + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data), true); } TEST_F(TestReplaceKernel, ReplaceNull) { @@ -128,9 +136,12 @@ TEST_F(TestReplaceKernel, ReplaceNull) { // Some replacements CheckReplace(boolean(), "[null, null, null, null]", "[true, false, true, false]", /*replacement=*/true_scalar, "[true, null, true, null]"); + // Some replacements with some nulls in mask + CheckReplace(boolean(), "[null, null, null, null]", "[true, null, true, false]", + /*replacement=*/true_scalar, "[true, null, true, null]"); // Replace all CheckReplace(boolean(), "[null, null, null, null]", "[true, true, true, true]", - /*replacement=*/true_scalar, "[true, true, true, true]"); + /*replacement=*/true_scalar, "[true, true, true, true]", true); } TEST_F(TestReplaceKernel, ReplaceBoolean) { @@ -143,6 +154,15 @@ TEST_F(TestReplaceKernel, ReplaceBoolean) { // Some replacements CheckReplace(boolean(), "[true, false, true, false]", "[true, false, true, false]", Datum(scalar1), "[false, false, false, false]"); + // Some replacements with nulls in input + CheckReplace(boolean(), "[true, null, true, null]", "[true, false, true, false]", + Datum(scalar1), "[false, null, false, null]"); + // Some replacements with nulls in mask + CheckReplace(boolean(), "[true, false, true, null]", "[true, null, null, false]", + Datum(scalar1), "[false, false, true, null]"); + // Replace all + CheckReplace(boolean(), "[true, false, true, null]", "[true, true, true, true]", + Datum(scalar1), "[false, false, false, false]", true); random::RandomArrayGenerator rand(/*seed=*/0); auto arr = @@ -169,8 +189,7 @@ TEST_F(TestReplaceKernel, ReplaceBoolean) { BitUtil::SetBitTo(out_data, i, scalar1->value); } } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), - BooleanArray(expected_data)); + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), BooleanArray(expected_data), true); } TEST_F(TestReplaceKernel, ReplaceTimestamp) { @@ -179,26 +198,26 @@ TEST_F(TestReplaceKernel, ReplaceTimestamp) { auto scalar1 = std::make_shared(5, time32_type); auto scalar2 = std::make_shared(6, time64_type); // No replacement - CheckReplace(time32_type, "[2, 1, 6, 9]", "[false, false, false, false]", - Datum(scalar1), "[2, 1, 6, 9]"); - CheckReplace(time64_type, "[2, 1, 6, 9]", "[false, false, false, false]", - Datum(scalar2), "[2, 1, 6, 9]"); + CheckReplace(time32_type, "[2, 1, 6, null]", "[false, false, false, false]", + Datum(scalar1), "[2, 1, 6, null]"); + CheckReplace(time64_type, "[2, 1, 6, null]", "[false, false, false, false]", + Datum(scalar2), "[2, 1, 6, null]"); // Some replacements - CheckReplace(time32_type, "[2, 1, 6, 9]", "[true, false, true, false]", Datum(scalar1), - "[5, 1, 5, 9]"); - CheckReplace(time64_type, "[2, 1, 6, 9]", "[false, true, false, true]", Datum(scalar2), - "[2, 6, 6, 6]"); + CheckReplace(time32_type, "[2, 1, null, 9]", "[true, false, true, false]", Datum(scalar1), + "[5, 1, 5, 9]", true); + CheckReplace(time64_type, "[2, 1, 6, null]", "[false, true, false, true]", Datum(scalar2), + "[2, 6, 6, 6]", true); } TEST_F(TestReplaceKernel, ReplaceString) { auto type = large_utf8(); auto scalar = std::make_shared("arrow"); // No replacement - CheckReplace(type, R"(["foo", "bar"])", "[false, false]", Datum(scalar), - R"(["foo", "bar"])"); + CheckReplace(type, R"(["foo", "bar", null])", "[false, false, false]", Datum(scalar), + R"(["foo", "bar", null])"); // Some replacements - CheckReplace(type, R"(["foo", "bar"])", "[true, false]", Datum(scalar), - R"(["arrow", "bar"])"); + CheckReplace(type, R"(["foo", "bar", null, null])", "[true, false, true, false]", Datum(scalar), + R"(["arrow", "bar", "arrow", null])"); } } // namespace compute From 7022b91db712287a2eacf546f785719085b01945 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Thu, 31 Dec 2020 15:35:02 +0100 Subject: [PATCH 7/8] ARROW-11044: [C++] Add "replace" kernel Fix String kernel implementation: * offsets were used inconsistently hence a bug on sliced arrays. * ArrayData::GetValues(i) use "byte offset" even for buffer[2] which store variable width values... => Should it rather use buffer[1] value offset ? --- .../arrow/compute/kernels/scalar_replace.cc | 31 ++++++++++--------- .../compute/kernels/scalar_replace_test.cc | 2 ++ 2 files changed, 19 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index fce27be4c3b..56e32392263 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -216,36 +216,39 @@ struct ReplaceFunctor::value>> { mask.length, output->offset)); } BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); - int64_t out_offset = 0; - while (out_offset < input.length) { + int64_t i = 0; + while (i < input.length) { BitBlockCount block = bit_counter.NextWord(); replace_count += block.popcount; - out_offset += block.length; + i += block.length; } } if (replace_count > 0 && replacement_scalar.is_valid) { const uint8_t* input_validities = input.buffers[0] == nullptr ? nullptr : input.buffers[0]->data(); - const auto input_offsets = input.GetValues(1); - const auto input_values = input.GetValues(2, input.offset); + const auto input_offsets = input.GetValues(1, input.offset); + // offset is 0 otherwise GetValue() will "shift" the buffer by input.offset bytes + // (should it rather shift by the lengths of the first input.offset string values ?) + const auto input_values = input.GetValues(2, 0); BuilderType builder(input.type, ctx->memory_pool()); KERNEL_RETURN_IF_ERROR(ctx, - builder.ReserveData(input.buffers[2]->size() + - replacement.length() * replace_count)); + builder.ReserveData(input.buffers[2]->size() + - input_offsets[0] + + replace_count * replacement.length())); KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); - int64_t input_offset = 0; - while (input_offset < input.length) { + int64_t j = 0; + while (j < input.length) { BitBlockCount block = bit_counter.NextWord(); for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(to_replace_valid->data(), input_offset + i)) { + if (BitUtil::GetBit(to_replace_valid->data(), input.offset + j + i)) { builder.UnsafeAppend(replacement); } else { - if (input_validities == nullptr || BitUtil::GetBit(input_validities, input_offset + i)) { - auto current_offset = input_offsets[input_offset + i]; - auto next_offset = input_offsets[input_offset + i + 1]; + if (input_validities == nullptr || BitUtil::GetBit(input_validities, input.offset + j + i)) { + auto current_offset = input_offsets[j + i]; + auto next_offset = input_offsets[j + i + 1]; auto string_value = util::string_view(input_values + current_offset, next_offset - current_offset); builder.UnsafeAppend(string_value); @@ -254,7 +257,7 @@ struct ReplaceFunctor::value>> { } } } - input_offset += block.length; + j += block.length; } std::shared_ptr string_array; KERNEL_RETURN_IF_ERROR(ctx, builder.Finish(&string_array)); diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 9b00d393d38..540d33f3ebd 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -36,6 +36,7 @@ void CheckReplace(const Array& input, const Datum& replacement, const Array& expected, const bool ensure_no_nulls = false) { + auto Check = [&](const Array& input, const Array& mask, const Array& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); std::shared_ptr result = datum_out.make_array(); @@ -48,6 +49,7 @@ void CheckReplace(const Array& input, }; Check(input, mask, expected); + if (input.length() > 0) { Check(*input.Slice(1), *mask.Slice(1), *expected.Slice(1)); } From a1032045b8e6300d615036d3b3a181b173d3ffc2 Mon Sep 17 00:00:00 2001 From: Bruno LE HYARIC Date: Thu, 31 Dec 2020 15:56:32 +0100 Subject: [PATCH 8/8] ARROW-11044: [C++] Add "replace" kernel All type kernels implemented and tested, polish PR: * apply clang-format to pass CI lint check. --- .../arrow/compute/kernels/scalar_replace.cc | 62 ++++++++----------- .../compute/kernels/scalar_replace_test.cc | 45 +++++++------- 2 files changed, 47 insertions(+), 60 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_replace.cc b/cpp/src/arrow/compute/kernels/scalar_replace.cc index 56e32392263..13b27d52a19 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace.cc @@ -30,37 +30,27 @@ namespace internal { namespace { -void handle_nulls(KernelContext* ctx, const ArrayData& data, const ArrayData& mask, ArrayData* output) { +void handle_nulls(KernelContext* ctx, const ArrayData& data, const ArrayData& mask, + ArrayData* output) { if (data.MayHaveNulls()) { KERNEL_ASSIGN_OR_RAISE(std::shared_ptr out_nulls, ctx, ctx->AllocateBitmap(data.length)); if (mask.MayHaveNulls()) { - ::arrow::internal::BitmapAnd(mask.buffers[0]->data(), - mask.offset, - mask.buffers[1]->data(), - mask.offset, - mask.length, - output->offset, - out_nulls->mutable_data()); - ::arrow::internal::BitmapOr(data.buffers[0]->data(), - data.offset, - out_nulls->data(), - output->offset, - data.length, - output->offset, + ::arrow::internal::BitmapAnd(mask.buffers[0]->data(), mask.offset, + mask.buffers[1]->data(), mask.offset, mask.length, + output->offset, out_nulls->mutable_data()); + ::arrow::internal::BitmapOr(data.buffers[0]->data(), data.offset, out_nulls->data(), + output->offset, data.length, output->offset, out_nulls->mutable_data()); } else { - ::arrow::internal::BitmapOr(data.buffers[0]->data(), - data.offset, - mask.buffers[1]->data(), - mask.offset, - mask.length, - output->offset, - out_nulls->mutable_data()); + ::arrow::internal::BitmapOr(data.buffers[0]->data(), data.offset, + mask.buffers[1]->data(), mask.offset, mask.length, + output->offset, out_nulls->mutable_data()); } - if (::arrow::internal::CountSetBits(out_nulls->data(), output->offset, data.length) < data.length) + if (::arrow::internal::CountSetBits(out_nulls->data(), output->offset, data.length) < + data.length) output->buffers[0] = out_nulls; } } @@ -200,7 +190,8 @@ struct ReplaceFunctor::value>> { // null count 0 unless we explicitly propagate it below. DCHECK(output->buffers[0] == nullptr); - const uint8_t* mask_validities = mask.buffers[0] == nullptr ? nullptr : mask.buffers[0]->data(); + const uint8_t* mask_validities = + mask.buffers[0] == nullptr ? nullptr : mask.buffers[0]->data(); const std::shared_ptr to_replace = mask.buffers[1]; std::shared_ptr to_replace_valid = nullptr; uint64_t replace_count = 0; @@ -208,12 +199,11 @@ struct ReplaceFunctor::value>> { if (mask_validities == nullptr) { to_replace_valid = to_replace; } else { - KERNEL_ASSIGN_OR_RAISE(to_replace_valid, - ctx, - ::arrow::internal::BitmapAnd(ctx->memory_pool(), - mask_validities, mask.offset, - to_replace->data(), mask.offset, - mask.length, output->offset)); + KERNEL_ASSIGN_OR_RAISE( + to_replace_valid, ctx, + ::arrow::internal::BitmapAnd(ctx->memory_pool(), mask_validities, mask.offset, + to_replace->data(), mask.offset, mask.length, + output->offset)); } BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); int64_t i = 0; @@ -225,16 +215,16 @@ struct ReplaceFunctor::value>> { } if (replace_count > 0 && replacement_scalar.is_valid) { - const uint8_t* input_validities = input.buffers[0] == nullptr ? nullptr : input.buffers[0]->data(); + const uint8_t* input_validities = + input.buffers[0] == nullptr ? nullptr : input.buffers[0]->data(); const auto input_offsets = input.GetValues(1, input.offset); // offset is 0 otherwise GetValue() will "shift" the buffer by input.offset bytes // (should it rather shift by the lengths of the first input.offset string values ?) const auto input_values = input.GetValues(2, 0); BuilderType builder(input.type, ctx->memory_pool()); - KERNEL_RETURN_IF_ERROR(ctx, - builder.ReserveData(input.buffers[2]->size() - - input_offsets[0] - + replace_count * replacement.length())); + KERNEL_RETURN_IF_ERROR( + ctx, builder.ReserveData(input.buffers[2]->size() - input_offsets[0] + + replace_count * replacement.length())); KERNEL_RETURN_IF_ERROR(ctx, builder.Resize(input.length)); BitBlockCounter bit_counter(to_replace_valid->data(), input.offset, input.length); @@ -242,11 +232,11 @@ struct ReplaceFunctor::value>> { while (j < input.length) { BitBlockCount block = bit_counter.NextWord(); for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(to_replace_valid->data(), input.offset + j + i)) { builder.UnsafeAppend(replacement); } else { - if (input_validities == nullptr || BitUtil::GetBit(input_validities, input.offset + j + i)) { + if (input_validities == nullptr || + BitUtil::GetBit(input_validities, input.offset + j + i)) { auto current_offset = input_offsets[j + i]; auto next_offset = input_offsets[j + i + 1]; auto string_value = util::string_view(input_values + current_offset, diff --git a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc index 540d33f3ebd..65647e67846 100644 --- a/cpp/src/arrow/compute/kernels/scalar_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_replace_test.cc @@ -31,12 +31,8 @@ using internal::InvertBitmap; namespace compute { -void CheckReplace(const Array& input, - const Array& mask, - const Datum& replacement, - const Array& expected, - const bool ensure_no_nulls = false) { - +void CheckReplace(const Array& input, const Array& mask, const Datum& replacement, + const Array& expected, const bool ensure_no_nulls = false) { auto Check = [&](const Array& input, const Array& mask, const Array& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, Replace(input, mask, replacement)); std::shared_ptr result = datum_out.make_array(); @@ -57,8 +53,7 @@ void CheckReplace(const Array& input, void CheckReplace(const std::shared_ptr& type, const std::string& in_values, const std::string& in_mask, const Datum& replacement, - const std::string& out_values, - const bool ensure_no_nulls = false) { + const std::string& out_values, const bool ensure_no_nulls = false) { std::shared_ptr input = ArrayFromJSON(type, in_values); std::shared_ptr mask = ArrayFromJSON(boolean(), in_mask); std::shared_ptr expected = ArrayFromJSON(type, out_values); @@ -92,11 +87,11 @@ TYPED_TEST(TestReplacePrimitive, Replace) { auto scalar = std::make_shared(static_cast(42)); // No replacement - CheckReplace(type, "[2, 4, 7, 9, null]", "[false, false, false, false, false]", Datum(scalar), - "[2, 4, 7, 9, null]"); + CheckReplace(type, "[2, 4, 7, 9, null]", "[false, false, false, false, false]", + Datum(scalar), "[2, 4, 7, 9, null]"); // Some replacements - CheckReplace(type, "[2, 4, 7, 9, null]", "[true, false, true, false, true]", Datum(scalar), - "[42, 4, 42, 9, 42]", true); + CheckReplace(type, "[2, 4, 7, 9, null]", "[true, false, true, false, true]", + Datum(scalar), "[42, 4, 42, 9, 42]", true); // Empty Array CheckReplace(type, "[]", "[]", Datum(scalar), "[]"); @@ -123,7 +118,8 @@ TYPED_TEST(TestReplacePrimitive, Replace) { out_data[i] = scalar->value; } } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data), true); + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar), ArrayType(expected_data), + true); } TEST_F(TestReplaceKernel, ReplaceNull) { @@ -140,7 +136,7 @@ TEST_F(TestReplaceKernel, ReplaceNull) { /*replacement=*/true_scalar, "[true, null, true, null]"); // Some replacements with some nulls in mask CheckReplace(boolean(), "[null, null, null, null]", "[true, null, true, false]", - /*replacement=*/true_scalar, "[true, null, true, null]"); + /*replacement=*/true_scalar, "[true, null, true, null]"); // Replace all CheckReplace(boolean(), "[null, null, null, null]", "[true, true, true, true]", /*replacement=*/true_scalar, "[true, true, true, true]", true); @@ -158,13 +154,13 @@ TEST_F(TestReplaceKernel, ReplaceBoolean) { Datum(scalar1), "[false, false, false, false]"); // Some replacements with nulls in input CheckReplace(boolean(), "[true, null, true, null]", "[true, false, true, false]", - Datum(scalar1), "[false, null, false, null]"); + Datum(scalar1), "[false, null, false, null]"); // Some replacements with nulls in mask CheckReplace(boolean(), "[true, false, true, null]", "[true, null, null, false]", - Datum(scalar1), "[false, false, true, null]"); + Datum(scalar1), "[false, false, true, null]"); // Replace all CheckReplace(boolean(), "[true, false, true, null]", "[true, true, true, true]", - Datum(scalar1), "[false, false, false, false]", true); + Datum(scalar1), "[false, false, false, false]", true); random::RandomArrayGenerator rand(/*seed=*/0); auto arr = @@ -191,7 +187,8 @@ TEST_F(TestReplaceKernel, ReplaceBoolean) { BitUtil::SetBitTo(out_data, i, scalar1->value); } } - CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), BooleanArray(expected_data), true); + CheckReplace(*arr, BooleanArray(mask_data), Datum(scalar1), BooleanArray(expected_data), + true); } TEST_F(TestReplaceKernel, ReplaceTimestamp) { @@ -205,10 +202,10 @@ TEST_F(TestReplaceKernel, ReplaceTimestamp) { CheckReplace(time64_type, "[2, 1, 6, null]", "[false, false, false, false]", Datum(scalar2), "[2, 1, 6, null]"); // Some replacements - CheckReplace(time32_type, "[2, 1, null, 9]", "[true, false, true, false]", Datum(scalar1), - "[5, 1, 5, 9]", true); - CheckReplace(time64_type, "[2, 1, 6, null]", "[false, true, false, true]", Datum(scalar2), - "[2, 6, 6, 6]", true); + CheckReplace(time32_type, "[2, 1, null, 9]", "[true, false, true, false]", + Datum(scalar1), "[5, 1, 5, 9]", true); + CheckReplace(time64_type, "[2, 1, 6, null]", "[false, true, false, true]", + Datum(scalar2), "[2, 6, 6, 6]", true); } TEST_F(TestReplaceKernel, ReplaceString) { @@ -218,8 +215,8 @@ TEST_F(TestReplaceKernel, ReplaceString) { CheckReplace(type, R"(["foo", "bar", null])", "[false, false, false]", Datum(scalar), R"(["foo", "bar", null])"); // Some replacements - CheckReplace(type, R"(["foo", "bar", null, null])", "[true, false, true, false]", Datum(scalar), - R"(["arrow", "bar", "arrow", null])"); + CheckReplace(type, R"(["foo", "bar", null, null])", "[true, false, true, false]", + Datum(scalar), R"(["arrow", "bar", "arrow", null])"); } } // namespace compute