Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ if(ARROW_COMPUTE)
compute/kernels/scalar_string.cc
compute/kernels/scalar_validity.cc
compute/kernels/scalar_fill_null.cc
compute/kernels/scalar_if_else.cc
compute/kernels/util_internal.cc
compute/kernels/vector_hash.cc
compute/kernels/vector_nested.cc
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,10 @@ Result<Datum> FillNull(const Datum& values, const Datum& fill_value, ExecContext
return CallFunction("fill_null", {values, fill_value}, ctx);
}

Result<Datum> IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false,
ExecContext* ctx) {
return CallFunction("if_else", {cond, if_true, if_false}, ctx);
}

} // namespace compute
} // namespace arrow
16 changes: 16 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,5 +450,21 @@ ARROW_EXPORT
Result<Datum> FillNull(const Datum& values, const Datum& fill_value,
ExecContext* ctx = NULLPTR);

/// \brief IfElse returns elements chosen from `left` or `right`
/// depending on `cond`. `Null` values would be promoted to the result
///
/// \param[in] cond `BooleanArray` condition array
/// \param[in] left scalar/ Array
/// \param[in] right scalar/ Array
/// \param[in] ctx the function execution context, optional
///
/// \return the resulting datum
///
/// \since x.x.x
/// \note API not yet finalized
ARROW_EXPORT
Result<Datum> IfElse(const Datum& cond, const Datum& left, const Datum& right,
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test
scalar_string_test.cc
scalar_validity_test.cc
scalar_fill_null_test.cc
scalar_if_else_test.cc
test_util.cc)

add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute")
Expand Down
285 changes: 285 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <arrow/compute/api.h>
#include <arrow/util/bit_block_counter.h>
#include <arrow/util/bitmap_ops.h>

#include "codegen_internal.h"

namespace arrow {
using internal::BitBlockCount;
using internal::BitBlockCounter;

namespace compute {

namespace {

// nulls will be promoted as follows
// cond.val && (cond.data && left.val || ~cond.data && right.val)
Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const ArrayData& right, ArrayData* output) {
if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) {
return Status::OK(); // no nulls to handle
}
const int64_t len = cond.length;

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> out_validity, ctx->AllocateBitmap(len));
arrow::internal::InvertBitmap(out_validity->data(), 0, len,
out_validity->mutable_data(), 0);
if (right.MayHaveNulls()) {
// out_validity = right.val && ~cond.data
arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset,
cond.buffers[1]->data(), cond.offset, len, 0,
out_validity->mutable_data());
}

if (left.MayHaveNulls()) {
// tmp_buf = left.val && cond.data
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> temp_buf,
arrow::internal::BitmapAnd(
ctx->memory_pool(), left.buffers[0]->data(), left.offset,
cond.buffers[1]->data(), cond.offset, len, 0));
// out_validity = cond.data && left.val || ~cond.data && right.val
arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0,
out_validity->mutable_data());
}

if (cond.MayHaveNulls()) {
// out_validity &= cond.val
::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(),
cond.offset, len, 0, out_validity->mutable_data());
}

output->buffers[0] = std::move(out_validity);
output->GetNullCount(); // update null count
return Status::OK();
}

template <typename Type, bool swap = false, typename Enable = void>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is swap doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea is to reuse the impl for the cases like, cond, left: Array, right: Scalar and cond, left: Sca;ar, right: Array. In the second scenario, I can swap left and right and invert the cond without changing the loop mechanism.

struct IfElseFunctor {};

template <typename Type, bool swap>
struct IfElseFunctor<
Type, swap,
enable_if_t<is_number_type<Type>::value | is_temporal_type<Type>::value>> {
using T = typename TypeTraits<Type>::CType;

static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const ArrayData& right, ArrayData* out) {
ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out));

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> out_buf,
ctx->Allocate(cond.length * sizeof(T)));
T* out_values = reinterpret_cast<T*>(out_buf->mutable_data());

// copy right data to out_buff
const T* right_data = right.GetValues<T>(1);
std::memcpy(out_values, right_data, right.length * sizeof(T));

const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray
BitBlockCounter bit_counter(cond_data, cond.offset, cond.length);

// selectively copy values from left data
const T* left_data = left.GetValues<T>(1);
int64_t offset = cond.offset;

// todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume memcpy already does this for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean, load with mask?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I misunderstood the optimization you're thinking about. How would SIMD help here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say, you first copy right to ouput. Then, cond becomes a mask to store left onto output. For that there are specialized SIMD instructions.
https://software.intel.com/sites/landingpage/IntrinsicsGuide/#cats=Store&text=mask_store&expand=5564,5566

So, we can drop the BitBlock objects, and remove all the loops and memcpy inside the while loop. We'd have to handle the alignment though.

while (offset < cond.offset + cond.length) {
const BitBlockCount& block = bit_counter.NextWord();
if (block.AllSet()) { // all from left
std::memcpy(out_values, left_data, block.length * sizeof(T));
} else if (block.popcount) { // selectively copy from left
for (int64_t i = 0; i < block.length; ++i) {
if (BitUtil::GetBit(cond_data, offset + i)) {
out_values[i] = left_data[i];
}
}
}

offset += block.length;
out_values += block.length;
left_data += block.length;
}

out->buffers[1] = std::move(out_buf);
return Status::OK();
}

static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const Scalar& right, ArrayData* out) {
// todo impl
return Status::OK();
}

static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left,
const Scalar& right, Scalar* out) {
// todo impl
return Status::OK();
}
};

template <typename Type, bool swap>
struct IfElseFunctor<Type, swap, enable_if_boolean<Type>> {
static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const ArrayData& right, ArrayData* out) {
ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out));

// out_buff = right & ~cond
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> out_buf,
arrow::internal::BitmapAndNot(
ctx->memory_pool(), right.buffers[1]->data(), right.offset,
cond.buffers[1]->data(), cond.offset, cond.length, 0));

// out_buff = left & cond
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> temp_buf,
arrow::internal::BitmapAnd(
ctx->memory_pool(), left.buffers[1]->data(), left.offset,
cond.buffers[1]->data(), cond.offset, cond.length, 0));

arrow::internal::BitmapOr(out_buf->data(), 0, temp_buf->data(), 0, cond.length, 0,
out_buf->mutable_data());
out->buffers[1] = std::move(out_buf);
return Status::OK();
}

static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const Scalar& right, ArrayData* out) {
// todo impl
return Status::OK();
}

static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left,
const Scalar& right, Scalar* out) {
// todo impl
return Status::OK();
}
};

template <typename Type, bool swap>
struct IfElseFunctor<Type, swap, enable_if_null<Type>> {
static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const ArrayData& right, ArrayData* out) {
// Nothing preallocated, so we assign left into the output
*out = left;
return Status::OK();
}

static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
const Scalar& right, ArrayData* out) {
return Status::OK();
}

static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left,
const Scalar& right, Scalar* out) {
return Status::OK();
}
};

template <typename Type>
struct ResolveExec {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch[0].kind() == Datum::ARRAY) {
if (batch[1].kind() == Datum::ARRAY) {
if (batch[2].kind() == Datum::ARRAY) { // AAA
return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].array(),
*batch[2].array(), out->mutable_array());
} else { // AAS
return IfElseFunctor<Type>::Call(ctx, *batch[0].array(), *batch[1].array(),
*batch[2].scalar(), out->mutable_array());
}
} else {
return Status::Invalid("");
// if (batch[2].kind() == Datum::ARRAY) { // ASA
// return IfElseFunctor<Type, true>::Call(ctx, *batch[0].array(),
// *batch[2].array(),
// *batch[1].scalar(),
// out->mutable_array());
// } else { // ASS
// return IfElseFunctor<Type>::Call(ctx, *batch[0].array(),
// *batch[1].scalar(),
// *batch[2].scalar(),
// out->mutable_array());
// }
}
} else {
if (batch[1].kind() == Datum::ARRAY) {
return Status::Invalid("");
// if (batch[2].kind() == Datum::ARRAY) { // SAA
// return IfElseFunctor<Type>::Call(ctx, *batch[0].scalar(),
// *batch[1].array(),
// *batch[2].array(),
// out->mutable_array());
// } else { // SAS
// return IfElseFunctor<Type>::Call(ctx, *batch[0].scalar(),
// *batch[1].array(),
// *batch[2].scalar(),
// out->mutable_array());
// }
} else {
if (batch[2].kind() == Datum::ARRAY) { // SSA
return Status::Invalid("");
// return IfElseFunctor<Type>::Call(ctx, *batch[0].scalar(),
// *batch[1].scalar(),
// *batch[2].array(),
// out->mutable_array());
} else { // SSS
return IfElseFunctor<Type>::Call(ctx, *batch[0].scalar(), *batch[1].scalar(),
*batch[2].scalar(), out->scalar().get());
}
}
}
}
};

void AddPrimitiveIfElseKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
const std::vector<std::shared_ptr<DataType>>& types) {
for (auto&& type : types) {
auto exec = internal::GenerateTypeAgnosticPrimitive<ResolveExec>(*type);
// cond array needs to be boolean always
ScalarKernel kernel({boolean(), type, type}, type, exec);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;

DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}
}

} // namespace

const FunctionDoc if_else_doc{"<fill this>", ("`<fill this>"), {"cond", "left", "right"}};

namespace internal {

void RegisterScalarIfElse(FunctionRegistry* registry) {
ScalarKernel scalar_kernel;
scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;

auto func = std::make_shared<ScalarFunction>("if_else", Arity::Ternary(), &if_else_doc);

AddPrimitiveIfElseKernels(func, NumericTypes());
AddPrimitiveIfElseKernels(func, TemporalTypes());
AddPrimitiveIfElseKernels(func, {boolean(), null()});
// todo add binary kernels

DCHECK_OK(registry->AddFunction(std::move(func)));
}

} // namespace internal
} // namespace compute
} // namespace arrow
Loading