From 0079ecc1d8b0c1546f98c1c30daa0ea692117a86 Mon Sep 17 00:00:00 2001 From: c-jamie Date: Fri, 3 Jul 2020 11:45:05 +0100 Subject: [PATCH 1/3] ARROW-1587: [C++] implement fill null --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_scalar.cc | 19 ++ cpp/src/arrow/compute/api_scalar.h | 21 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/scalar_fill_null.cc | 223 ++++++++++++++++++ .../compute/kernels/scalar_fill_null_test.cc | 137 +++++++++++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + 8 files changed, 404 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/scalar_fill_null.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_fill_null_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index eea19e54954..91e67fb423a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -363,6 +363,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_set_lookup.cc compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc + compute/kernels/scalar_fill_null.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 1b2b8991d9b..03fe2763a92 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -126,5 +126,24 @@ Result Compare(const Datum& left, const Datum& right, CompareOptions opti SCALAR_EAGER_UNARY(IsValid, "is_valid") SCALAR_EAGER_UNARY(IsNull, "is_null") +Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx) { + if (!values.is_arraylike()) { + return Status::Invalid("Values must be Array or ChunkedArray"); + } + + if (!fill_value.is_scalar()) { + return Status::Invalid("fill value must be a scalar"); + } + + if (!values.type()->Equals(fill_value.type())) { + std::stringstream ss; + ss << "Array type didn't match type of fill value: " << values.type()->ToString() + << " vs " << fill_value.type()->ToString(); + return Status::Invalid(ss.str()); + } + FillNullOptions options(fill_value); + return CallFunction("fill_null", {values}, &options, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index d513173d76f..2c235501b99 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -259,6 +259,27 @@ Result IsValid(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); +struct ARROW_EXPORT FillNullOptions : public FunctionOptions { + explicit FillNullOptions(Datum fill_value) : fill_value(std::move(fill_value)) {} + + Datum fill_value; +}; + +/// \brief FillNull replaces each null element in `values` +/// with `fill_value` +/// +/// \param[in] values input to examine for nullity +/// \param[in] fill_value scalar +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result FillNull(const Datum& values, const Datum& fill_value, + ExecContext* ctx = NULLPTR); + // ---------------------------------------------------------------------- // String functions diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index e693a4176ab..fc147e3a69b 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -28,6 +28,7 @@ add_arrow_compute_test(scalar_test scalar_set_lookup_test.cc scalar_string_test.cc scalar_validity_test.cc + scalar_fill_null_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_fill_null.cc b/cpp/src/arrow/compute/kernels/scalar_fill_null.cc new file mode 100644 index 00000000000..843cdf45776 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_fill_null.cc @@ -0,0 +1,223 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/array/array_base.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/common.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_writer.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +namespace compute { +namespace internal { +namespace { + +template +using enable_if_supports_fill_null = enable_if_t::value, R>; + +template +struct FillNullState : public KernelState { + explicit FillNullState(MemoryPool* pool) {} + + Status Init(const FillNullOptions& options) { + fill_value = options.fill_value.scalar(); + return Status::OK(); + } + + std::shared_ptr fill_value; +}; + +template <> +struct FillNullState : public KernelState { + explicit FillNullState(MemoryPool*) {} + + Status Init(const FillNullOptions& options) { return Status::OK(); } + + std::shared_ptr fill_value; +}; + +struct InitFillNullStateVisitor { + KernelContext* ctx; + const FillNullOptions* options; + std::unique_ptr result; + + InitFillNullStateVisitor(KernelContext* ctx, const FillNullOptions* options) + : ctx(ctx), options(options) {} + + template + Status Init() { + using StateType = FillNullState; + result.reset(new StateType(ctx->exec_context()->memory_pool())); + return static_cast(result.get())->Init(*options); + } + + Status Visit(const DataType&) { return Init(); } + + template + enable_if_supports_fill_null Visit(const Type&) { + return Init(); + } + + Status GetResult(std::unique_ptr* out) { + RETURN_NOT_OK(VisitTypeInline(*options->fill_value.type(), this)); + *out = std::move(result); + return Status::OK(); + } +}; + +std::unique_ptr InitFillNull(KernelContext* ctx, + const KernelInitArgs& args) { + InitFillNullStateVisitor visitor{ctx, + static_cast(args.options)}; + std::unique_ptr result; + ctx->SetStatus(visitor.GetResult(&result)); + return result; +} + +struct ScalarFillVisitor { + KernelContext* ctx; + const ArrayData& data; + Datum* out; + + ScalarFillVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) + : ctx(ctx), data(data), out(out) {} + + Status Visit(const DataType&) { + ArrayData* out_arr = out->mutable_array(); + *out_arr = data; + return Status::OK(); + } + + Status Visit(const BooleanType&) { + const auto& state = checked_cast&>(*ctx->state()); + bool value = UnboxScalar::Unbox(*state.fill_value); + ArrayData* out_arr = out->mutable_array(); + FirstTimeBitmapWriter bit_writer(out_arr->buffers[1]->mutable_data(), out_arr->offset, + out_arr->length); + FirstTimeBitmapWriter bit_writer_validity(out_arr->buffers[0]->mutable_data(), + out_arr->offset, out_arr->length); + if (data.null_count != 0) { + BitmapReader bit_reader(data.buffers[1]->data(), data.offset, data.length); + BitmapReader bit_reader_validity(data.buffers[0]->data(), data.offset, data.length); + for (int64_t i = 0; i < data.length; i++) { + if (bit_reader_validity.IsNotSet()) { + if (value == true) { + bit_writer.Set(); + } else { + bit_writer.Clear(); + } + bit_writer_validity.Set(); + } else { + if (bit_reader.IsSet()) { + bit_writer.Set(); + } else { + bit_writer.Clear(); + } + bit_writer_validity.Set(); + } + bit_reader.Next(); + bit_writer.Next(); + bit_reader_validity.Next(); + bit_writer_validity.Next(); + } + bit_writer_validity.Finish(); + bit_writer.Finish(); + } else { + *out_arr = data; + } + return Status::OK(); + } + + template + enable_if_supports_fill_null Visit(const Type&) { + using T = typename GetViewType::T; + const auto& state = checked_cast&>(*ctx->state()); + T value = UnboxScalar::Unbox(*state.fill_value); + const T* in_data = data.GetValues(1); + ArrayData* out_arr = out->mutable_array(); + auto out_data = out_arr->GetMutableValues(1); + + if (data.null_count != 0) { + BitmapReader bit_reader(data.buffers[0]->data(), data.offset, data.length); + for (int64_t i = 0; i < data.length; i++) { + if (bit_reader.IsNotSet()) { + out_data[i] = value; + } else { + out_data[i] = static_cast(in_data[i]); + } + bit_reader.Next(); + } + BitUtil::SetBitsTo(out_arr->buffers[0]->mutable_data(), out_arr->offset, + out_arr->length, true); + } else { + *out_arr = data; + } + return Status::OK(); + } + + Status Execute() { return VisitTypeInline(*data.type, this); } +}; + +void ExecFillNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ScalarFillVisitor dispatch(ctx, *batch[0].array(), out); + ctx->SetStatus(dispatch.Execute()); +} + +void AddBasicFillNullKernels(ScalarKernel kernel, ScalarFunction* func) { + auto AddKernels = [&](const std::vector>& types) { + for (const std::shared_ptr& ty : types) { + kernel.signature = KernelSignature::Make({InputType::Array(ty)}, ty); + DCHECK_OK(func->AddKernel(kernel)); + } + }; + + AddKernels(NumericTypes()); + AddKernels(TemporalTypes()); + + std::vector> other_types = {boolean()}; + + for (auto ty : other_types) { + kernel.signature = KernelSignature::Make({InputType::Array(ty)}, ty); + DCHECK_OK(func->AddKernel(kernel)); + } +} + +} // namespace + +void RegisterScalarFillNull(FunctionRegistry* registry) { + // Fill Null always writes into preallocated memory + { + ScalarKernel fill_null_base; + fill_null_base.init = InitFillNull; + fill_null_base.exec = ExecFillNull; + auto fill_null = std::make_shared("fill_null", Arity::Unary()); + + AddBasicFillNullKernels(fill_null_base, fill_null.get()); + fill_null_base.signature = KernelSignature::Make({InputType::Array(null())}, null()); + fill_null_base.null_handling = NullHandling::COMPUTED_PREALLOCATE; + DCHECK_OK(fill_null->AddKernel(fill_null_base)); + DCHECK_OK(registry->AddFunction(fill_null)); + } +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_fill_null_test.cc b/cpp/src/arrow/compute/kernels/scalar_fill_null_test.cc new file mode 100644 index 00000000000..5ce85ddb093 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_fill_null_test.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "arrow/array/array_base.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/compute/api.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/testing/gtest_compat.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" + +namespace arrow { +namespace compute { + +template ::c_type> +void CheckFillNull(const std::shared_ptr& type, const std::vector& in_values, + const std::vector& in_is_valid, const Datum fill_value, + const std::vector& out_values, + const std::vector& out_is_valid) { + std::shared_ptr input = _MakeArray(type, in_values, in_is_valid); + std::shared_ptr expected = _MakeArray(type, out_values, out_is_valid); + + ASSERT_OK_AND_ASSIGN(Datum datum_out, FillNull(input, fill_value)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(*expected, *result, /*verbose=*/true); +} + +class TestFillNullKernel : public ::testing::Test {}; + +template +class TestFillNullPrimitive : public ::testing::Test {}; + +typedef ::testing::Types + PrimitiveTypes; + +TYPED_TEST_SUITE(TestFillNullPrimitive, PrimitiveTypes); + +TYPED_TEST(TestFillNullPrimitive, FillNull) { + using T = typename TypeParam::c_type; + using ScalarType = typename TypeTraits::ScalarType; + auto type = TypeTraits::type_singleton(); + auto scalar = std::make_shared(static_cast(5)); + // No Nulls + CheckFillNull(type, {2, 4, 7, 9}, {true, true, true, true}, Datum(scalar), + {2, 4, 7, 9}, {true, true, true, true}); + // Some Nulls + CheckFillNull(type, {2, 4, 7, 8}, {false, true, false, true}, + Datum(scalar), {5, 4, 5, 8}, {true, true, true, true}); + // Empty Array + CheckFillNull(type, {}, {}, Datum(scalar), {}, {}); +} + +TEST_F(TestFillNullKernel, FillNullNull) { + auto datum = Datum(std::make_shared()); + CheckFillNull(null(), {0, 0, 0, 0}, + {false, false, false, false}, datum, + {0, 0, 0, 0}, {false, false, false, false}); + CheckFillNull(null(), {NULL, NULL, NULL, NULL}, {}, datum, + {NULL, NULL, NULL, NULL}, {}); + CheckFillNull(null(), {0, 0, 0, 0}, + {false, false, false, false}, datum, + {0, 0, 0, 0}, {false, false, false, false}); + CheckFillNull(null(), {NULL, NULL, NULL, NULL}, {}, datum, + {NULL, NULL, NULL, NULL}, {}); +} + +TEST_F(TestFillNullKernel, FillNullBoolean) { + auto scalar1 = std::make_shared(false); + auto scalar2 = std::make_shared(true); + // no nulls + CheckFillNull(boolean(), {true, false, true, false}, + {true, true, true, true}, Datum(scalar1), + {true, false, true, false}, {true, true, true, true}); + // some nulls + CheckFillNull(boolean(), {true, false, true, false}, + {false, true, true, false}, Datum(scalar1), + {false, false, true, false}, {true, true, true, true}); + CheckFillNull(boolean(), {true, false, true, false}, + {false, true, false, false}, Datum(scalar2), + {true, false, true, true}, {true, true, true, true}); +} + +TEST_F(TestFillNullKernel, FillNullTimeStamp) { + auto time32_type = time32(TimeUnit::SECOND); + auto time64_type = time64(TimeUnit::NANO); + auto scalar1 = Datum(std::make_shared(5, time32_type)); + auto scalar2 = Datum(std::make_shared(6, time64_type)); + // no nulls + CheckFillNull(time32_type, {2, 1, 6, 9}, {true, true, true, true}, + Datum(scalar1), {2, 1, 6, 9}, + {true, true, true, true}); + CheckFillNull(time32_type, {2, 1, 6, 9}, + {true, false, true, false}, Datum(scalar1), + {2, 5, 6, 5}, {true, true, true, true}); + // some nulls + CheckFillNull(time64_type, {2, 1, 6, 9}, {true, true, true, true}, + scalar2, {2, 1, 6, 9}, {true, true, true, true}); + CheckFillNull(time64_type, {2, 1, 6, 9}, + {true, false, true, false}, scalar2, {2, 6, 6, 6}, + {true, true, true, true}); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 1e418170159..ed9f38fe31d 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -107,6 +107,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarSetLookup(registry.get()); RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); + RegisterScalarFillNull(registry.get()); // Aggregate functions RegisterScalarAggregateBasic(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index 993aa59a1d0..e6c68efe7dc 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -33,6 +33,7 @@ void RegisterScalarNested(FunctionRegistry* registry); void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); +void RegisterScalarFillNull(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); From def327dee3fa30b75a03755f3c674885f1838c98 Mon Sep 17 00:00:00 2001 From: c-jamie Date: Mon, 6 Jul 2020 17:32:46 +0100 Subject: [PATCH 2/3] address review comments --- cpp/src/arrow/compute/api_scalar.cc | 17 +- cpp/src/arrow/compute/api_scalar.h | 6 - .../arrow/compute/kernels/scalar_fill_null.cc | 291 ++++++++++-------- .../compute/kernels/scalar_fill_null_test.cc | 75 ++--- 4 files changed, 186 insertions(+), 203 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 03fe2763a92..77893f74fcd 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -127,22 +127,7 @@ SCALAR_EAGER_UNARY(IsValid, "is_valid") SCALAR_EAGER_UNARY(IsNull, "is_null") Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx) { - if (!values.is_arraylike()) { - return Status::Invalid("Values must be Array or ChunkedArray"); - } - - if (!fill_value.is_scalar()) { - return Status::Invalid("fill value must be a scalar"); - } - - if (!values.type()->Equals(fill_value.type())) { - std::stringstream ss; - ss << "Array type didn't match type of fill value: " << values.type()->ToString() - << " vs " << fill_value.type()->ToString(); - return Status::Invalid(ss.str()); - } - FillNullOptions options(fill_value); - return CallFunction("fill_null", {values}, &options, ctx); + return CallFunction("fill_null", {values, fill_value}, ctx); } } // namespace compute diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 2c235501b99..858e1ff6a19 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -259,12 +259,6 @@ Result IsValid(const Datum& values, ExecContext* ctx = NULLPTR); ARROW_EXPORT Result IsNull(const Datum& values, ExecContext* ctx = NULLPTR); -struct ARROW_EXPORT FillNullOptions : public FunctionOptions { - explicit FillNullOptions(Datum fill_value) : fill_value(std::move(fill_value)) {} - - Datum fill_value; -}; - /// \brief FillNull replaces each null element in `values` /// with `fill_value` /// diff --git a/cpp/src/arrow/compute/kernels/scalar_fill_null.cc b/cpp/src/arrow/compute/kernels/scalar_fill_null.cc index 843cdf45776..f5cc5c069b6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_fill_null.cc +++ b/cpp/src/arrow/compute/kernels/scalar_fill_null.cc @@ -1,4 +1,3 @@ - // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -16,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/array/array_base.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" @@ -26,194 +24,217 @@ namespace arrow { +using internal::BitBlockCount; +using internal::BitBlockCounter; + namespace compute { namespace internal { -namespace { - -template -using enable_if_supports_fill_null = enable_if_t::value, R>; - -template -struct FillNullState : public KernelState { - explicit FillNullState(MemoryPool* pool) {} - Status Init(const FillNullOptions& options) { - fill_value = options.fill_value.scalar(); - return Status::OK(); - } +namespace { - std::shared_ptr fill_value; -}; +template +struct FillNullFunctor {}; -template <> -struct FillNullState : public KernelState { - explicit FillNullState(MemoryPool*) {} +template +struct FillNullFunctor::value>> { + using value_type = typename TypeTraits::CType; + using BuilderType = typename TypeTraits::BuilderType; - Status Init(const FillNullOptions& options) { return Status::OK(); } + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Datum& in_arr = batch[0]; + const Datum& fill_value = batch[1]; - std::shared_ptr fill_value; -}; + if (!in_arr.is_arraylike()) { + ctx->SetStatus(Status::Invalid("Values must be Array or ChunkedArray")); + } + if (!fill_value.is_scalar()) { + ctx->SetStatus(Status::Invalid("fill value must be a scalar")); + } -struct InitFillNullStateVisitor { - KernelContext* ctx; - const FillNullOptions* options; - std::unique_ptr result; + ctx->SetStatus(Fill(ctx, *in_arr.array(), *fill_value.scalar(), out)); + } - InitFillNullStateVisitor(KernelContext* ctx, const FillNullOptions* options) - : ctx(ctx), options(options) {} + static Status Fill(KernelContext* ctx, const ArrayData& data, const Scalar& fill_value, + Datum* out) { + value_type value = UnboxScalar::Unbox(fill_value); + ArrayData* output = out->mutable_array(); - template - Status Init() { - using StateType = FillNullState; - result.reset(new StateType(ctx->exec_context()->memory_pool())); - return static_cast(result.get())->Init(*options); - } + if (data.null_count != 0 && fill_value.is_valid) { + BuilderType builder(data.type, ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(data.length)); - Status Visit(const DataType&) { return Init(); } + RETURN_NOT_OK(VisitArrayDataInline( + data, [&](value_type v) { return builder.Append(v); }, + [&]() { return builder.Append(value); })); - template - enable_if_supports_fill_null Visit(const Type&) { - return Init(); - } + std::shared_ptr output_array; + RETURN_NOT_OK(builder.Finish(&output_array)); + *output = std::move(*output_array->data()); - Status GetResult(std::unique_ptr* out) { - RETURN_NOT_OK(VisitTypeInline(*options->fill_value.type(), this)); - *out = std::move(result); + } else { + *output = data; + } return Status::OK(); } }; -std::unique_ptr InitFillNull(KernelContext* ctx, - const KernelInitArgs& args) { - InitFillNullStateVisitor visitor{ctx, - static_cast(args.options)}; - std::unique_ptr result; - ctx->SetStatus(visitor.GetResult(&result)); - return result; -} +template +struct FillNullFunctor::value>> { + using value_type = typename TypeTraits::CType; -struct ScalarFillVisitor { - KernelContext* ctx; - const ArrayData& data; - Datum* out; + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Datum& in_arr = batch[0]; + const Datum& fill_value = batch[1]; - ScalarFillVisitor(KernelContext* ctx, const ArrayData& data, Datum* out) - : ctx(ctx), data(data), out(out) {} + if (!in_arr.is_arraylike()) { + ctx->SetStatus(Status::Invalid("Values must be Array or ChunkedArray")); + } + if (!fill_value.is_scalar()) { + ctx->SetStatus(Status::Invalid("fill value must be a scalar")); + } - Status Visit(const DataType&) { - ArrayData* out_arr = out->mutable_array(); - *out_arr = data; - return Status::OK(); + ctx->SetStatus(Fill(ctx, *in_arr.array(), *fill_value.scalar(), out)); } - Status Visit(const BooleanType&) { - const auto& state = checked_cast&>(*ctx->state()); - bool value = UnboxScalar::Unbox(*state.fill_value); - ArrayData* out_arr = out->mutable_array(); - FirstTimeBitmapWriter bit_writer(out_arr->buffers[1]->mutable_data(), out_arr->offset, - out_arr->length); - FirstTimeBitmapWriter bit_writer_validity(out_arr->buffers[0]->mutable_data(), - out_arr->offset, out_arr->length); - if (data.null_count != 0) { - BitmapReader bit_reader(data.buffers[1]->data(), data.offset, data.length); - BitmapReader bit_reader_validity(data.buffers[0]->data(), data.offset, data.length); - for (int64_t i = 0; i < data.length; i++) { - if (bit_reader_validity.IsNotSet()) { - if (value == true) { - bit_writer.Set(); - } else { - bit_writer.Clear(); + static Status Fill(KernelContext* ctx, const ArrayData& data, const Scalar& fill_value, + Datum* out) { + value_type value = UnboxScalar::Unbox(fill_value); + ArrayData* output = out->mutable_array(); + + if (data.null_count != 0 && fill_value.is_valid) { + int64_t position = 0; + const uint8_t* bitmap = data.buffers[1]->data(); + const uint8_t* bitmap_validity = output->buffers[0]->data(); + auto length = data.length; + auto offset = data.offset; + + BooleanBuilder builder(data.type, ctx->memory_pool()); + RETURN_NOT_OK(builder.Reserve(length)); + BitBlockCounter bit_counter(bitmap_validity, offset, length); + while (position < length) { + BitBlockCount block = bit_counter.NextWord(); + if (block.AllSet()) { + for (int64_t i = 0; i < block.length; ++i, ++position) { + if (BitUtil::GetBit(bitmap, offset + position)) { + RETURN_NOT_OK(builder.Append(true)); + } else { + RETURN_NOT_OK(builder.Append(false)); + } + } + } else if (block.NoneSet()) { + for (int64_t i = 0; i < block.length; ++i, ++position) { + RETURN_NOT_OK(builder.Append(value)); } - bit_writer_validity.Set(); } else { - if (bit_reader.IsSet()) { - bit_writer.Set(); - } else { - bit_writer.Clear(); + for (int64_t i = 0; i < block.length; ++i, ++position) { + if (BitUtil::GetBit(bitmap_validity, offset + position)) { + if (BitUtil::GetBit(bitmap, offset + position)) { + RETURN_NOT_OK(builder.Append(true)); + } else { + RETURN_NOT_OK(builder.Append(false)); + } + } else { + RETURN_NOT_OK(builder.Append(value)); + } } - bit_writer_validity.Set(); } - bit_reader.Next(); - bit_writer.Next(); - bit_reader_validity.Next(); - bit_writer_validity.Next(); } - bit_writer_validity.Finish(); - bit_writer.Finish(); + std::shared_ptr output_array; + RETURN_NOT_OK(builder.Finish(&output_array)); + *output = std::move(*output_array->data()); } else { - *out_arr = data; + *output = data; } return Status::OK(); } +}; - template - enable_if_supports_fill_null Visit(const Type&) { - using T = typename GetViewType::T; - const auto& state = checked_cast&>(*ctx->state()); - T value = UnboxScalar::Unbox(*state.fill_value); - const T* in_data = data.GetValues(1); - ArrayData* out_arr = out->mutable_array(); - auto out_data = out_arr->GetMutableValues(1); - - if (data.null_count != 0) { - BitmapReader bit_reader(data.buffers[0]->data(), data.offset, data.length); - for (int64_t i = 0; i < data.length; i++) { - if (bit_reader.IsNotSet()) { - out_data[i] = value; - } else { - out_data[i] = static_cast(in_data[i]); - } - bit_reader.Next(); - } - BitUtil::SetBitsTo(out_arr->buffers[0]->mutable_data(), out_arr->offset, - out_arr->length, true); - } else { - *out_arr = data; +template +struct FillNullFunctor::value>> { + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const Datum& in_arr = batch[0]; + const Datum& fill_value = batch[1]; + + if (!in_arr.is_arraylike()) { + ctx->SetStatus(Status::Invalid("Values must be Array or ChunkedArray")); } - return Status::OK(); + if (!fill_value.is_scalar()) { + ctx->SetStatus(Status::Invalid("fill value must be a scalar")); + } + + ctx->SetStatus(Fill(ctx, *in_arr.array(), *fill_value.scalar(), out)); } - Status Execute() { return VisitTypeInline(*data.type, this); } + static Status Fill(KernelContext* ctx, const ArrayData& data, const Scalar& fill_value, + Datum* out) { + ArrayData* output = out->mutable_array(); + *output = data; + return Status::OK(); + } }; -void ExecFillNull(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - ScalarFillVisitor dispatch(ctx, *batch[0].array(), out); - ctx->SetStatus(dispatch.Execute()); +template