diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index bee14ae4ce3..0d04d967915 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -393,6 +393,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc compute/kernels/scalar_fill_null.cc + compute/kernels/scalar_if_else.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index c7c049af980..b88850d2ce9 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -156,5 +156,10 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } +Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false, + ExecContext* ctx) { + return CallFunction("if_else", {cond, if_true, if_false}, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 3e390df47e7..b0e5dc5325e 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -450,5 +450,21 @@ ARROW_EXPORT Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); +/// \brief IfElse returns elements chosen from `left` or `right` +/// depending on `cond`. `Null` values would be promoted to the result +/// +/// \param[in] cond `BooleanArray` condition array +/// \param[in] left scalar/ Array +/// \param[in] right scalar/ Array +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since x.x.x +/// \note API not yet finalized +ARROW_EXPORT +Result IfElse(const Datum& cond, const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5e223a1f906..fc11d144105 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test scalar_string_test.cc scalar_validity_test.cc scalar_fill_null_test.cc + scalar_if_else_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc new file mode 100644 index 00000000000..2b971cf3236 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include "codegen_internal.h" + +namespace arrow { +using internal::BitBlockCount; +using internal::BitBlockCounter; + +namespace compute { + +namespace { + +// nulls will be promoted as follows +// cond.val && (cond.data && left.val || ~cond.data && right.val) +Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { + if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { + return Status::OK(); // no nulls to handle + } + const int64_t len = cond.length; + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_validity, ctx->AllocateBitmap(len)); + arrow::internal::InvertBitmap(out_validity->data(), 0, len, + out_validity->mutable_data(), 0); + if (right.MayHaveNulls()) { + // out_validity = right.val && ~cond.data + arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, len, 0, + out_validity->mutable_data()); + } + + if (left.MayHaveNulls()) { + // tmp_buf = left.val && cond.data + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[0]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, len, 0)); + // out_validity = cond.data && left.val || ~cond.data && right.val + arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, + out_validity->mutable_data()); + } + + if (cond.MayHaveNulls()) { + // out_validity &= cond.val + ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + output->buffers[0] = std::move(out_validity); + output->GetNullCount(); // update null count + return Status::OK(); +} + +template +struct IfElseFunctor {}; + +template +struct IfElseFunctor< + Type, swap, + enable_if_t::value | is_temporal_type::value>> { + using T = typename TypeTraits::CType; + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + + // out_buff = right & ~cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + arrow::internal::BitmapOr(out_buf->data(), 0, temp_buf->data(), 0, cond.length, 0, + out_buf->mutable_data()); + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + // Nothing preallocated, so we assign left into the output + *out = left; + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + return Status::OK(); + } +}; + +template +struct ResolveExec { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // AAS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out->mutable_array()); + } + } else { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // ASA + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[2].array(), + // *batch[1].scalar(), + // out->mutable_array()); + // } else { // ASS + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[1].scalar(), + // *batch[2].scalar(), + // out->mutable_array()); + // } + } + } else { + if (batch[1].kind() == Datum::ARRAY) { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // SAA + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].array(), + // out->mutable_array()); + // } else { // SAS + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].scalar(), + // out->mutable_array()); + // } + } else { + if (batch[2].kind() == Datum::ARRAY) { // SSA + return Status::Invalid(""); + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].scalar(), + // *batch[2].array(), + // out->mutable_array()); + } else { // SSS + return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].scalar(), out->scalar().get()); + } + } + } + } +}; + +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + // cond array needs to be boolean always + ScalarKernel kernel({boolean(), type, type}, type, exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); + } +} + +} // namespace + +const FunctionDoc if_else_doc{"", ("`"), {"cond", "left", "right"}}; + +namespace internal { + +void RegisterScalarIfElse(FunctionRegistry* registry) { + ScalarKernel scalar_kernel; + scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), null()}); + // todo add binary kernels + + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc new file mode 100644 index 00000000000..cf81ebf9441 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -0,0 +1,157 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +namespace arrow { +namespace compute { + +void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& right, + const Datum& expected, bool all_valid = true) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(*expected.make_array(), *result, /*verbose=*/true); + if (all_valid) { + // Check null count of ArrayData is set, not the computed Array.null_count + ASSERT_EQ(result->data()->null_count, 0); + } +} + +void CheckIfElseOutputArray(const std::shared_ptr& type, + const std::string& cond, const std::string& left, + const std::string& right, const std::string& expected, + bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); +} + +class TestIfElseKernel : public ::testing::Test {}; + +template +class TestIfElsePrimitive : public ::testing::Test {}; + +using PrimitiveTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); + +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { + auto type = TypeTraits::type_singleton(); + + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, 3, 8]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, null, 8]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[1, 2, null, null]", false); + + using ArrayType = typename TypeTraits::ArrayType; + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + typename TypeTraits::BuilderType builder; + + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +TEST_F(TestIfElseKernel, IfElseBoolean) { + auto type = boolean(); + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, false, true]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, null, true]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", + "[null, false, true, null]", "[true, false, null, null]", false); + + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + BooleanBuilder builder; + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +TEST_F(TestIfElseKernel, IfElseNull) { + CheckIfElseOutputArray(null(), "[null, null, null, null]", "[null, null, null, null]", + "[null, null, null, null]", "[null, null, null, null]", + /*all_valid=*/false); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 3a8a3a0eb85..1d713b96e1e 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -125,6 +125,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); RegisterScalarFillNull(registry.get()); + RegisterScalarIfElse(registry.get()); // Vector functions RegisterVectorHash(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index e4008cf3f27..f97553af4b1 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -34,6 +34,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); void RegisterScalarFillNull(FunctionRegistry* registry); +void RegisterScalarIfElse(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry);