From ebc637a0d932b097193945986a94d9d7bc82a984 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 20 May 2021 14:22:48 -0400 Subject: [PATCH 1/5] adding basic structure --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_scalar.cc | 5 + cpp/src/arrow/compute/api_scalar.h | 16 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/scalar_if_else.cc | 137 ++++++++++++++++++ .../compute/kernels/scalar_if_else_test.cc | 17 +++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + 8 files changed, 179 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/scalar_if_else.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_if_else_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index bee14ae4ce3..0d04d967915 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -393,6 +393,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc compute/kernels/scalar_fill_null.cc + compute/kernels/scalar_if_else.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index c7c049af980..b88850d2ce9 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -156,5 +156,10 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } +Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false, + ExecContext* ctx) { + return CallFunction("if_else", {cond, if_true, if_false}, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 3e390df47e7..5ff8cc09e83 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -450,5 +450,21 @@ ARROW_EXPORT Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); +/// \brief IfElse returns elements chosen from `left` or `right` +/// depending on `cond`. `Null` values would be promoted to the result +/// +/// \param[in] cond `BooleanArray` condition array +/// \param[in] left scalar/ Array +/// \param[in] right scalar/ Array +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since x.x.x +/// \note API not yet finalized +ARROW_EXPORT +Result IfElse(const Datum& cond, const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5e223a1f906..fc11d144105 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test scalar_string_test.cc scalar_validity_test.cc scalar_fill_null_test.cc + scalar_if_else_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc new file mode 100644 index 00000000000..dc461b6a3d4 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "codegen_internal.h" + +namespace arrow { +namespace compute { + +namespace { + +template +struct IfElseFunctor { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + return Status::OK(); + } +}; + +template +struct ResolveExec { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch.length == 0) return Status::OK(); + + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // AAS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out->mutable_array()); + } + } else { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // ASA + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[2].array(), + // *batch[1].scalar(), + // out->mutable_array()); + // } else { // ASS + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[1].scalar(), + // *batch[2].scalar(), + // out->mutable_array()); + // } + } + } else { // when cond is scalar, output will also be scalar + if (batch[1].kind() == Datum::ARRAY) { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // SAA + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].array(), + // out->scalar().get()); + // } else { // SAS + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].scalar(), + // out->scalar().get()); + // } + } else { + if (batch[2].kind() == Datum::ARRAY) { // SSA + return Status::Invalid(""); + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].scalar(), + // *batch[2].array(), + // out->scalar().get()); + } else { // SSS + return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].scalar(), out->scalar().get()); + } + } + } + } +}; + +void AddPrimitiveKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + ScalarKernel kernel({boolean(), type, type}, type, exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); + } +} + +} // namespace + +const FunctionDoc if_else_doc{"", ("`"), {"cond", "left", "right"}}; + +namespace internal { + +void RegisterScalarIfElse(FunctionRegistry* registry) { + ScalarKernel scalar_kernel; + scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveKernels(func, NumericTypes()); + // todo add temporal, boolean, null and binary kernels + + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc new file mode 100644 index 00000000000..5cd17fb5a64 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -0,0 +1,17 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 3a8a3a0eb85..1d713b96e1e 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -125,6 +125,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); RegisterScalarFillNull(registry.get()); + RegisterScalarIfElse(registry.get()); // Vector functions RegisterVectorHash(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index e4008cf3f27..f97553af4b1 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -34,6 +34,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); void RegisterScalarFillNull(FunctionRegistry* registry); +void RegisterScalarIfElse(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); From 8b4616f00695bdc5a05665581100b9931151e179 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 16:45:47 -0400 Subject: [PATCH 2/5] working primitive types --- .../arrow/compute/kernels/scalar_if_else.cc | 195 +++++++++++++++++- .../compute/kernels/scalar_if_else_test.cc | 92 +++++++++ 2 files changed, 281 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index dc461b6a3d4..0d1b68da5e5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -16,17 +16,199 @@ // under the License. #include -#include +#include +#include #include "codegen_internal.h" namespace arrow { +using internal::BitBlockCount; +using internal::BitBlockCounter; + namespace compute { namespace { +// nulls will be promoted as follows +// cond.val && (cond.data && left.val || ~cond.data && right.val) +Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { + if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { + return Status::OK(); // no nulls to handle + } + const int64_t len = cond.length; + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_validity, ctx->AllocateBitmap(len)); + arrow::internal::InvertBitmap(out_validity->data(), 0, len, + out_validity->mutable_data(), 0); + if (right.MayHaveNulls()) { + // out_validity = right.val && ~cond.data + arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, len, 0, + out_validity->mutable_data()); + } + + if (left.MayHaveNulls()) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, ctx->AllocateBitmap(len)); + // tmp_buf = left.val && cond.data + arrow::internal::BitmapAnd(left.buffers[0]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, len, 0, + temp_buf->mutable_data()); + // out_validity = cond.data && left.val || ~cond.data && right.val + arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, + out_validity->mutable_data()); + } + + if (cond.MayHaveNulls()) { + // out_validity &= cond.val + ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + output->buffers[0] = std::move(out_validity); + output->GetNullCount(); // update null count + return Status::OK(); +} + template -struct IfElseFunctor { +struct IfElseFunctor {}; + +template +struct IfElseFunctor::value>> { + using T = typename TypeTraits::CType; + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::CopyBitmap(ctx->memory_pool(), ) + + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->AllocateBitmap(cond.length)); + uint8_t* out_values = out_buf->mutable_data(); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { return Status::OK(); @@ -71,19 +253,19 @@ struct ResolveExec { // out->mutable_array()); // } } - } else { // when cond is scalar, output will also be scalar + } else { if (batch[1].kind() == Datum::ARRAY) { return Status::Invalid(""); // if (batch[2].kind() == Datum::ARRAY) { // SAA // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].array(), // *batch[2].array(), - // out->scalar().get()); + // out->mutable_array()); // } else { // SAS // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].array(), // *batch[2].scalar(), - // out->scalar().get()); + // out->mutable_array()); // } } else { if (batch[2].kind() == Datum::ARRAY) { // SSA @@ -91,7 +273,7 @@ struct ResolveExec { // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].scalar(), // *batch[2].array(), - // out->scalar().get()); + // out->mutable_array()); } else { // SSS return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), *batch[2].scalar(), out->scalar().get()); @@ -127,6 +309,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); AddPrimitiveKernels(func, NumericTypes()); + AddPrimitiveKernels(func, TemporalTypes()); // todo add temporal, boolean, null and binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5cd17fb5a64..9b5dac71ff4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -15,3 +15,95 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include +#include + +namespace arrow { +namespace compute { + +void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& right, + const Datum& expected, bool all_valid = true) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(*expected.make_array(), *result, /*verbose=*/true); + if (all_valid) { + // Check null count of ArrayData is set, not the computed Array.null_count + ASSERT_EQ(result->data()->null_count, 0); + } +} + +void CheckIfElseOutputArray(const std::shared_ptr& type, + const std::string& cond, const std::string& left, + const std::string& right, const std::string& expected, + bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); +} + +class TestIfElseNullKernel : public ::testing::Test {}; + +template +class TestIfElsePrimitive : public ::testing::Test {}; + +using PrimitiveTypes = ::testing::Types; + + +TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); + +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { + // using ScalarType = typename TypeTraits::ScalarType; + auto type = TypeTraits::type_singleton(); + // auto scalar = std::make_shared(static_cast(5)); + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, 3, 8]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, null, 8]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[1, 2, null, null]", false); + + using ArrayType = typename TypeTraits::ArrayType; + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + typename TypeTraits::BuilderType builder; + + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file From d2779e648921f955031bf66c4551cbc2a9fb51cb Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 17:18:19 -0400 Subject: [PATCH 3/5] adding bool type --- .../arrow/compute/kernels/scalar_if_else.cc | 76 +++++-------------- .../compute/kernels/scalar_if_else_test.cc | 58 ++++++++++++-- 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 0d1b68da5e5..fe3e31268ac 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -49,11 +49,11 @@ Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& } if (left.MayHaveNulls()) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, ctx->AllocateBitmap(len)); // tmp_buf = left.val && cond.data - arrow::internal::BitmapAnd(left.buffers[0]->data(), left.offset, - cond.buffers[1]->data(), cond.offset, len, 0, - temp_buf->mutable_data()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[0]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, len, 0)); // out_validity = cond.data && left.val || ~cond.data && right.val arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, out_validity->mutable_data()); @@ -82,9 +82,7 @@ struct IfElseFunctor::value>> { ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - arrow::internal::CopyBitmap(ctx->memory_pool(), ) - - ctx->Allocate(cond.length * sizeof(T))); + ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); // copy right data to out_buff @@ -139,39 +137,20 @@ struct IfElseFunctor::value>> { const ArrayData& right, ArrayData* out) { ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - ctx->AllocateBitmap(cond.length)); - uint8_t* out_values = out_buf->mutable_data(); - - // copy right data to out_buff - const T* right_data = right.GetValues(1); - std::memcpy(out_values, right_data, right.length * sizeof(T)); - - const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray - BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); - - // selectively copy values from left data - const T* left_data = left.GetValues(1); - int64_t offset = cond.offset; - - // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) - while (offset < cond.offset + cond.length) { - const BitBlockCount& block = bit_counter.NextWord(); - if (block.AllSet()) { // all from left - std::memcpy(out_values, left_data, block.length * sizeof(T)); - } else if (block.popcount) { // selectively copy from left - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(cond_data, offset + i)) { - out_values[i] = left_data[i]; - } - } - } - - offset += block.length; - out_values += block.length; - left_data += block.length; - } - + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + arrow::internal::BitmapOr(out_buf->data(), 0, temp_buf->data(), 0, cond.length, 0, + out_buf->mutable_data()); out->buffers[1] = std::move(out_buf); return Status::OK(); } @@ -189,24 +168,6 @@ struct IfElseFunctor::value>> { } }; -template -struct IfElseFunctor::value>> { - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - return Status::OK(); - } - - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const Scalar& right, ArrayData* out) { - return Status::OK(); - } - - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const Scalar& right, Scalar* out) { - return Status::OK(); - } -}; - template struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, @@ -310,6 +271,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveKernels(func, NumericTypes()); AddPrimitiveKernels(func, TemporalTypes()); + AddPrimitiveKernels(func, {boolean()}); // todo add temporal, boolean, null and binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 9b5dac71ff4..f970b99456e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -47,7 +47,7 @@ void CheckIfElseOutputArray(const std::shared_ptr& type, CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); } -class TestIfElseNullKernel : public ::testing::Test {}; +class TestIfElseKernel : public ::testing::Test {}; template class TestIfElsePrimitive : public ::testing::Test {}; @@ -75,18 +75,16 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", "[null, 6, 7, null]", "[1, 2, null, null]", false); - using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto left = std::static_pointer_cast( + auto left = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - auto right = std::static_pointer_cast( + auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - typename TypeTraits::BuilderType builder; - + BooleanBuilder builder; for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || (!cond->Value(i) && !right->IsValid(i))) { @@ -105,5 +103,53 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(cond, left, right, expected_data, false); } +TEST_F(TestIfElseKernel, IfElseBoolean) { + // using ScalarType = typename TypeTraits::ScalarType; + // auto scalar = std::make_shared(static_cast(5)); + auto type = boolean(); + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, false, true]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, null, true]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", + "[null, false, true, null]", "[true, false, null, null]", false); + +// using ArrayType = typename TypeTraits::ArrayType; +// random::RandomArrayGenerator rand(/*seed=*/0); +// int64_t len = 1000; +// auto cond = std::static_pointer_cast( +// rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); +// auto left = std::static_pointer_cast( +// rand.ArrayOf(type, len, /*null_probability=*/0.01)); +// auto right = std::static_pointer_cast( +// rand.ArrayOf(type, len, /*null_probability=*/0.01)); +// +// typename TypeTraits::BuilderType builder; +// +// for (int64_t i = 0; i < len; ++i) { +// if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || +// (!cond->Value(i) && !right->IsValid(i))) { +// ASSERT_OK(builder.AppendNull()); +// continue; +// } +// +// if (cond->Value(i)) { +// ASSERT_OK(builder.Append(left->Value(i))); +// } else { +// ASSERT_OK(builder.Append(right->Value(i))); +// } +// } +// ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); +// +// CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + } // namespace compute } // namespace arrow \ No newline at end of file From 93d842adb097e22610eea18dcc48cdae1a84035e Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 17:36:06 -0400 Subject: [PATCH 4/5] adding null kernel --- .../arrow/compute/kernels/scalar_if_else.cc | 7 +- .../compute/kernels/scalar_if_else_test.cc | 73 ++++++++++--------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index fe3e31268ac..e00ee7bb2c6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -172,6 +172,8 @@ template struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { + // Nothing preallocated, so we assign left into the output + *out = left; return Status::OK(); } @@ -248,6 +250,7 @@ void AddPrimitiveKernels(const std::shared_ptr& scalar_function, const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + // cond array needs to be boolean always ScalarKernel kernel({boolean(), type, type}, type, exec); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; @@ -271,8 +274,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveKernels(func, NumericTypes()); AddPrimitiveKernels(func, TemporalTypes()); - AddPrimitiveKernels(func, {boolean()}); - // todo add temporal, boolean, null and binary kernels + AddPrimitiveKernels(func, {boolean(), null()}); + // todo add binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index f970b99456e..f1574bd3d28 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -60,9 +60,8 @@ using PrimitiveTypes = ::testing::Types::ScalarType; auto type = TypeTraits::type_singleton(); - // auto scalar = std::make_shared(static_cast(5)); + // No Nulls CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); @@ -75,16 +74,18 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", "[null, 6, 7, null]", "[1, 2, null, null]", false); + using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto left = std::static_pointer_cast( + auto left = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - auto right = std::static_pointer_cast( + auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - BooleanBuilder builder; + typename TypeTraits::BuilderType builder; + for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || (!cond->Value(i) && !right->IsValid(i))) { @@ -104,8 +105,6 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { } TEST_F(TestIfElseKernel, IfElseBoolean) { - // using ScalarType = typename TypeTraits::ScalarType; - // auto scalar = std::make_shared(static_cast(5)); auto type = boolean(); // No Nulls CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); @@ -121,34 +120,38 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", "[null, false, true, null]", "[true, false, null, null]", false); -// using ArrayType = typename TypeTraits::ArrayType; -// random::RandomArrayGenerator rand(/*seed=*/0); -// int64_t len = 1000; -// auto cond = std::static_pointer_cast( -// rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); -// auto left = std::static_pointer_cast( -// rand.ArrayOf(type, len, /*null_probability=*/0.01)); -// auto right = std::static_pointer_cast( -// rand.ArrayOf(type, len, /*null_probability=*/0.01)); -// -// typename TypeTraits::BuilderType builder; -// -// for (int64_t i = 0; i < len; ++i) { -// if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || -// (!cond->Value(i) && !right->IsValid(i))) { -// ASSERT_OK(builder.AppendNull()); -// continue; -// } -// -// if (cond->Value(i)) { -// ASSERT_OK(builder.Append(left->Value(i))); -// } else { -// ASSERT_OK(builder.Append(right->Value(i))); -// } -// } -// ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); -// -// CheckIfElseOutputArray(cond, left, right, expected_data, false); + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + BooleanBuilder builder; + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +TEST_F(TestIfElseKernel, IfElseNull) { + CheckIfElseOutputArray(null(), "[null, null, null, null]", "[null, null, null, null]", + "[null, null, null, null]", "[null, null, null, null]", + /*all_valid=*/false); } } // namespace compute From a1da1593ee791f2044431037be3e89938f38e767 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 May 2021 10:08:18 -0400 Subject: [PATCH 5/5] adding PR comments --- cpp/src/arrow/compute/api_scalar.h | 2 +- .../arrow/compute/kernels/scalar_if_else.cc | 28 +++++++++---------- .../compute/kernels/scalar_if_else_test.cc | 1 - 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 5ff8cc09e83..b0e5dc5325e 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -464,7 +464,7 @@ Result FillNull(const Datum& values, const Datum& fill_value, /// \note API not yet finalized ARROW_EXPORT Result IfElse(const Datum& cond, const Datum& left, const Datum& right, - ExecContext* ctx = NULLPTR); + ExecContext* ctx = NULLPTR); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index e00ee7bb2c6..2b971cf3236 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,8 +31,8 @@ namespace { // nulls will be promoted as follows // cond.val && (cond.data && left.val || ~cond.data && right.val) -Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* output) { +Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { return Status::OK(); // no nulls to handle } @@ -74,12 +74,14 @@ template struct IfElseFunctor {}; template -struct IfElseFunctor::value>> { +struct IfElseFunctor< + Type, swap, + enable_if_t::value | is_temporal_type::value>> { using T = typename TypeTraits::CType; static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -132,10 +134,10 @@ struct IfElseFunctor::value>> { }; template -struct IfElseFunctor::value>> { +struct IfElseFunctor> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, @@ -169,7 +171,7 @@ struct IfElseFunctor::value>> { }; template -struct IfElseFunctor::value>> { +struct IfElseFunctor> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { // Nothing preallocated, so we assign left into the output @@ -191,8 +193,6 @@ struct IfElseFunctor::value>> { template struct ResolveExec { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch.length == 0) return Status::OK(); - if (batch[0].kind() == Datum::ARRAY) { if (batch[1].kind() == Datum::ARRAY) { if (batch[2].kind() == Datum::ARRAY) { // AAA @@ -246,8 +246,8 @@ struct ResolveExec { } }; -void AddPrimitiveKernels(const std::shared_ptr& scalar_function, - const std::vector>& types) { +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); // cond array needs to be boolean always @@ -272,9 +272,9 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); - AddPrimitiveKernels(func, NumericTypes()); - AddPrimitiveKernels(func, TemporalTypes()); - AddPrimitiveKernels(func, {boolean(), null()}); + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), null()}); // todo add binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index f1574bd3d28..cf81ebf9441 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -56,7 +56,6 @@ using PrimitiveTypes = ::testing::Types; - TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) {