From f6165c8cd6e178a5c82c40d4dfd1afb61a5d814c Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 22 Sep 2021 16:23:01 -0400
Subject: [PATCH 1/5] ARROW-13358: [C++] Improve type support in if_else
---
.../arrow/compute/kernels/codegen_internal.cc | 43 +-
.../arrow/compute/kernels/codegen_internal.h | 13 +-
.../compute/kernels/codegen_internal_test.cc | 42 +-
.../arrow/compute/kernels/scalar_compare.cc | 4 +-
.../arrow/compute/kernels/scalar_if_else.cc | 279 ++++++++--
.../compute/kernels/scalar_if_else_test.cc | 488 +++++++++++++++++-
6 files changed, 761 insertions(+), 108 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index 8a8292c1c89..9077c7e9f0b 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -66,29 +66,45 @@ Result FirstType(KernelContext*, const std::vector& desc
return result;
}
+Result LastType(KernelContext*, const std::vector& descrs) {
+ ValueDescr result = descrs.back();
+ result.shape = GetBroadcastShape(descrs);
+ return result;
+}
+
Result ListValuesType(KernelContext*, const std::vector& args) {
const auto& list_type = checked_cast(*args[0].type);
return ValueDescr(list_type.value_type(), GetBroadcastShape(args));
}
void EnsureDictionaryDecoded(std::vector* descrs) {
- for (ValueDescr& descr : *descrs) {
- if (descr.type->id() == Type::DICTIONARY) {
- descr.type = checked_cast(*descr.type).value_type();
+ EnsureDictionaryDecoded(descrs->data(), descrs->size());
+}
+
+void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) {
+ auto* end = begin + count;
+ for (auto it = begin; it != end; it++) {
+ if (it->type->id() == Type::DICTIONARY) {
+ it->type = checked_cast(*it->type).value_type();
}
}
}
void ReplaceNullWithOtherType(std::vector* descrs) {
- DCHECK_EQ(descrs->size(), 2);
+ ReplaceNullWithOtherType(descrs->data(), descrs->size());
+}
+
+void ReplaceNullWithOtherType(ValueDescr* first, size_t count) {
+ DCHECK_EQ(count, 2);
- if (descrs->at(0).type->id() == Type::NA) {
- descrs->at(0).type = descrs->at(1).type;
+ ValueDescr* second = first++;
+ if (first->type->id() == Type::NA) {
+ first->type = second->type;
return;
}
- if (descrs->at(1).type->id() == Type::NA) {
- descrs->at(1).type = descrs->at(0).type;
+ if (second->type->id() == Type::NA) {
+ second->type = first->type;
return;
}
}
@@ -164,14 +180,15 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}
-std::shared_ptr CommonTemporal(const std::vector& descrs) {
+std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;
- for (const auto& descr : descrs) {
- auto id = descr.type->id();
+ const ValueDescr* end = begin + count;
+ for (auto it = begin; it != end; it++) {
+ auto id = it->type->id();
// a common timestamp is only possible if all types are timestamp like
switch (id) {
case Type::DATE32:
@@ -183,9 +200,7 @@ std::shared_ptr CommonTemporal(const std::vector& descrs)
saw_date64 = true;
continue;
case Type::TIMESTAMP: {
- const auto& ty = checked_cast(*descr.type);
- // Don't cast to common timezone by default (may not make
- // sense for all kernels)
+ const auto& ty = checked_cast(*it->type);
if (timezone && *timezone != ty.timezone()) return nullptr;
timezone = &ty.timezone();
finest_unit = std::max(finest_unit, ty.unit());
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 52e8f326d89..a4e2c3d6d1a 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -296,14 +296,14 @@ struct UnboxScalar> {
template <>
struct UnboxScalar {
- static Decimal128 Unbox(const Scalar& val) {
+ static const Decimal128& Unbox(const Scalar& val) {
return checked_cast(val).value;
}
};
template <>
struct UnboxScalar {
- static Decimal256 Unbox(const Scalar& val) {
+ static const Decimal256& Unbox(const Scalar& val) {
return checked_cast(val).value;
}
};
@@ -397,6 +397,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar
// Reusable type resolvers
Result FirstType(KernelContext*, const std::vector& descrs);
+Result LastType(KernelContext*, const std::vector& descrs);
Result ListValuesType(KernelContext*, const std::vector& args);
// ----------------------------------------------------------------------
@@ -1279,9 +1280,15 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
ARROW_EXPORT
void EnsureDictionaryDecoded(std::vector* descrs);
+ARROW_EXPORT
+void EnsureDictionaryDecoded(ValueDescr* begin, size_t count);
+
ARROW_EXPORT
void ReplaceNullWithOtherType(std::vector* descrs);
+ARROW_EXPORT
+void ReplaceNullWithOtherType(ValueDescr* begin, size_t count);
+
ARROW_EXPORT
void ReplaceTypes(const std::shared_ptr&, std::vector* descrs);
@@ -1295,7 +1302,7 @@ ARROW_EXPORT
std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count);
ARROW_EXPORT
-std::shared_ptr CommonTemporal(const std::vector& descrs);
+std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count);
ARROW_EXPORT
std::shared_ptr CommonBinary(const std::vector& descrs);
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
index a830d0c7636..d64143dea31 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc
@@ -130,24 +130,32 @@ TEST(TestDispatchBest, CastDecimalArgs) {
}
TEST(TestDispatchBest, CommonTemporal) {
- AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND),
- timestamp(TimeUnit::NANO)}));
+ std::vector args;
+
+ args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)};
+ AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::NANO, "UTC")};
AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"),
- CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"),
- timestamp(TimeUnit::NANO, "UTC")}));
- AssertTypeEqual(timestamp(TimeUnit::NANO),
- CommonTemporal({date32(), timestamp(TimeUnit::NANO)}));
- AssertTypeEqual(timestamp(TimeUnit::MILLI),
- CommonTemporal({date64(), timestamp(TimeUnit::SECOND)}));
- AssertTypeEqual(date32(), CommonTemporal({date32(), date32()}));
- AssertTypeEqual(date64(), CommonTemporal({date64(), date64()}));
- AssertTypeEqual(date64(), CommonTemporal({date32(), date64()}));
- ASSERT_EQ(nullptr, CommonTemporal({}));
- ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()}));
- ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND),
- timestamp(TimeUnit::SECOND, "UTC")}));
- ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"),
- timestamp(TimeUnit::SECOND, "UTC")}));
+ CommonTemporal(args.data(), args.size()));
+ args = {date32(), timestamp(TimeUnit::NANO)};
+ AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size()));
+ args = {date64(), timestamp(TimeUnit::SECOND)};
+ AssertTypeEqual(timestamp(TimeUnit::MILLI), CommonTemporal(args.data(), args.size()));
+ args = {date32(), date32()};
+ AssertTypeEqual(date32(), CommonTemporal(args.data(), args.size()));
+ args = {date64(), date64()};
+ AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
+ args = {date32(), date64()};
+ AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size()));
+ args = {};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {float64(), int32()};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
+ args = {timestamp(TimeUnit::SECOND, "America/Phoenix"),
+ timestamp(TimeUnit::SECOND, "UTC")};
+ ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size()));
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index c42b0ddac24..e42bcb7b25c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -171,7 +171,7 @@ struct CompareFunction : ScalarFunction {
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
- } else if (auto type = CommonTemporal(*values)) {
+ } else if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
} else if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
@@ -195,7 +195,7 @@ struct VarArgsCompareFunction : ScalarFunction {
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
- } else if (auto type = CommonTemporal(*values)) {
+ } else if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index 3b0da0a489a..e9432868ee8 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -810,12 +810,11 @@ struct IfElseFunctor> {
},
/*BroadcastScalar*/
[&](const Scalar& scalar, ArrayData* out_array) {
- const util::string_view& scalar_data =
- internal::UnboxScalar::Unbox(scalar);
+ const uint8_t* scalar_data = UnboxBinaryScalar(scalar);
uint8_t* start =
out_array->buffers[1]->mutable_data() + out_array->offset * byte_width;
for (int64_t i = 0; i < out_array->length; i++) {
- std::memcpy(start + i * byte_width, scalar_data.data(), scalar_data.size());
+ std::memcpy(start + i * byte_width, scalar_data, byte_width);
}
});
}
@@ -852,14 +851,12 @@ struct IfElseFunctor> {
std::memcpy(out_values, right_data, right.length * byte_width);
// selectively copy values from left data
- const util::string_view& left_data =
- internal::UnboxScalar::Unbox(left);
+ const uint8_t* left_data = UnboxBinaryScalar(left);
RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
- if (left_data.data()) {
+ if (left_data) {
for (int64_t i = 0; i < num_elems; i++) {
- std::memcpy(out_values + (data_offset + i) * byte_width, left_data.data(),
- left_data.size());
+ std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width);
}
}
});
@@ -877,14 +874,13 @@ struct IfElseFunctor> {
const uint8_t* left_data = left.buffers[1]->data() + left.offset * byte_width;
std::memcpy(out_values, left_data, left.length * byte_width);
- const util::string_view& right_data =
- internal::UnboxScalar::Unbox(right);
+ const uint8_t* right_data = UnboxBinaryScalar(right);
RunIfElseLoopInverted(cond, [&](int64_t data_offset, int64_t num_elems) {
- if (right_data.data()) {
+ if (right_data) {
for (int64_t i = 0; i < num_elems; i++) {
- std::memcpy(out_values + (data_offset + i) * byte_width, right_data.data(),
- right_data.size());
+ std::memcpy(out_values + (data_offset + i) * byte_width, right_data,
+ byte_width);
}
}
});
@@ -899,23 +895,19 @@ struct IfElseFunctor> {
auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width;
// copy right data to out_buff
- const util::string_view& right_data =
- internal::UnboxScalar::Unbox(right);
- if (right_data.data()) {
+ const uint8_t* right_data = UnboxBinaryScalar(right);
+ if (right_data) {
for (int64_t i = 0; i < cond.length; i++) {
- std::memcpy(out_values + i * byte_width, right_data.data(), right_data.size());
+ std::memcpy(out_values + i * byte_width, right_data, byte_width);
}
}
// selectively copy values from left data
- const util::string_view& left_data =
- internal::UnboxScalar::Unbox(left);
-
+ const uint8_t* left_data = UnboxBinaryScalar(left);
RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) {
- if (left_data.data()) {
+ if (left_data) {
for (int64_t i = 0; i < num_elems; i++) {
- std::memcpy(out_values + (data_offset + i) * byte_width, left_data.data(),
- left_data.size());
+ std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width);
}
}
});
@@ -923,20 +915,162 @@ struct IfElseFunctor> {
return Status::OK();
}
- static Result GetByteWidth(const DataType& left_type,
- const DataType& right_type) {
- int width = checked_cast(left_type).byte_width();
- if (width == checked_cast(right_type).byte_width()) {
- return width;
+ template
+ static enable_if_t::value, const uint8_t*> UnboxBinaryScalar(
+ const Scalar& scalar) {
+ return reinterpret_cast(
+ internal::UnboxScalar::Unbox(scalar).data());
+ }
+
+ template
+ static enable_if_decimal UnboxBinaryScalar(const Scalar& scalar) {
+ return internal::UnboxScalar::Unbox(scalar).native_endian_bytes();
+ }
+
+ template
+ static enable_if_t::value, Result> GetByteWidth(
+ const DataType& left_type, const DataType& right_type) {
+ const int32_t width =
+ checked_cast(left_type).byte_width();
+ DCHECK_EQ(width, checked_cast(right_type).byte_width());
+ return width;
+ }
+
+ template
+ static enable_if_decimal> GetByteWidth(const DataType& left_type,
+ const DataType& right_type) {
+ const auto& left = checked_cast(left_type);
+ const auto& right = checked_cast(right_type);
+ DCHECK_EQ(left.precision(), right.precision());
+ DCHECK_EQ(left.scale(), right.scale());
+ return left.byte_width();
+ }
+};
+
+// Use builders for dictionaries - slower, but allows us to unify dictionaries
+template
+struct IfElseFunctor<
+ Type, enable_if_t::value || is_dictionary_type::value>> {
+ // A - Array, S - Scalar, X = Array/Scalar
+
+ // SXX
+ static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left,
+ const Datum& right, Datum* out) {
+ if (left.is_scalar() && right.is_scalar()) {
+ if (cond.is_valid) {
+ *out = cond.value ? left.scalar() : right.scalar();
+ } else {
+ *out = MakeNullScalar(left.type());
+ }
+ return Status::OK();
+ }
+ // either left or right is an array. Output is always an array
+ int64_t out_arr_len = std::max(left.length(), right.length());
+ if (!cond.is_valid) {
+ // cond is null; just create a null array
+ ARROW_ASSIGN_OR_RAISE(*out,
+ MakeArrayOfNull(left.type(), out_arr_len, ctx->memory_pool()))
+ return Status::OK();
+ }
+
+ const auto& valid_data = cond.value ? left : right;
+ if (valid_data.is_array()) {
+ *out = valid_data;
+ } else {
+ // valid data is a scalar that needs to be broadcasted
+ ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), out_arr_len,
+ ctx->memory_pool()));
+ }
+ return Status::OK();
+ }
+
+ // AAA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const ArrayData& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i) {
+ return builder->AppendArraySlice(left, i, /*length=*/1);
+ },
+ [&](ArrayBuilder* builder, int64_t i) {
+ return builder->AppendArraySlice(right, i, /*length=*/1);
+ });
+ }
+
+ // ASA
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const ArrayData& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); },
+ [&](ArrayBuilder* builder, int64_t i) {
+ return builder->AppendArraySlice(right, i, /*length=*/1);
+ });
+ }
+
+ // AAS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left,
+ const Scalar& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i) {
+ return builder->AppendArraySlice(left, i, /*length=*/1);
+ },
+ [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); });
+ }
+
+ // ASS
+ static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left,
+ const Scalar& right, ArrayData* out) {
+ return RunLoop(
+ ctx, cond, out,
+ [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); },
+ [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); });
+ }
+
+ template
+ static Status RunLoop(KernelContext* ctx, const ArrayData& cond, ArrayData* out,
+ HandleLeft&& handle_left, HandleRight&& handle_right) {
+ std::unique_ptr raw_builder;
+ RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type, &raw_builder));
+ RETURN_NOT_OK(raw_builder->Reserve(out->length));
+
+ const auto* cond_data = cond.buffers[1]->data();
+ if (out->buffers[0]) {
+ const uint8_t* out_valid = out->buffers[0]->data();
+ for (int64_t i = 0; i < cond.length; i++) {
+ if (BitUtil::GetBit(out_valid, i)) {
+ if (BitUtil::GetBit(cond_data, cond.offset + i)) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), i));
+ } else {
+ RETURN_NOT_OK(handle_right(raw_builder.get(), i));
+ }
+ } else {
+ RETURN_NOT_OK(raw_builder->AppendNull());
+ }
+ }
} else {
- return Status::Invalid("FixedSizeBinaryType byte_widths should be equal");
+ for (int64_t i = 0; i < cond.length; i++) {
+ if (BitUtil::GetBit(cond_data, cond.offset + i)) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), i));
+ } else {
+ RETURN_NOT_OK(handle_right(raw_builder.get(), i));
+ }
+ }
}
+ ARROW_ASSIGN_OR_RAISE(auto out_arr, raw_builder->Finish());
+ *out = std::move(*out_arr->data());
+ return Status::OK();
}
};
template
struct ResolveIfElseExec {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ // Check is unconditional because parametric types like timestamp
+ // are templated as integer
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2));
+
// cond is scalar
if (batch[0].is_scalar()) {
const auto& cond = batch[0].scalar_as();
@@ -988,7 +1122,8 @@ struct IfElseFunction : ScalarFunction {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
- if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
+ // Do not DispatchExact here because it'll let through something like (bool,
+ // timestamp[s], timestamp[s, "UTC"])
// if 0th descriptor is null, replace with bool
if (values->at(0).type->id() == Type::NA) {
@@ -996,15 +1131,26 @@ struct IfElseFunction : ScalarFunction {
}
// if-else 0'th descriptor is bool, so skip it
- std::vector values_copy(values->begin() + 1, values->end());
- internal::EnsureDictionaryDecoded(&values_copy);
- internal::ReplaceNullWithOtherType(&values_copy);
+ ValueDescr* left_arg = &(*values)[1];
+ constexpr size_t num_args = 2;
- if (auto type = internal::CommonNumeric(values_copy)) {
- internal::ReplaceTypes(type, &values_copy);
+ internal::ReplaceNullWithOtherType(left_arg, num_args);
+
+ if (is_dictionary((*values)[1].type->id()) &&
+ is_dictionary((*values)[2].type->id())) {
+ auto kernel = DispatchExactImpl(this, *values);
+ DCHECK(kernel);
+ return kernel;
}
- std::move(values_copy.begin(), values_copy.end(), values->begin() + 1);
+ internal::EnsureDictionaryDecoded(left_arg, num_args);
+
+ if (auto type = internal::CommonNumeric(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ }
+ if (auto type = internal::CommonTemporal(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ }
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
@@ -1030,7 +1176,16 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun
internal::GenerateTypeAgnosticPrimitive(*type);
// cond array needs to be boolean always
- ScalarKernel kernel({boolean(), type, type}, type, exec);
+ std::shared_ptr sig;
+ if (type->id() == Type::TIMESTAMP) {
+ auto unit = checked_cast(*type).unit();
+ sig = KernelSignature::Make(
+ {boolean(), match::TimestampTypeUnit(unit), match::TimestampTypeUnit(unit)},
+ OutputType(LastType));
+ } else {
+ sig = KernelSignature::Make({boolean(), type, type}, type);
+ }
+ ScalarKernel kernel(std::move(sig), exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = true;
@@ -1056,14 +1211,12 @@ void AddBinaryIfElseKernels(const std::shared_ptr& scalar_functi
}
}
-void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_function) {
- // cond array needs to be boolean always
- ScalarKernel kernel(
- {boolean(), InputType(Type::FIXED_SIZE_BINARY), InputType(Type::FIXED_SIZE_BINARY)},
- OutputType([](KernelContext*, const std::vector& descrs) {
- return ValueDescr(descrs[1].type, ValueDescr::ANY);
- }),
- ResolveIfElseExec::Exec);
+template
+void AddFixedWidthIfElseKernel(const std::shared_ptr& scalar_function) {
+ auto type_id = T::type_id;
+ ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)},
+ OutputType(LastType),
+ ResolveIfElseExec::Exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = true;
@@ -1071,6 +1224,16 @@ void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_funct
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}
+void AddNestedIfElseKernel(const std::shared_ptr& scalar_function,
+ detail::GetTypeId get_id, ArrayKernelExec exec) {
+ ScalarKernel kernel({boolean(), InputType(get_id.id), InputType(get_id.id)},
+ OutputType(LastType), std::move(exec));
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+}
+
// Helper to copy or broadcast fixed-width values between buffers.
template
struct CopyFixedWidth {};
@@ -1758,7 +1921,7 @@ struct CoalesceFunction : ScalarFunction {
if (auto type = CommonBinary(*values)) {
ReplaceTypes(type, values);
}
- if (auto type = CommonTemporal(*values)) {
+ if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
}
if (HasDecimal(*values)) {
@@ -2500,12 +2663,6 @@ struct ChooseFunction : ScalarFunction {
}
};
-Result LastType(KernelContext*, const std::vector& descrs) {
- ValueDescr result = descrs.back();
- result.shape = GetBroadcastShape(descrs);
- return result;
-}
-
void AddCaseWhenKernel(const std::shared_ptr& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(
@@ -2629,7 +2786,23 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveIfElseKernels(func, {boolean()});
AddNullIfElseKernel(func);
AddBinaryIfElseKernels(func, BaseBinaryTypes());
- AddFSBinaryIfElseKernel(func);
+ AddFixedWidthIfElseKernel(func);
+ AddFixedWidthIfElseKernel(func);
+ AddFixedWidthIfElseKernel(func);
+ AddNestedIfElseKernel(func, Type::LIST,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::LARGE_LIST,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::FIXED_SIZE_LIST,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::STRUCT,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::DENSE_UNION,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::SPARSE_UNION,
+ ResolveIfElseExec::Exec);
+ AddNestedIfElseKernel(func, Type::DICTIONARY,
+ ResolveIfElseExec::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 907c7b0e638..5e74d4a60e4 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -19,6 +19,7 @@
#include "arrow/array.h"
#include "arrow/array/concatenate.h"
#include "arrow/compute/api_scalar.h"
+#include "arrow/compute/cast.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/compute/registry.h"
#include "arrow/testing/gtest_util.h"
@@ -103,7 +104,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
auto len = left->length();
enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
- for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) {
+ for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) {
for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) {
Datum cond_in, cond_bcast;
std::string trace_cond = "Cond";
@@ -113,6 +114,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString();
} else {
cond_in = cond_bcast = cond;
+ trace_cond += "=Array";
}
SCOPED_TRACE(trace_cond);
@@ -122,10 +124,11 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
if (mask & LEFT_SCALAR) {
ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As());
ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len));
- trace_cond +=
+ trace_left +=
"@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString();
} else {
left_in = left_bcast = left;
+ trace_left += "=Array";
}
SCOPED_TRACE(trace_left);
@@ -140,12 +143,27 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
"@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString();
} else {
right_in = right_bcast = right;
+ trace_right += "=Array";
}
SCOPED_TRACE(trace_right);
- ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast, right_bcast));
+ Datum expected;
ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in));
- AssertDatumsEqual(exp, actual, /*verbose=*/true);
+ if (mask & COND_SCALAR && mask & LEFT_SCALAR && mask & RIGHT_SCALAR) {
+ const auto& scalar = cond_in.scalar_as();
+ if (scalar.is_valid) {
+ expected = scalar.value ? left_in : right_in;
+ } else {
+ expected = MakeNullScalar(left_in.type());
+ }
+ if (!left_in.type()->Equals(*right_in.type())) {
+ ASSERT_OK_AND_ASSIGN(expected,
+ Cast(expected, CastOptions::Safe(actual.type())));
+ }
+ } else {
+ ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, right_bcast));
+ }
+ AssertDatumsEqual(actual, expected, /*verbose=*/true);
if (right_in.is_array()) break;
}
@@ -288,8 +306,43 @@ TEST_F(TestIfElseKernel, IfElseMultiType) {
ArrayFromJSON(float32(), "[1, 2, 3, 8]"));
}
+TEST_F(TestIfElseKernel, TimestampTypes) {
+ for (const auto unit : TimeUnit::values()) {
+ auto ty = timestamp(unit);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+
+ ty = timestamp(unit, "America/Phoenix");
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+ }
+}
+
+TEST_F(TestIfElseKernel, TemporalTypes) {
+ for (const auto& ty : TemporalTypes()) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 4]"),
+ ArrayFromJSON(ty, "[5, 6, 7, 8]"),
+ ArrayFromJSON(ty, "[1, 2, 3, 8]"));
+ }
+}
+
+TEST_F(TestIfElseKernel, DayTimeInterval) {
+ auto ty = day_time_interval();
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [-7, -8]]"),
+ ArrayFromJSON(ty, "[[-9, -10], [11, -12], [-13, 14], [15, 16]]"),
+ ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [15, 16]]"));
+}
+
TEST_F(TestIfElseKernel, IfElseDispatchBest) {
std::string name = "if_else";
+ ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(name));
CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()});
CheckDispatchBest(name, {boolean(), int32(), null()}, {boolean(), int32(), int32()});
CheckDispatchBest(name, {boolean(), null(), int32()}, {boolean(), int32(), int32()});
@@ -316,6 +369,16 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
{boolean(), float64(), float64()});
CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()});
+
+ CheckDispatchBest(name,
+ {boolean(), timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
+ {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
+ {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
+ CheckDispatchBest(name, {boolean(), date32(), date64()},
+ {boolean(), date64(), date64()});
+ CheckDispatchBest(name, {boolean(), date32(), date32()},
+ {boolean(), date32(), date32()});
}
template
@@ -412,7 +475,7 @@ TYPED_TEST(TestIfElseBaseBinary, IfElseBaseBinaryRand) {
}
TEST_F(TestIfElseKernel, IfElseFSBinary) {
- auto type = std::make_shared(4);
+ auto type = fixed_size_binary(4);
CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
@@ -453,18 +516,10 @@ TEST_F(TestIfElseKernel, IfElseFSBinary) {
ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"),
ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"),
ArrayFromJSON(type, R"([null, "abab", "abca", "llll"])"));
-
- // should fails for non-equal byte_widths
- auto type1 = std::make_shared(5);
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- Invalid, ::testing::HasSubstr("FixedSizeBinaryType byte_widths should be equal"),
- CallFunction("if_else", {ArrayFromJSON(boolean(), "[true]"),
- ArrayFromJSON(type, R"(["aaaa"])"),
- ArrayFromJSON(type1, R"(["aaaaa"])")}));
}
TEST_F(TestIfElseKernel, IfElseFSBinaryRand) {
- auto type = std::make_shared(5);
+ auto type = fixed_size_binary(5);
random::RandomArrayGenerator rand(/*seed=*/0);
int64_t len = 1000;
@@ -504,6 +559,406 @@ TEST_F(TestIfElseKernel, IfElseFSBinaryRand) {
CheckIfElseOutput(cond, left, right, expected_data);
}
+TEST_F(TestIfElseKernel, Decimal) {
+ for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) {
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"([null, "2.34", null, "-4.56"])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"([null, "2.34", null, null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"),
+ ArrayFromJSON(ty, R"([null, "2.34", "-1.23", null])"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"),
+ ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"),
+ ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"),
+ ArrayFromJSON(ty, R"([null, "2.34", "-1.23", "-4.56"])"));
+ }
+}
+
+template
+class TestIfElseList : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseList, ListArrowTypes);
+
+TYPED_TEST(TestIfElseList, ListOfInt) {
+ auto type = std::make_shared(int32());
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, "[[], null, [1, null], [2, 3]]"),
+ ArrayFromJSON(type, "[[4, 5, 6], [7], [null], null]"),
+ ArrayFromJSON(type, "[[], null, [null], null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, "[[], [2, 3, 4, 5], null, null]"),
+ ArrayFromJSON(type, "[[4, 5, 6], null, [null], null]"),
+ ArrayFromJSON(type, "[null, null, null, null]"));
+}
+
+TYPED_TEST(TestIfElseList, ListOfString) {
+ auto type = std::make_shared(utf8());
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[], null, ["xyz", null], ["ab", "c"]])"),
+ ArrayFromJSON(type, R"([["hi", "jk", "l"], ["defg"], [null], null])"),
+ ArrayFromJSON(type, R"([[], null, [null], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[], ["b", "cd", "efg", "h"], null, null])"),
+ ArrayFromJSON(type, R"([["hi", "jk", "l"], null, [null], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, FixedSizeList) {
+ auto type = fixed_size_list(int32(), 2);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, "[[1, 2], null, [1, null], [2, 3]]"),
+ ArrayFromJSON(type, "[[4, 5], [6, 7], [null, 8], null]"),
+ ArrayFromJSON(type, "[[1, 2], null, [null, 8], null]"));
+
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, "[[2, 3], [4, 5], null, null]"),
+ ArrayFromJSON(type, "[[4, 5], null, [6, null], null]"),
+ ArrayFromJSON(type, "[null, null, null, null]"));
+}
+
+TEST_F(TestIfElseKernel, StructPrimitive) {
+ auto type = struct_({field("int", uint16()), field("str", utf8())});
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[null, "foo"], null, [1, null], [2, "spam"]])"),
+ ArrayFromJSON(type, R"([[1, "a"], [42, ""], [24, null], null])"),
+ ArrayFromJSON(type, R"([[null, "foo"], null, [24, null], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[null, "foo"], [4, "abcd"], null, null])"),
+ ArrayFromJSON(type, R"([[1, "a"], null, [24, null], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, StructNested) {
+ auto type = struct_({field("date", date32()), field("list", list(int32()))});
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[-1, [null]], null, [1, null], [2, [3, 4]]])"),
+ ArrayFromJSON(type, R"([[4, [5]], [6, [7, 8]], [null, [1, null, 42]], null])"),
+ ArrayFromJSON(type, R"([[-1, [null]], null, [null, [1, null, 42]], null])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[-1, [null]], [4, [5, 6]], null, null])"),
+ ArrayFromJSON(type, R"([[4, [5]], null, [null, [1, null, 42]], null])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TEST_F(TestIfElseKernel, ParameterizedTypes) {
+ auto cond = ArrayFromJSON(boolean(), "[true]");
+
+ auto type0 = fixed_size_binary(4);
+ auto type1 = fixed_size_binary(5);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: "
+ "fixed_size_binary[4], but got: fixed_size_binary[5]"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["aaaa"])"),
+ ArrayFromJSON(type1, R"(["aaaaa"])")}));
+
+ type0 = decimal128(3, 2);
+ type1 = decimal128(4, 2);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), "
+ "but got: decimal128(4, 2)"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"),
+ ArrayFromJSON(type1, R"(["1.23"])")}));
+
+ type1 = decimal128(3, 4);
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), "
+ "but got: decimal128(3, 4)"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"),
+ ArrayFromJSON(type1, R"(["1.2345"])")}));
+
+ // TODO(ARROW-14105): in principle many of these could be implicitly castable too
+
+ type0 = struct_({field("a", int32())});
+ type1 = struct_({field("a", int64())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: struct, "
+ "but got: struct"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
+
+ type0 = dense_union({field("a", int32())});
+ type1 = dense_union({field("a", int64())});
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: dense_union, but got: dense_union"),
+ CallFunction("if_else", {cond, ArrayFromJSON(type0, "[[0, -1]]"),
+ ArrayFromJSON(type1, "[[0, -1]]")}));
+
+ type0 = list(int16());
+ type1 = list(int32());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: list, "
+ "but got: list"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
+
+ type0 = dictionary(int16(), utf8());
+ type1 = dictionary(int32(), utf8());
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: "
+ "dictionary, but "
+ "got: dictionary"),
+ CallFunction("if_else", {cond, DictArrayFromJSON(type0, "[0]", R"(["a"])"),
+ DictArrayFromJSON(type1, "[0]", R"(["a"])")}));
+
+ type0 = timestamp(TimeUnit::SECOND);
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr("All types must be compatible, expected: timestamp[s], "
+ "but got: timestamp[s, tz=America/Phoenix]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+
+ type0 = timestamp(TimeUnit::SECOND, "America/New_York");
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ TypeError,
+ ::testing::HasSubstr(
+ "All types must be compatible, expected: timestamp[s, tz=America/New_York], "
+ "but got: timestamp[s, tz=America/Phoenix]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+
+ type0 = timestamp(TimeUnit::MILLI, "America/New_York");
+ type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
+ // Casting fails so we never get to the kernel in the first place (since the units don't
+ // match)
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented,
+ ::testing::HasSubstr("Function if_else has no kernel matching input types "
+ "(array[bool], array[timestamp[ms, tz=America/New_York]], "
+ "array[timestamp[s, tz=America/Phoenix]]"),
+ CallFunction("if_else",
+ {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
+}
+
+template
+class TestIfElseUnion : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseUnion, UnionArrowTypes);
+
+TYPED_TEST(TestIfElseUnion, UnionPrimitive) {
+ std::vector> fields = {field("int", uint16()),
+ field("str", utf8())};
+ std::vector codes = {2, 7};
+ auto type = std::make_shared(fields, codes);
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [2, 42], [2, null]])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+TYPED_TEST(TestIfElseUnion, UnionNested) {
+ std::vector> fields = {field("int", uint16()),
+ field("list", list(int16()))};
+ std::vector codes = {2, 7};
+ auto type = std::make_shared(fields, codes);
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[true, true, false, false]"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [2, 42], [2, null]])"));
+
+ CheckWithDifferentShapes(
+ ArrayFromJSON(boolean(), "[null, null, null, null]"),
+ ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"),
+ ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"),
+ ArrayFromJSON(type, R"([null, null, null, null])"));
+}
+
+template
+class TestIfElseDict : public ::testing::Test {};
+
+TYPED_TEST_SUITE(TestIfElseDict, IntegralArrowTypes);
+
+struct JsonDict {
+ std::shared_ptr type;
+ std::string value;
+};
+
+TYPED_TEST(TestIfElseDict, Simple) {
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ for (const auto& dict :
+ {JsonDict{utf8(), R"(["a", null, "bc", "def"])"},
+ JsonDict{int64(), "[1, null, 2, 3]"},
+ JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) {
+ auto type = dictionary(default_type_instance(), dict.type);
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value);
+ auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value);
+ auto scalar = DictScalarFromJSON(type, "3", dict.value);
+
+ // Easy case: all arguments have the same dictionary
+ CheckDictionary("if_else", {cond, values1, values2});
+ CheckDictionary("if_else", {cond, values1, scalar});
+ CheckDictionary("if_else", {cond, scalar, values2});
+ CheckDictionary("if_else", {cond, values_null, values2});
+ CheckDictionary("if_else", {cond, values1, values_null});
+ CheckDictionary("if_else", {Datum(true), values1, values2});
+ CheckDictionary("if_else", {Datum(false), values1, values2});
+ CheckDictionary("if_else", {Datum(true), scalar, values2});
+ CheckDictionary("if_else", {Datum(true), values1, scalar});
+ CheckDictionary("if_else", {Datum(false), values1, scalar});
+ CheckDictionary("if_else", {Datum(false), scalar, values2});
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2});
+ }
+}
+
+TYPED_TEST(TestIfElseDict, Mixed) {
+ auto type = dictionary(default_type_instance(), utf8());
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
+ auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict);
+ auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])");
+ auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict);
+ auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])");
+ auto scalar = ScalarFromJSON(utf8(), R"("bc")");
+
+ // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries
+ CheckDictionary("if_else", {cond, values1_dict, values2_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_dict, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, scalar, values2_dict}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values_null, values2_decoded},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_decoded, values_null},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), scalar, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1_dict, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1_dict, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), scalar, values2_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1_decoded, values2_dict},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestIfElseDict, NestedSimple) {
+ auto make_list = [](const std::shared_ptr& indices,
+ const std::shared_ptr& backing_array) {
+ EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
+ return result;
+ };
+ auto index_type = default_type_instance();
+ auto inner_type = dictionary(index_type, utf8());
+ auto type = list(inner_type);
+ auto dict = R"(["a", null, "bc", "def"])";
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
+ auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
+ auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
+ auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+ auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+ auto scalar =
+ Datum(std::make_shared(DictArrayFromJSON(inner_type, "[0, 1]", dict)));
+
+ CheckDictionary("if_else", {cond, values1, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, scalar, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values_null, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1, values_null}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), scalar, values2}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(true), values1, scalar}, /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), values1, scalar},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {Datum(false), scalar, values2},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2},
+ /*result_is_encoded=*/false);
+}
+
+TYPED_TEST(TestIfElseDict, DifferentDictionaries) {
+ auto type = dictionary(default_type_instance(), utf8());
+ auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
+ auto dict1 = R"(["a", null, "bc", "def"])";
+ auto dict2 = R"(["bc", "foo", null, "a"])";
+ auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1);
+ auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2);
+ auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1);
+ auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2);
+ auto scalar1 = DictScalarFromJSON(type, "0", dict1);
+ auto scalar2 = DictScalarFromJSON(type, "0", dict2);
+
+ CheckDictionary("if_else", {cond, values1, values2});
+ CheckDictionary("if_else", {cond, values1, scalar2});
+ CheckDictionary("if_else", {cond, scalar1, values2});
+ CheckDictionary("if_else", {cond, values1_null, values2});
+ CheckDictionary("if_else", {cond, values1, values2_null});
+ CheckDictionary("if_else", {Datum(true), values1, values2});
+ CheckDictionary("if_else", {Datum(false), values1, values2});
+ CheckDictionary("if_else", {Datum(true), scalar1, values2});
+ CheckDictionary("if_else", {Datum(true), values1, scalar2});
+ CheckDictionary("if_else", {Datum(false), values1, scalar2});
+ CheckDictionary("if_else", {Datum(false), scalar1, values2});
+ CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2});
+}
+
template
class TestCaseWhenNumeric : public ::testing::Test {};
@@ -627,11 +1082,6 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) {
template
class TestCaseWhenDict : public ::testing::Test {};
-struct JsonDict {
- std::shared_ptr type;
- std::string value;
-};
-
TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes);
TYPED_TEST(TestCaseWhenDict, Simple) {
From a53bc771d321cc5a57e07409dba1c8b2ee35a632 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 29 Sep 2021 12:46:10 -0400
Subject: [PATCH 2/5] ARROW-13358: [C++] Move out common utilities
---
.../compute/kernels/scalar_if_else_test.cc | 54 ++++++++++---------
1 file changed, 29 insertions(+), 25 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 5e74d4a60e4..edf1d954404 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -27,6 +27,20 @@
namespace arrow {
namespace compute {
+// Helper that combines a dictionary and the value type so it can
+// later be used with DictArrayFromJSON
+struct JsonDict {
+ std::shared_ptr type;
+ std::string value;
+};
+
+// Helper that makes a list of dictionary indices
+std::shared_ptr MakeListOfDict(const std::shared_ptr& indices,
+ const std::shared_ptr& backing_array) {
+ EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
+ return result;
+}
+
void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right,
const Datum& expected) {
ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right));
@@ -825,11 +839,6 @@ class TestIfElseDict : public ::testing::Test {};
TYPED_TEST_SUITE(TestIfElseDict, IntegralArrowTypes);
-struct JsonDict {
- std::shared_ptr type;
- std::string value;
-};
-
TYPED_TEST(TestIfElseDict, Simple) {
auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
for (const auto& dict :
@@ -895,22 +904,19 @@ TYPED_TEST(TestIfElseDict, Mixed) {
}
TYPED_TEST(TestIfElseDict, NestedSimple) {
- auto make_list = [](const std::shared_ptr& indices,
- const std::shared_ptr& backing_array) {
- EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
- return result;
- };
auto index_type = default_type_instance();
auto inner_type = dictionary(index_type, utf8());
auto type = list(inner_type);
auto dict = R"(["a", null, "bc", "def"])";
auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
- auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
- DictArrayFromJSON(inner_type, "[]", dict));
+ auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
- auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
- auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+ auto values1 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+ auto values2 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
auto scalar =
Datum(std::make_shared(DictArrayFromJSON(inner_type, "[0, 1]", dict)));
@@ -1133,35 +1139,33 @@ TYPED_TEST(TestCaseWhenDict, Mixed) {
}
TYPED_TEST(TestCaseWhenDict, NestedSimple) {
- auto make_list = [](const std::shared_ptr& indices,
- const std::shared_ptr& backing_array) {
- EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array));
- return result;
- };
auto index_type = default_type_instance();
auto inner_type = dictionary(index_type, utf8());
auto type = list(inner_type);
auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]");
auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]");
auto dict = R"(["a", null, "bc", "def"])";
- auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
- DictArrayFromJSON(inner_type, "[]", dict));
+ auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"),
+ DictArrayFromJSON(inner_type, "[]", dict));
auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict);
auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict);
- auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
- auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
+ auto values1 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing);
+ auto values2 =
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing);
CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2},
/*result_is_encoded=*/false);
CheckDictionary(
"case_when",
{MakeStruct({cond1, cond2}), values1,
- make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)},
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)},
/*result_is_encoded=*/false);
CheckDictionary(
"case_when",
{MakeStruct({cond1, cond2}), values1,
- make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1},
+ MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing),
+ values1},
/*result_is_encoded=*/false);
CheckDictionary("case_when",
From ba2e94704179ba34df264a2fa275fc4e4c8ab807 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 29 Sep 2021 13:54:10 -0400
Subject: [PATCH 3/5] ARROW-13358: [C++] Reconcile with ARROW-14167
---
.../arrow/compute/kernels/codegen_internal.cc | 7 ++--
.../arrow/compute/kernels/codegen_internal.h | 2 +-
.../arrow/compute/kernels/scalar_compare.cc | 2 +-
.../arrow/compute/kernels/scalar_if_else.cc | 10 ++++-
.../compute/kernels/scalar_if_else_test.cc | 41 ++++++-------------
5 files changed, 27 insertions(+), 35 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index 9077c7e9f0b..70488759261 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -222,11 +222,12 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count)
return nullptr;
}
-std::shared_ptr CommonBinary(const std::vector& descrs) {
+std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count) {
bool all_utf8 = true, all_offset32 = true;
- for (const auto& descr : descrs) {
- auto id = descr.type->id();
+ const ValueDescr* end = begin + count;
+ for (auto it = begin; it != end; ++it) {
+ auto id = it->type->id();
// a common varbinary type is only possible if all types are binary like
switch (id) {
case Type::STRING:
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index a4e2c3d6d1a..dc6068dd529 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -1305,7 +1305,7 @@ ARROW_EXPORT
std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count);
ARROW_EXPORT
-std::shared_ptr CommonBinary(const std::vector& descrs);
+std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count);
/// How to promote decimal precision/scale in CastBinaryDecimalArgs.
enum class DecimalPromotion : uint8_t {
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index e42bcb7b25c..54230f88ff9 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -173,7 +173,7 @@ struct CompareFunction : ScalarFunction {
ReplaceTypes(type, values);
} else if (auto type = CommonTemporal(values->data(), values->size())) {
ReplaceTypes(type, values);
- } else if (auto type = CommonBinary(*values)) {
+ } else if (auto type = CommonBinary(values->data(), values->size())) {
ReplaceTypes(type, values);
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index e9432868ee8..bb7c417f1e5 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -1137,7 +1137,7 @@ struct IfElseFunction : ScalarFunction {
internal::ReplaceNullWithOtherType(left_arg, num_args);
if (is_dictionary((*values)[1].type->id()) &&
- is_dictionary((*values)[2].type->id())) {
+ (*values)[1].type->Equals(*(*values)[2].type)) {
auto kernel = DispatchExactImpl(this, *values);
DCHECK(kernel);
return kernel;
@@ -1151,6 +1151,12 @@ struct IfElseFunction : ScalarFunction {
if (auto type = internal::CommonTemporal(left_arg, num_args)) {
internal::ReplaceTypes(type, left_arg, num_args);
}
+ if (auto type = internal::CommonBinary(left_arg, num_args)) {
+ internal::ReplaceTypes(type, left_arg, num_args);
+ }
+ if (HasDecimal(*values)) {
+ RETURN_NOT_OK(CastDecimalArgs(left_arg, num_args));
+ }
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
@@ -1918,7 +1924,7 @@ struct CoalesceFunction : ScalarFunction {
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
- if (auto type = CommonBinary(*values)) {
+ if (auto type = CommonBinary(values->data(), values->size())) {
ReplaceTypes(type, values);
}
if (auto type = CommonTemporal(values->data(), values->size())) {
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index edf1d954404..5acb0b16b6c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -705,23 +705,6 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["aaaa"])"),
ArrayFromJSON(type1, R"(["aaaaa"])")}));
- type0 = decimal128(3, 2);
- type1 = decimal128(4, 2);
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError,
- ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), "
- "but got: decimal128(4, 2)"),
- CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"),
- ArrayFromJSON(type1, R"(["1.23"])")}));
-
- type1 = decimal128(3, 4);
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError,
- ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), "
- "but got: decimal128(3, 4)"),
- CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"),
- ArrayFromJSON(type1, R"(["1.2345"])")}));
-
// TODO(ARROW-14105): in principle many of these could be implicitly castable too
type0 = struct_({field("a", int32())});
@@ -751,16 +734,6 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
CallFunction("if_else",
{cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
- type0 = dictionary(int16(), utf8());
- type1 = dictionary(int32(), utf8());
- EXPECT_RAISES_WITH_MESSAGE_THAT(
- TypeError,
- ::testing::HasSubstr("All types must be compatible, expected: "
- "dictionary, but "
- "got: dictionary"),
- CallFunction("if_else", {cond, DictArrayFromJSON(type0, "[0]", R"(["a"])"),
- DictArrayFromJSON(type1, "[0]", R"(["a"])")}));
-
type0 = timestamp(TimeUnit::SECOND);
type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
EXPECT_RAISES_WITH_MESSAGE_THAT(
@@ -868,7 +841,8 @@ TYPED_TEST(TestIfElseDict, Simple) {
}
TYPED_TEST(TestIfElseDict, Mixed) {
- auto type = dictionary(default_type_instance(), utf8());
+ auto index_type = default_type_instance();
+ auto type = dictionary(index_type, utf8());
auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]");
auto dict = R"(["a", null, "bc", "def"])";
auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict);
@@ -901,6 +875,17 @@ TYPED_TEST(TestIfElseDict, Mixed) {
/*result_is_encoded=*/false);
CheckDictionary("if_else", {MakeNullScalar(boolean()), values1_decoded, values2_dict},
/*result_is_encoded=*/false);
+
+ // If we have mismatched dictionary types, we decode (for now)
+ auto values3_dict =
+ DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict);
+ auto values4_dict = DictArrayFromJSON(
+ dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()),
+ "[2, 1, null, 0]", dict);
+ CheckDictionary("if_else", {cond, values1_dict, values3_dict},
+ /*result_is_encoded=*/false);
+ CheckDictionary("if_else", {cond, values1_dict, values4_dict},
+ /*result_is_encoded=*/false);
}
TYPED_TEST(TestIfElseDict, NestedSimple) {
From 36b2a77b3ef06c804afa0001862ec11129003e26 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 30 Sep 2021 11:09:34 -0400
Subject: [PATCH 4/5] ARROW-13358: [C++] Address feedback
---
cpp/src/arrow/array/array_test.cc | 2 +
cpp/src/arrow/array/util.cc | 3 +-
.../arrow/compute/kernels/scalar_if_else.cc | 147 +++++++++++-------
.../compute/kernels/scalar_if_else_test.cc | 74 +++++----
4 files changed, 129 insertions(+), 97 deletions(-)
diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc
index b6cfa0dce63..e5dfd0ccd1a 100644
--- a/cpp/src/arrow/array/array_test.cc
+++ b/cpp/src/arrow/array/array_test.cc
@@ -544,10 +544,12 @@ static ScalarVector GetScalars() {
sparse_union_ty),
std::make_shared(std::make_shared(100), 42,
sparse_union_ty),
+ std::make_shared(42, sparse_union_ty),
std::make_shared(std::make_shared(101), 6,
dense_union_ty),
std::make_shared(std::make_shared(101), 42,
dense_union_ty),
+ std::make_shared(42, dense_union_ty),
DictionaryScalar::Make(ScalarFromJSON(int8(), "1"),
ArrayFromJSON(utf8(), R"(["foo", "bar"])")),
DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"),
diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc
index 71f91497f61..d639830f469 100644
--- a/cpp/src/arrow/array/util.cc
+++ b/cpp/src/arrow/array/util.cc
@@ -775,7 +775,8 @@ Result> MakeArrayOfNull(const std::shared_ptr&
Result> MakeArrayFromScalar(const Scalar& scalar, int64_t length,
MemoryPool* pool) {
- if (!scalar.is_valid) {
+ // Null union scalars still have a type code associated
+ if (!scalar.is_valid && !is_union(scalar.type->id())) {
return MakeArrayOfNull(scalar.type, length, pool);
}
return RepeatedArrayFactory(pool, scalar, length).Create();
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
index bb7c417f1e5..b3ebba8ea00 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc
@@ -948,9 +948,7 @@ struct IfElseFunctor> {
};
// Use builders for dictionaries - slower, but allows us to unify dictionaries
-template
-struct IfElseFunctor<
- Type, enable_if_t::value || is_dictionary_type::value>> {
+struct NestedIfElseExec {
// A - Array, S - Scalar, X = Array/Scalar
// SXX
@@ -989,11 +987,11 @@ struct IfElseFunctor<
const ArrayData& right, ArrayData* out) {
return RunLoop(
ctx, cond, out,
- [&](ArrayBuilder* builder, int64_t i) {
- return builder->AppendArraySlice(left, i, /*length=*/1);
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(left, i, length);
},
- [&](ArrayBuilder* builder, int64_t i) {
- return builder->AppendArraySlice(right, i, /*length=*/1);
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(right, i, length);
});
}
@@ -1002,9 +1000,11 @@ struct IfElseFunctor<
const ArrayData& right, ArrayData* out) {
return RunLoop(
ctx, cond, out,
- [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); },
- [&](ArrayBuilder* builder, int64_t i) {
- return builder->AppendArraySlice(right, i, /*length=*/1);
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(left, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(right, i, length);
});
}
@@ -1013,10 +1013,12 @@ struct IfElseFunctor<
const Scalar& right, ArrayData* out) {
return RunLoop(
ctx, cond, out,
- [&](ArrayBuilder* builder, int64_t i) {
- return builder->AppendArraySlice(left, i, /*length=*/1);
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendArraySlice(left, i, length);
},
- [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); });
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(right, length);
+ });
}
// ASS
@@ -1024,8 +1026,12 @@ struct IfElseFunctor<
const Scalar& right, ArrayData* out) {
return RunLoop(
ctx, cond, out,
- [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); },
- [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); });
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(left, length);
+ },
+ [&](ArrayBuilder* builder, int64_t i, int64_t length) {
+ return builder->AppendScalar(right, length);
+ });
}
template
@@ -1036,32 +1042,68 @@ struct IfElseFunctor<
RETURN_NOT_OK(raw_builder->Reserve(out->length));
const auto* cond_data = cond.buffers[1]->data();
- if (out->buffers[0]) {
- const uint8_t* out_valid = out->buffers[0]->data();
- for (int64_t i = 0; i < cond.length; i++) {
- if (BitUtil::GetBit(out_valid, i)) {
- if (BitUtil::GetBit(cond_data, cond.offset + i)) {
- RETURN_NOT_OK(handle_left(raw_builder.get(), i));
- } else {
- RETURN_NOT_OK(handle_right(raw_builder.get(), i));
+ if (cond.buffers[0]) {
+ BitRunReader reader(cond.buffers[0]->data(), cond.offset, cond.length);
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ for (int j = 0; j < run.length; j++) {
+ if (BitUtil::GetBit(cond_data, cond.offset + position + j)) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), position + j, 1));
+ } else {
+ RETURN_NOT_OK(handle_right(raw_builder.get(), position + j, 1));
+ }
}
} else {
- RETURN_NOT_OK(raw_builder->AppendNull());
+ RETURN_NOT_OK(raw_builder->AppendNulls(run.length));
}
+ position += run.length;
}
} else {
- for (int64_t i = 0; i < cond.length; i++) {
- if (BitUtil::GetBit(cond_data, cond.offset + i)) {
- RETURN_NOT_OK(handle_left(raw_builder.get(), i));
+ BitRunReader reader(cond_data, cond.offset, cond.length);
+ int64_t position = 0;
+ while (true) {
+ auto run = reader.NextRun();
+ if (run.length == 0) break;
+ if (run.set) {
+ RETURN_NOT_OK(handle_left(raw_builder.get(), position, run.length));
} else {
- RETURN_NOT_OK(handle_right(raw_builder.get(), i));
+ RETURN_NOT_OK(handle_right(raw_builder.get(), position, run.length));
}
+ position += run.length;
}
}
ARROW_ASSIGN_OR_RAISE(auto out_arr, raw_builder->Finish());
*out = std::move(*out_arr->data());
return Status::OK();
}
+
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2));
+ if (batch[0].is_scalar()) {
+ const auto& cond = batch[0].scalar_as();
+ return Call(ctx, cond, batch[1], batch[2], out);
+ }
+ if (batch[1].kind() == Datum::ARRAY) {
+ if (batch[2].kind() == Datum::ARRAY) { // AAA
+ return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].array(),
+ out->mutable_array());
+ } else { // AAS
+ return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].scalar(),
+ out->mutable_array());
+ }
+ } else {
+ if (batch[2].kind() == Datum::ARRAY) { // ASA
+ return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].array(),
+ out->mutable_array());
+ } else { // ASS
+ return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].scalar(),
+ out->mutable_array());
+ }
+ }
+ }
};
template
@@ -1136,8 +1178,10 @@ struct IfElseFunction : ScalarFunction {
internal::ReplaceNullWithOtherType(left_arg, num_args);
- if (is_dictionary((*values)[1].type->id()) &&
- (*values)[1].type->Equals(*(*values)[2].type)) {
+ // If both are identical dictionary types, dispatch to the dictionary kernel
+ // TODO(ARROW-14105): apply implicit casts to dictionary types too
+ ValueDescr* right_arg = &(*values)[2];
+ if (is_dictionary(left_arg->type->id()) && left_arg->type->Equals(right_arg->type)) {
auto kernel = DispatchExactImpl(this, *values);
DCHECK(kernel);
return kernel;
@@ -1147,14 +1191,11 @@ struct IfElseFunction : ScalarFunction {
if (auto type = internal::CommonNumeric(left_arg, num_args)) {
internal::ReplaceTypes(type, left_arg, num_args);
- }
- if (auto type = internal::CommonTemporal(left_arg, num_args)) {
+ } else if (auto type = internal::CommonTemporal(left_arg, num_args)) {
internal::ReplaceTypes(type, left_arg, num_args);
- }
- if (auto type = internal::CommonBinary(left_arg, num_args)) {
+ } else if (auto type = internal::CommonBinary(left_arg, num_args)) {
internal::ReplaceTypes(type, left_arg, num_args);
- }
- if (HasDecimal(*values)) {
+ } else if (HasDecimal(*values)) {
RETURN_NOT_OK(CastDecimalArgs(left_arg, num_args));
}
@@ -1230,14 +1271,17 @@ void AddFixedWidthIfElseKernel(const std::shared_ptr& scalar_fun
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}
-void AddNestedIfElseKernel(const std::shared_ptr& scalar_function,
- detail::GetTypeId get_id, ArrayKernelExec exec) {
- ScalarKernel kernel({boolean(), InputType(get_id.id), InputType(get_id.id)},
- OutputType(LastType), std::move(exec));
- kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
- kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
- kernel.can_write_into_slices = false;
- DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+void AddNestedIfElseKernels(const std::shared_ptr& scalar_function) {
+ for (const auto type_id :
+ {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST, Type::STRUCT,
+ Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) {
+ ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)},
+ OutputType(LastType), NestedIfElseExec::Exec);
+ kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
+ kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
+ kernel.can_write_into_slices = false;
+ DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
+ }
}
// Helper to copy or broadcast fixed-width values between buffers.
@@ -2795,20 +2839,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddFixedWidthIfElseKernel(func);
AddFixedWidthIfElseKernel(func);
AddFixedWidthIfElseKernel(func);
- AddNestedIfElseKernel(func, Type::LIST,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::LARGE_LIST,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::FIXED_SIZE_LIST,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::STRUCT,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::DENSE_UNION,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::SPARSE_UNION,
- ResolveIfElseExec::Exec);
- AddNestedIfElseKernel(func, Type::DICTIONARY,
- ResolveIfElseExec::Exec);
+ AddNestedIfElseKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
index 5acb0b16b6c..92e0582c6f1 100644
--- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
@@ -116,54 +116,49 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
CheckScalar("if_else", {cond, left, right}, expected);
auto len = left->length();
+ std::vector array_indices = {-1}; // sentinel for make_input
+ std::vector scalar_indices(len);
+ std::iota(scalar_indices.begin(), scalar_indices.end(), 0);
+ auto make_input = [&](const std::shared_ptr& array, int64_t index, Datum* input,
+ Datum* input_broadcast, std::string* trace) {
+ if (index >= 0) {
+ // Use scalar from array[index] as input; broadcast scalar for computing expected
+ // result
+ ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(index));
+ *trace += "@" + std::to_string(index) + "=" + scalar->ToString();
+ *input = std::move(scalar);
+ ASSERT_OK_AND_ASSIGN(*input_broadcast, MakeArrayFromScalar(*input->scalar(), len));
+ } else {
+ // Use array as input
+ *trace += "=Array";
+ *input = *input_broadcast = array;
+ }
+ };
enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 };
for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) {
- for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) {
+ for (int64_t cond_idx : (mask & COND_SCALAR) ? scalar_indices : array_indices) {
Datum cond_in, cond_bcast;
std::string trace_cond = "Cond";
- if (mask & COND_SCALAR) {
- ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx));
- ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len));
- trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString();
- } else {
- cond_in = cond_bcast = cond;
- trace_cond += "=Array";
- }
- SCOPED_TRACE(trace_cond);
+ make_input(cond, cond_idx, &cond_in, &cond_bcast, &trace_cond);
- for (int64_t left_idx = 0; left_idx < len; ++left_idx) {
+ for (int64_t left_idx : (mask & LEFT_SCALAR) ? scalar_indices : array_indices) {
Datum left_in, left_bcast;
std::string trace_left = "Left";
- if (mask & LEFT_SCALAR) {
- ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As());
- ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len));
- trace_left +=
- "@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString();
- } else {
- left_in = left_bcast = left;
- trace_left += "=Array";
- }
- SCOPED_TRACE(trace_left);
+ make_input(left, left_idx, &left_in, &left_bcast, &trace_left);
- for (int64_t right_idx = 0; right_idx < len; ++right_idx) {
+ for (int64_t right_idx : (mask & RIGHT_SCALAR) ? scalar_indices : array_indices) {
Datum right_in, right_bcast;
std::string trace_right = "Right";
- if (mask & RIGHT_SCALAR) {
- ASSERT_OK_AND_ASSIGN(right_in, right->GetScalar(right_idx));
- ASSERT_OK_AND_ASSIGN(right_bcast,
- MakeArrayFromScalar(*right_in.scalar(), len));
- trace_right +=
- "@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString();
- } else {
- right_in = right_bcast = right;
- trace_right += "=Array";
- }
+ make_input(right, right_idx, &right_in, &right_bcast, &trace_right);
+
SCOPED_TRACE(trace_right);
+ SCOPED_TRACE(trace_left);
+ SCOPED_TRACE(trace_cond);
Datum expected;
ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in));
- if (mask & COND_SCALAR && mask & LEFT_SCALAR && mask & RIGHT_SCALAR) {
+ if (mask == (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR)) {
const auto& scalar = cond_in.scalar_as();
if (scalar.is_valid) {
expected = scalar.value ? left_in : right_in;
@@ -177,13 +172,9 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond,
} else {
ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, right_bcast));
}
- AssertDatumsEqual(actual, expected, /*verbose=*/true);
-
- if (right_in.is_array()) break;
+ AssertDatumsEqual(expected, actual, /*verbose=*/true);
}
- if (left_in.is_array()) break;
}
- if (cond_in.is_array()) break;
}
} // for (mask)
}
@@ -734,6 +725,13 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
CallFunction("if_else",
{cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")}));
+ type0 = timestamp(TimeUnit::SECOND);
+ type1 = timestamp(TimeUnit::MILLI);
+ CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"),
+ ArrayFromJSON(type0, "[1, 2, 3, 4]"),
+ ArrayFromJSON(type1, "[5, 6, 7, 8]"),
+ ArrayFromJSON(type1, "[1000, 2000, 3000, 8]"));
+
type0 = timestamp(TimeUnit::SECOND);
type1 = timestamp(TimeUnit::SECOND, "America/Phoenix");
EXPECT_RAISES_WITH_MESSAGE_THAT(
From fe8329e7aecdced6eb98510002dff15351940bf2 Mon Sep 17 00:00:00 2001
From: Antoine Pitrou
Date: Mon, 4 Oct 2021 15:38:35 +0200
Subject: [PATCH 5/5] Update docs
---
docs/source/cpp/compute.rst | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 0a87b0b1d29..f699698e97b 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -1145,11 +1145,11 @@ depending on a condition.
+==================+============+===================================================+=====================+=========+
| case_when | Varargs | Struct of Boolean (Arg 0), Any (rest) | Input type | \(1) |
+------------------+------------+---------------------------------------------------+---------------------+---------+
-| choose | Varargs | Integral (Arg 0); Fixed-width/Binary-like (rest) | Input type | \(2) |
+| choose | Varargs | Integral (Arg 0), Fixed-width/Binary-like (rest) | Input type | \(2) |
+------------------+------------+---------------------------------------------------+---------------------+---------+
| coalesce | Varargs | Any | Input type | \(3) |
+------------------+------------+---------------------------------------------------+---------------------+---------+
-| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(4) |
+| if_else | Ternary | Boolean (Arg 0), Any (rest) | Input type | \(4) |
+------------------+------------+---------------------------------------------------+---------------------+---------+
* \(1) This function acts like a SQL "case when" statement or switch-case. The