From f6165c8cd6e178a5c82c40d4dfd1afb61a5d814c Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 22 Sep 2021 16:23:01 -0400 Subject: [PATCH 1/5] ARROW-13358: [C++] Improve type support in if_else --- .../arrow/compute/kernels/codegen_internal.cc | 43 +- .../arrow/compute/kernels/codegen_internal.h | 13 +- .../compute/kernels/codegen_internal_test.cc | 42 +- .../arrow/compute/kernels/scalar_compare.cc | 4 +- .../arrow/compute/kernels/scalar_if_else.cc | 279 ++++++++-- .../compute/kernels/scalar_if_else_test.cc | 488 +++++++++++++++++- 6 files changed, 761 insertions(+), 108 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 8a8292c1c89..9077c7e9f0b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -66,29 +66,45 @@ Result FirstType(KernelContext*, const std::vector& desc return result; } +Result LastType(KernelContext*, const std::vector& descrs) { + ValueDescr result = descrs.back(); + result.shape = GetBroadcastShape(descrs); + return result; +} + Result ListValuesType(KernelContext*, const std::vector& args) { const auto& list_type = checked_cast(*args[0].type); return ValueDescr(list_type.value_type(), GetBroadcastShape(args)); } void EnsureDictionaryDecoded(std::vector* descrs) { - for (ValueDescr& descr : *descrs) { - if (descr.type->id() == Type::DICTIONARY) { - descr.type = checked_cast(*descr.type).value_type(); + EnsureDictionaryDecoded(descrs->data(), descrs->size()); +} + +void EnsureDictionaryDecoded(ValueDescr* begin, size_t count) { + auto* end = begin + count; + for (auto it = begin; it != end; it++) { + if (it->type->id() == Type::DICTIONARY) { + it->type = checked_cast(*it->type).value_type(); } } } void ReplaceNullWithOtherType(std::vector* descrs) { - DCHECK_EQ(descrs->size(), 2); + ReplaceNullWithOtherType(descrs->data(), descrs->size()); +} + +void ReplaceNullWithOtherType(ValueDescr* first, size_t count) { + DCHECK_EQ(count, 2); - if (descrs->at(0).type->id() == Type::NA) { - descrs->at(0).type = descrs->at(1).type; + ValueDescr* second = first++; + if (first->type->id() == Type::NA) { + first->type = second->type; return; } - if (descrs->at(1).type->id() == Type::NA) { - descrs->at(1).type = descrs->at(0).type; + if (second->type->id() == Type::NA) { + second->type = first->type; return; } } @@ -164,14 +180,15 @@ std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count) { return int8(); } -std::shared_ptr CommonTemporal(const std::vector& descrs) { +std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) { TimeUnit::type finest_unit = TimeUnit::SECOND; const std::string* timezone = nullptr; bool saw_date32 = false; bool saw_date64 = false; - for (const auto& descr : descrs) { - auto id = descr.type->id(); + const ValueDescr* end = begin + count; + for (auto it = begin; it != end; it++) { + auto id = it->type->id(); // a common timestamp is only possible if all types are timestamp like switch (id) { case Type::DATE32: @@ -183,9 +200,7 @@ std::shared_ptr CommonTemporal(const std::vector& descrs) saw_date64 = true; continue; case Type::TIMESTAMP: { - const auto& ty = checked_cast(*descr.type); - // Don't cast to common timezone by default (may not make - // sense for all kernels) + const auto& ty = checked_cast(*it->type); if (timezone && *timezone != ty.timezone()) return nullptr; timezone = &ty.timezone(); finest_unit = std::max(finest_unit, ty.unit()); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 52e8f326d89..a4e2c3d6d1a 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -296,14 +296,14 @@ struct UnboxScalar> { template <> struct UnboxScalar { - static Decimal128 Unbox(const Scalar& val) { + static const Decimal128& Unbox(const Scalar& val) { return checked_cast(val).value; } }; template <> struct UnboxScalar { - static Decimal256 Unbox(const Scalar& val) { + static const Decimal256& Unbox(const Scalar& val) { return checked_cast(val).value; } }; @@ -397,6 +397,7 @@ static void VisitTwoArrayValuesInline(const ArrayData& arr0, const ArrayData& ar // Reusable type resolvers Result FirstType(KernelContext*, const std::vector& descrs); +Result LastType(KernelContext*, const std::vector& descrs); Result ListValuesType(KernelContext*, const std::vector& args); // ---------------------------------------------------------------------- @@ -1279,9 +1280,15 @@ ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) { ARROW_EXPORT void EnsureDictionaryDecoded(std::vector* descrs); +ARROW_EXPORT +void EnsureDictionaryDecoded(ValueDescr* begin, size_t count); + ARROW_EXPORT void ReplaceNullWithOtherType(std::vector* descrs); +ARROW_EXPORT +void ReplaceNullWithOtherType(ValueDescr* begin, size_t count); + ARROW_EXPORT void ReplaceTypes(const std::shared_ptr&, std::vector* descrs); @@ -1295,7 +1302,7 @@ ARROW_EXPORT std::shared_ptr CommonNumeric(const ValueDescr* begin, size_t count); ARROW_EXPORT -std::shared_ptr CommonTemporal(const std::vector& descrs); +std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count); ARROW_EXPORT std::shared_ptr CommonBinary(const std::vector& descrs); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index a830d0c7636..d64143dea31 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -130,24 +130,32 @@ TEST(TestDispatchBest, CastDecimalArgs) { } TEST(TestDispatchBest, CommonTemporal) { - AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal({timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::NANO)})); + std::vector args; + + args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::NANO)}; + AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::NANO, "UTC")}; AssertTypeEqual(timestamp(TimeUnit::NANO, "UTC"), - CommonTemporal({timestamp(TimeUnit::SECOND, "UTC"), - timestamp(TimeUnit::NANO, "UTC")})); - AssertTypeEqual(timestamp(TimeUnit::NANO), - CommonTemporal({date32(), timestamp(TimeUnit::NANO)})); - AssertTypeEqual(timestamp(TimeUnit::MILLI), - CommonTemporal({date64(), timestamp(TimeUnit::SECOND)})); - AssertTypeEqual(date32(), CommonTemporal({date32(), date32()})); - AssertTypeEqual(date64(), CommonTemporal({date64(), date64()})); - AssertTypeEqual(date64(), CommonTemporal({date32(), date64()})); - ASSERT_EQ(nullptr, CommonTemporal({})); - ASSERT_EQ(nullptr, CommonTemporal({float64(), int32()})); - ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND, "UTC")})); - ASSERT_EQ(nullptr, CommonTemporal({timestamp(TimeUnit::SECOND, "America/Phoenix"), - timestamp(TimeUnit::SECOND, "UTC")})); + CommonTemporal(args.data(), args.size())); + args = {date32(), timestamp(TimeUnit::NANO)}; + AssertTypeEqual(timestamp(TimeUnit::NANO), CommonTemporal(args.data(), args.size())); + args = {date64(), timestamp(TimeUnit::SECOND)}; + AssertTypeEqual(timestamp(TimeUnit::MILLI), CommonTemporal(args.data(), args.size())); + args = {date32(), date32()}; + AssertTypeEqual(date32(), CommonTemporal(args.data(), args.size())); + args = {date64(), date64()}; + AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size())); + args = {date32(), date64()}; + AssertTypeEqual(date64(), CommonTemporal(args.data(), args.size())); + args = {}; + ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); + args = {float64(), int32()}; + ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC")}; + ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); + args = {timestamp(TimeUnit::SECOND, "America/Phoenix"), + timestamp(TimeUnit::SECOND, "UTC")}; + ASSERT_EQ(nullptr, CommonTemporal(args.data(), args.size())); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index c42b0ddac24..e42bcb7b25c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -171,7 +171,7 @@ struct CompareFunction : ScalarFunction { if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); - } else if (auto type = CommonTemporal(*values)) { + } else if (auto type = CommonTemporal(values->data(), values->size())) { ReplaceTypes(type, values); } else if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); @@ -195,7 +195,7 @@ struct VarArgsCompareFunction : ScalarFunction { if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); - } else if (auto type = CommonTemporal(*values)) { + } else if (auto type = CommonTemporal(values->data(), values->size())) { ReplaceTypes(type, values); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 3b0da0a489a..e9432868ee8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -810,12 +810,11 @@ struct IfElseFunctor> { }, /*BroadcastScalar*/ [&](const Scalar& scalar, ArrayData* out_array) { - const util::string_view& scalar_data = - internal::UnboxScalar::Unbox(scalar); + const uint8_t* scalar_data = UnboxBinaryScalar(scalar); uint8_t* start = out_array->buffers[1]->mutable_data() + out_array->offset * byte_width; for (int64_t i = 0; i < out_array->length; i++) { - std::memcpy(start + i * byte_width, scalar_data.data(), scalar_data.size()); + std::memcpy(start + i * byte_width, scalar_data, byte_width); } }); } @@ -852,14 +851,12 @@ struct IfElseFunctor> { std::memcpy(out_values, right_data, right.length * byte_width); // selectively copy values from left data - const util::string_view& left_data = - internal::UnboxScalar::Unbox(left); + const uint8_t* left_data = UnboxBinaryScalar(left); RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) { - if (left_data.data()) { + if (left_data) { for (int64_t i = 0; i < num_elems; i++) { - std::memcpy(out_values + (data_offset + i) * byte_width, left_data.data(), - left_data.size()); + std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width); } } }); @@ -877,14 +874,13 @@ struct IfElseFunctor> { const uint8_t* left_data = left.buffers[1]->data() + left.offset * byte_width; std::memcpy(out_values, left_data, left.length * byte_width); - const util::string_view& right_data = - internal::UnboxScalar::Unbox(right); + const uint8_t* right_data = UnboxBinaryScalar(right); RunIfElseLoopInverted(cond, [&](int64_t data_offset, int64_t num_elems) { - if (right_data.data()) { + if (right_data) { for (int64_t i = 0; i < num_elems; i++) { - std::memcpy(out_values + (data_offset + i) * byte_width, right_data.data(), - right_data.size()); + std::memcpy(out_values + (data_offset + i) * byte_width, right_data, + byte_width); } } }); @@ -899,23 +895,19 @@ struct IfElseFunctor> { auto* out_values = out->buffers[1]->mutable_data() + out->offset * byte_width; // copy right data to out_buff - const util::string_view& right_data = - internal::UnboxScalar::Unbox(right); - if (right_data.data()) { + const uint8_t* right_data = UnboxBinaryScalar(right); + if (right_data) { for (int64_t i = 0; i < cond.length; i++) { - std::memcpy(out_values + i * byte_width, right_data.data(), right_data.size()); + std::memcpy(out_values + i * byte_width, right_data, byte_width); } } // selectively copy values from left data - const util::string_view& left_data = - internal::UnboxScalar::Unbox(left); - + const uint8_t* left_data = UnboxBinaryScalar(left); RunIfElseLoop(cond, [&](int64_t data_offset, int64_t num_elems) { - if (left_data.data()) { + if (left_data) { for (int64_t i = 0; i < num_elems; i++) { - std::memcpy(out_values + (data_offset + i) * byte_width, left_data.data(), - left_data.size()); + std::memcpy(out_values + (data_offset + i) * byte_width, left_data, byte_width); } } }); @@ -923,20 +915,162 @@ struct IfElseFunctor> { return Status::OK(); } - static Result GetByteWidth(const DataType& left_type, - const DataType& right_type) { - int width = checked_cast(left_type).byte_width(); - if (width == checked_cast(right_type).byte_width()) { - return width; + template + static enable_if_t::value, const uint8_t*> UnboxBinaryScalar( + const Scalar& scalar) { + return reinterpret_cast( + internal::UnboxScalar::Unbox(scalar).data()); + } + + template + static enable_if_decimal UnboxBinaryScalar(const Scalar& scalar) { + return internal::UnboxScalar::Unbox(scalar).native_endian_bytes(); + } + + template + static enable_if_t::value, Result> GetByteWidth( + const DataType& left_type, const DataType& right_type) { + const int32_t width = + checked_cast(left_type).byte_width(); + DCHECK_EQ(width, checked_cast(right_type).byte_width()); + return width; + } + + template + static enable_if_decimal> GetByteWidth(const DataType& left_type, + const DataType& right_type) { + const auto& left = checked_cast(left_type); + const auto& right = checked_cast(right_type); + DCHECK_EQ(left.precision(), right.precision()); + DCHECK_EQ(left.scale(), right.scale()); + return left.byte_width(); + } +}; + +// Use builders for dictionaries - slower, but allows us to unify dictionaries +template +struct IfElseFunctor< + Type, enable_if_t::value || is_dictionary_type::value>> { + // A - Array, S - Scalar, X = Array/Scalar + + // SXX + static Status Call(KernelContext* ctx, const BooleanScalar& cond, const Datum& left, + const Datum& right, Datum* out) { + if (left.is_scalar() && right.is_scalar()) { + if (cond.is_valid) { + *out = cond.value ? left.scalar() : right.scalar(); + } else { + *out = MakeNullScalar(left.type()); + } + return Status::OK(); + } + // either left or right is an array. Output is always an array + int64_t out_arr_len = std::max(left.length(), right.length()); + if (!cond.is_valid) { + // cond is null; just create a null array + ARROW_ASSIGN_OR_RAISE(*out, + MakeArrayOfNull(left.type(), out_arr_len, ctx->memory_pool())) + return Status::OK(); + } + + const auto& valid_data = cond.value ? left : right; + if (valid_data.is_array()) { + *out = valid_data; + } else { + // valid data is a scalar that needs to be broadcasted + ARROW_ASSIGN_OR_RAISE(*out, MakeArrayFromScalar(*valid_data.scalar(), out_arr_len, + ctx->memory_pool())); + } + return Status::OK(); + } + + // AAA + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return RunLoop( + ctx, cond, out, + [&](ArrayBuilder* builder, int64_t i) { + return builder->AppendArraySlice(left, i, /*length=*/1); + }, + [&](ArrayBuilder* builder, int64_t i) { + return builder->AppendArraySlice(right, i, /*length=*/1); + }); + } + + // ASA + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + return RunLoop( + ctx, cond, out, + [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); }, + [&](ArrayBuilder* builder, int64_t i) { + return builder->AppendArraySlice(right, i, /*length=*/1); + }); + } + + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return RunLoop( + ctx, cond, out, + [&](ArrayBuilder* builder, int64_t i) { + return builder->AppendArraySlice(left, i, /*length=*/1); + }, + [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); }); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* out) { + return RunLoop( + ctx, cond, out, + [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); }, + [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); }); + } + + template + static Status RunLoop(KernelContext* ctx, const ArrayData& cond, ArrayData* out, + HandleLeft&& handle_left, HandleRight&& handle_right) { + std::unique_ptr raw_builder; + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type, &raw_builder)); + RETURN_NOT_OK(raw_builder->Reserve(out->length)); + + const auto* cond_data = cond.buffers[1]->data(); + if (out->buffers[0]) { + const uint8_t* out_valid = out->buffers[0]->data(); + for (int64_t i = 0; i < cond.length; i++) { + if (BitUtil::GetBit(out_valid, i)) { + if (BitUtil::GetBit(cond_data, cond.offset + i)) { + RETURN_NOT_OK(handle_left(raw_builder.get(), i)); + } else { + RETURN_NOT_OK(handle_right(raw_builder.get(), i)); + } + } else { + RETURN_NOT_OK(raw_builder->AppendNull()); + } + } } else { - return Status::Invalid("FixedSizeBinaryType byte_widths should be equal"); + for (int64_t i = 0; i < cond.length; i++) { + if (BitUtil::GetBit(cond_data, cond.offset + i)) { + RETURN_NOT_OK(handle_left(raw_builder.get(), i)); + } else { + RETURN_NOT_OK(handle_right(raw_builder.get(), i)); + } + } } + ARROW_ASSIGN_OR_RAISE(auto out_arr, raw_builder->Finish()); + *out = std::move(*out_arr->data()); + return Status::OK(); } }; template struct ResolveIfElseExec { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // Check is unconditional because parametric types like timestamp + // are templated as integer + RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2)); + // cond is scalar if (batch[0].is_scalar()) { const auto& cond = batch[0].scalar_as(); @@ -988,7 +1122,8 @@ struct IfElseFunction : ScalarFunction { RETURN_NOT_OK(CheckArity(*values)); using arrow::compute::detail::DispatchExactImpl; - if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + // Do not DispatchExact here because it'll let through something like (bool, + // timestamp[s], timestamp[s, "UTC"]) // if 0th descriptor is null, replace with bool if (values->at(0).type->id() == Type::NA) { @@ -996,15 +1131,26 @@ struct IfElseFunction : ScalarFunction { } // if-else 0'th descriptor is bool, so skip it - std::vector values_copy(values->begin() + 1, values->end()); - internal::EnsureDictionaryDecoded(&values_copy); - internal::ReplaceNullWithOtherType(&values_copy); + ValueDescr* left_arg = &(*values)[1]; + constexpr size_t num_args = 2; - if (auto type = internal::CommonNumeric(values_copy)) { - internal::ReplaceTypes(type, &values_copy); + internal::ReplaceNullWithOtherType(left_arg, num_args); + + if (is_dictionary((*values)[1].type->id()) && + is_dictionary((*values)[2].type->id())) { + auto kernel = DispatchExactImpl(this, *values); + DCHECK(kernel); + return kernel; } - std::move(values_copy.begin(), values_copy.end(), values->begin() + 1); + internal::EnsureDictionaryDecoded(left_arg, num_args); + + if (auto type = internal::CommonNumeric(left_arg, num_args)) { + internal::ReplaceTypes(type, left_arg, num_args); + } + if (auto type = internal::CommonTemporal(left_arg, num_args)) { + internal::ReplaceTypes(type, left_arg, num_args); + } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -1030,7 +1176,16 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun internal::GenerateTypeAgnosticPrimitive(*type); // cond array needs to be boolean always - ScalarKernel kernel({boolean(), type, type}, type, exec); + std::shared_ptr sig; + if (type->id() == Type::TIMESTAMP) { + auto unit = checked_cast(*type).unit(); + sig = KernelSignature::Make( + {boolean(), match::TimestampTypeUnit(unit), match::TimestampTypeUnit(unit)}, + OutputType(LastType)); + } else { + sig = KernelSignature::Make({boolean(), type, type}, type); + } + ScalarKernel kernel(std::move(sig), exec); kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; kernel.mem_allocation = MemAllocation::PREALLOCATE; kernel.can_write_into_slices = true; @@ -1056,14 +1211,12 @@ void AddBinaryIfElseKernels(const std::shared_ptr& scalar_functi } } -void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_function) { - // cond array needs to be boolean always - ScalarKernel kernel( - {boolean(), InputType(Type::FIXED_SIZE_BINARY), InputType(Type::FIXED_SIZE_BINARY)}, - OutputType([](KernelContext*, const std::vector& descrs) { - return ValueDescr(descrs[1].type, ValueDescr::ANY); - }), - ResolveIfElseExec::Exec); +template +void AddFixedWidthIfElseKernel(const std::shared_ptr& scalar_function) { + auto type_id = T::type_id; + ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, + OutputType(LastType), + ResolveIfElseExec::Exec); kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE; kernel.mem_allocation = MemAllocation::PREALLOCATE; kernel.can_write_into_slices = true; @@ -1071,6 +1224,16 @@ void AddFSBinaryIfElseKernel(const std::shared_ptr& scalar_funct DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } +void AddNestedIfElseKernel(const std::shared_ptr& scalar_function, + detail::GetTypeId get_id, ArrayKernelExec exec) { + ScalarKernel kernel({boolean(), InputType(get_id.id), InputType(get_id.id)}, + OutputType(LastType), std::move(exec)); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + kernel.can_write_into_slices = false; + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); +} + // Helper to copy or broadcast fixed-width values between buffers. template struct CopyFixedWidth {}; @@ -1758,7 +1921,7 @@ struct CoalesceFunction : ScalarFunction { if (auto type = CommonBinary(*values)) { ReplaceTypes(type, values); } - if (auto type = CommonTemporal(*values)) { + if (auto type = CommonTemporal(values->data(), values->size())) { ReplaceTypes(type, values); } if (HasDecimal(*values)) { @@ -2500,12 +2663,6 @@ struct ChooseFunction : ScalarFunction { } }; -Result LastType(KernelContext*, const std::vector& descrs) { - ValueDescr result = descrs.back(); - result.shape = GetBroadcastShape(descrs); - return result; -} - void AddCaseWhenKernel(const std::shared_ptr& scalar_function, detail::GetTypeId get_id, ArrayKernelExec exec) { ScalarKernel kernel( @@ -2629,7 +2786,23 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveIfElseKernels(func, {boolean()}); AddNullIfElseKernel(func); AddBinaryIfElseKernels(func, BaseBinaryTypes()); - AddFSBinaryIfElseKernel(func); + AddFixedWidthIfElseKernel(func); + AddFixedWidthIfElseKernel(func); + AddFixedWidthIfElseKernel(func); + AddNestedIfElseKernel(func, Type::LIST, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::LARGE_LIST, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::FIXED_SIZE_LIST, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::STRUCT, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::DENSE_UNION, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::SPARSE_UNION, + ResolveIfElseExec::Exec); + AddNestedIfElseKernel(func, Type::DICTIONARY, + ResolveIfElseExec::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 907c7b0e638..5e74d4a60e4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -19,6 +19,7 @@ #include "arrow/array.h" #include "arrow/array/concatenate.h" #include "arrow/compute/api_scalar.h" +#include "arrow/compute/cast.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/registry.h" #include "arrow/testing/gtest_util.h" @@ -103,7 +104,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, auto len = left->length(); enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 }; - for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { + for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { Datum cond_in, cond_bcast; std::string trace_cond = "Cond"; @@ -113,6 +114,7 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); } else { cond_in = cond_bcast = cond; + trace_cond += "=Array"; } SCOPED_TRACE(trace_cond); @@ -122,10 +124,11 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, if (mask & LEFT_SCALAR) { ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As()); ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len)); - trace_cond += + trace_left += "@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString(); } else { left_in = left_bcast = left; + trace_left += "=Array"; } SCOPED_TRACE(trace_left); @@ -140,12 +143,27 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, "@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString(); } else { right_in = right_bcast = right; + trace_right += "=Array"; } SCOPED_TRACE(trace_right); - ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast, right_bcast)); + Datum expected; ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in)); - AssertDatumsEqual(exp, actual, /*verbose=*/true); + if (mask & COND_SCALAR && mask & LEFT_SCALAR && mask & RIGHT_SCALAR) { + const auto& scalar = cond_in.scalar_as(); + if (scalar.is_valid) { + expected = scalar.value ? left_in : right_in; + } else { + expected = MakeNullScalar(left_in.type()); + } + if (!left_in.type()->Equals(*right_in.type())) { + ASSERT_OK_AND_ASSIGN(expected, + Cast(expected, CastOptions::Safe(actual.type()))); + } + } else { + ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, right_bcast)); + } + AssertDatumsEqual(actual, expected, /*verbose=*/true); if (right_in.is_array()) break; } @@ -288,8 +306,43 @@ TEST_F(TestIfElseKernel, IfElseMultiType) { ArrayFromJSON(float32(), "[1, 2, 3, 8]")); } +TEST_F(TestIfElseKernel, TimestampTypes) { + for (const auto unit : TimeUnit::values()) { + auto ty = timestamp(unit); + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(ty, "[1, 2, 3, 4]"), + ArrayFromJSON(ty, "[5, 6, 7, 8]"), + ArrayFromJSON(ty, "[1, 2, 3, 8]")); + + ty = timestamp(unit, "America/Phoenix"); + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(ty, "[1, 2, 3, 4]"), + ArrayFromJSON(ty, "[5, 6, 7, 8]"), + ArrayFromJSON(ty, "[1, 2, 3, 8]")); + } +} + +TEST_F(TestIfElseKernel, TemporalTypes) { + for (const auto& ty : TemporalTypes()) { + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(ty, "[1, 2, 3, 4]"), + ArrayFromJSON(ty, "[5, 6, 7, 8]"), + ArrayFromJSON(ty, "[1, 2, 3, 8]")); + } +} + +TEST_F(TestIfElseKernel, DayTimeInterval) { + auto ty = day_time_interval(); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [-7, -8]]"), + ArrayFromJSON(ty, "[[-9, -10], [11, -12], [-13, 14], [15, 16]]"), + ArrayFromJSON(ty, "[[1, 2], [3, -4], [-5, 6], [15, 16]]")); +} + TEST_F(TestIfElseKernel, IfElseDispatchBest) { std::string name = "if_else"; + ASSERT_OK_AND_ASSIGN(auto function, GetFunctionRegistry()->GetFunction(name)); CheckDispatchBest(name, {boolean(), int32(), int32()}, {boolean(), int32(), int32()}); CheckDispatchBest(name, {boolean(), int32(), null()}, {boolean(), int32(), int32()}); CheckDispatchBest(name, {boolean(), null(), int32()}, {boolean(), int32(), int32()}); @@ -316,6 +369,16 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { {boolean(), float64(), float64()}); CheckDispatchBest(name, {null(), uint8(), int8()}, {boolean(), int16(), int16()}); + + CheckDispatchBest(name, + {boolean(), timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)}, + {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); + CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)}, + {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); + CheckDispatchBest(name, {boolean(), date32(), date64()}, + {boolean(), date64(), date64()}); + CheckDispatchBest(name, {boolean(), date32(), date32()}, + {boolean(), date32(), date32()}); } template @@ -412,7 +475,7 @@ TYPED_TEST(TestIfElseBaseBinary, IfElseBaseBinaryRand) { } TEST_F(TestIfElseKernel, IfElseFSBinary) { - auto type = std::make_shared(4); + auto type = fixed_size_binary(4); CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"), @@ -453,18 +516,10 @@ TEST_F(TestIfElseKernel, IfElseFSBinary) { ArrayFromJSON(type, R"(["aaaa", "abab", "abca", "abcd"])"), ArrayFromJSON(type, R"(["lmno", "lmnl", "lmlm", "llll"])"), ArrayFromJSON(type, R"([null, "abab", "abca", "llll"])")); - - // should fails for non-equal byte_widths - auto type1 = std::make_shared(5); - EXPECT_RAISES_WITH_MESSAGE_THAT( - Invalid, ::testing::HasSubstr("FixedSizeBinaryType byte_widths should be equal"), - CallFunction("if_else", {ArrayFromJSON(boolean(), "[true]"), - ArrayFromJSON(type, R"(["aaaa"])"), - ArrayFromJSON(type1, R"(["aaaaa"])")})); } TEST_F(TestIfElseKernel, IfElseFSBinaryRand) { - auto type = std::make_shared(5); + auto type = fixed_size_binary(5); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; @@ -504,6 +559,406 @@ TEST_F(TestIfElseKernel, IfElseFSBinaryRand) { CheckIfElseOutput(cond, left, right, expected_data); } +TEST_F(TestIfElseKernel, Decimal) { + for (const auto& ty : {decimal128(3, 2), decimal256(3, 2)}) { + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "-4.56"])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", null])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, null])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([true, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, "-4.56"])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"), + ArrayFromJSON(ty, R"([null, "2.34", null, "-4.56"])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", null, "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"), + ArrayFromJSON(ty, R"([null, "2.34", null, null])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", null])"), + ArrayFromJSON(ty, R"([null, "2.34", "-1.23", null])")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), R"([null, true, true, false])"), + ArrayFromJSON(ty, R"(["1.23", "2.34", "-1.23", "3.45"])"), + ArrayFromJSON(ty, R"(["1.34", "-2.34", "0.00", "-4.56"])"), + ArrayFromJSON(ty, R"([null, "2.34", "-1.23", "-4.56"])")); + } +} + +template +class TestIfElseList : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestIfElseList, ListArrowTypes); + +TYPED_TEST(TestIfElseList, ListOfInt) { + auto type = std::make_shared(int32()); + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, "[[], null, [1, null], [2, 3]]"), + ArrayFromJSON(type, "[[4, 5, 6], [7], [null], null]"), + ArrayFromJSON(type, "[[], null, [null], null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, "[[], [2, 3, 4, 5], null, null]"), + ArrayFromJSON(type, "[[4, 5, 6], null, [null], null]"), + ArrayFromJSON(type, "[null, null, null, null]")); +} + +TYPED_TEST(TestIfElseList, ListOfString) { + auto type = std::make_shared(utf8()); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, R"([[], null, ["xyz", null], ["ab", "c"]])"), + ArrayFromJSON(type, R"([["hi", "jk", "l"], ["defg"], [null], null])"), + ArrayFromJSON(type, R"([[], null, [null], null])")); + + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, R"([[], ["b", "cd", "efg", "h"], null, null])"), + ArrayFromJSON(type, R"([["hi", "jk", "l"], null, [null], null])"), + ArrayFromJSON(type, R"([null, null, null, null])")); +} + +TEST_F(TestIfElseKernel, FixedSizeList) { + auto type = fixed_size_list(int32(), 2); + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, "[[1, 2], null, [1, null], [2, 3]]"), + ArrayFromJSON(type, "[[4, 5], [6, 7], [null, 8], null]"), + ArrayFromJSON(type, "[[1, 2], null, [null, 8], null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, "[[2, 3], [4, 5], null, null]"), + ArrayFromJSON(type, "[[4, 5], null, [6, null], null]"), + ArrayFromJSON(type, "[null, null, null, null]")); +} + +TEST_F(TestIfElseKernel, StructPrimitive) { + auto type = struct_({field("int", uint16()), field("str", utf8())}); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, R"([[null, "foo"], null, [1, null], [2, "spam"]])"), + ArrayFromJSON(type, R"([[1, "a"], [42, ""], [24, null], null])"), + ArrayFromJSON(type, R"([[null, "foo"], null, [24, null], null])")); + + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, R"([[null, "foo"], [4, "abcd"], null, null])"), + ArrayFromJSON(type, R"([[1, "a"], null, [24, null], null])"), + ArrayFromJSON(type, R"([null, null, null, null])")); +} + +TEST_F(TestIfElseKernel, StructNested) { + auto type = struct_({field("date", date32()), field("list", list(int32()))}); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, R"([[-1, [null]], null, [1, null], [2, [3, 4]]])"), + ArrayFromJSON(type, R"([[4, [5]], [6, [7, 8]], [null, [1, null, 42]], null])"), + ArrayFromJSON(type, R"([[-1, [null]], null, [null, [1, null, 42]], null])")); + + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, R"([[-1, [null]], [4, [5, 6]], null, null])"), + ArrayFromJSON(type, R"([[4, [5]], null, [null, [1, null, 42]], null])"), + ArrayFromJSON(type, R"([null, null, null, null])")); +} + +TEST_F(TestIfElseKernel, ParameterizedTypes) { + auto cond = ArrayFromJSON(boolean(), "[true]"); + + auto type0 = fixed_size_binary(4); + auto type1 = fixed_size_binary(5); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: " + "fixed_size_binary[4], but got: fixed_size_binary[5]"), + CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["aaaa"])"), + ArrayFromJSON(type1, R"(["aaaaa"])")})); + + type0 = decimal128(3, 2); + type1 = decimal128(4, 2); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), " + "but got: decimal128(4, 2)"), + CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"), + ArrayFromJSON(type1, R"(["1.23"])")})); + + type1 = decimal128(3, 4); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), " + "but got: decimal128(3, 4)"), + CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"), + ArrayFromJSON(type1, R"(["1.2345"])")})); + + // TODO(ARROW-14105): in principle many of these could be implicitly castable too + + type0 = struct_({field("a", int32())}); + type1 = struct_({field("a", int64())}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: struct, " + "but got: struct"), + CallFunction("if_else", + {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")})); + + type0 = dense_union({field("a", int32())}); + type1 = dense_union({field("a", int64())}); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: dense_union, but got: dense_union"), + CallFunction("if_else", {cond, ArrayFromJSON(type0, "[[0, -1]]"), + ArrayFromJSON(type1, "[[0, -1]]")})); + + type0 = list(int16()); + type1 = list(int32()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: list, " + "but got: list"), + CallFunction("if_else", + {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")})); + + type0 = dictionary(int16(), utf8()); + type1 = dictionary(int32(), utf8()); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: " + "dictionary, but " + "got: dictionary"), + CallFunction("if_else", {cond, DictArrayFromJSON(type0, "[0]", R"(["a"])"), + DictArrayFromJSON(type1, "[0]", R"(["a"])")})); + + type0 = timestamp(TimeUnit::SECOND); + type1 = timestamp(TimeUnit::SECOND, "America/Phoenix"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr("All types must be compatible, expected: timestamp[s], " + "but got: timestamp[s, tz=America/Phoenix]"), + CallFunction("if_else", + {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")})); + + type0 = timestamp(TimeUnit::SECOND, "America/New_York"); + type1 = timestamp(TimeUnit::SECOND, "America/Phoenix"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + TypeError, + ::testing::HasSubstr( + "All types must be compatible, expected: timestamp[s, tz=America/New_York], " + "but got: timestamp[s, tz=America/Phoenix]"), + CallFunction("if_else", + {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")})); + + type0 = timestamp(TimeUnit::MILLI, "America/New_York"); + type1 = timestamp(TimeUnit::SECOND, "America/Phoenix"); + // Casting fails so we never get to the kernel in the first place (since the units don't + // match) + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, + ::testing::HasSubstr("Function if_else has no kernel matching input types " + "(array[bool], array[timestamp[ms, tz=America/New_York]], " + "array[timestamp[s, tz=America/Phoenix]]"), + CallFunction("if_else", + {cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")})); +} + +template +class TestIfElseUnion : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestIfElseUnion, UnionArrowTypes); + +TYPED_TEST(TestIfElseUnion, UnionPrimitive) { + std::vector> fields = {field("int", uint16()), + field("str", utf8())}; + std::vector codes = {2, 7}; + auto type = std::make_shared(fields, codes); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"), + ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"), + ArrayFromJSON(type, R"([[7, "foo"], [7, null], [2, 42], [2, null]])")); + + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, R"([[7, "foo"], [7, null], [7, null], [7, "spam"]])"), + ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"), + ArrayFromJSON(type, R"([null, null, null, null])")); +} + +TYPED_TEST(TestIfElseUnion, UnionNested) { + std::vector> fields = {field("int", uint16()), + field("list", list(int16()))}; + std::vector codes = {2, 7}; + auto type = std::make_shared(fields, codes); + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"), + ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"), + ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [2, 42], [2, null]])")); + + CheckWithDifferentShapes( + ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(type, R"([[7, [1, 2]], [7, null], [7, []], [7, [3]]])"), + ArrayFromJSON(type, R"([[2, 15], [2, null], [2, 42], [2, null]])"), + ArrayFromJSON(type, R"([null, null, null, null])")); +} + +template +class TestIfElseDict : public ::testing::Test {}; + +TYPED_TEST_SUITE(TestIfElseDict, IntegralArrowTypes); + +struct JsonDict { + std::shared_ptr type; + std::string value; +}; + +TYPED_TEST(TestIfElseDict, Simple) { + auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); + for (const auto& dict : + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + auto type = dictionary(default_type_instance(), dict.type); + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value); + auto scalar = DictScalarFromJSON(type, "3", dict.value); + + // Easy case: all arguments have the same dictionary + CheckDictionary("if_else", {cond, values1, values2}); + CheckDictionary("if_else", {cond, values1, scalar}); + CheckDictionary("if_else", {cond, scalar, values2}); + CheckDictionary("if_else", {cond, values_null, values2}); + CheckDictionary("if_else", {cond, values1, values_null}); + CheckDictionary("if_else", {Datum(true), values1, values2}); + CheckDictionary("if_else", {Datum(false), values1, values2}); + CheckDictionary("if_else", {Datum(true), scalar, values2}); + CheckDictionary("if_else", {Datum(true), values1, scalar}); + CheckDictionary("if_else", {Datum(false), values1, scalar}); + CheckDictionary("if_else", {Datum(false), scalar, values2}); + CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2}); + } +} + +TYPED_TEST(TestIfElseDict, Mixed) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); + auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); + auto scalar = ScalarFromJSON(utf8(), R"("bc")"); + + // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries + CheckDictionary("if_else", {cond, values1_dict, values2_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values1_dict, scalar}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, scalar, values2_dict}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values_null, values2_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values1_decoded, values_null}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), scalar, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), values1_dict, scalar}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), values1_dict, scalar}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), scalar, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {MakeNullScalar(boolean()), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestIfElseDict, NestedSimple) { + auto make_list = [](const std::shared_ptr& indices, + const std::shared_ptr& backing_array) { + EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); + return result; + }; + auto index_type = default_type_instance(); + auto inner_type = dictionary(index_type, utf8()); + auto type = list(inner_type); + auto dict = R"(["a", null, "bc", "def"])"; + auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); + auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); + auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); + auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + auto scalar = + Datum(std::make_shared(DictArrayFromJSON(inner_type, "[0, 1]", dict))); + + CheckDictionary("if_else", {cond, values1, values2}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values1, scalar}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, scalar, values2}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values_null, values2}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values1, values_null}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), scalar, values2}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(true), values1, scalar}, /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), values1, scalar}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {Datum(false), scalar, values2}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestIfElseDict, DifferentDictionaries) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict1 = R"(["a", null, "bc", "def"])"; + auto dict2 = R"(["bc", "foo", null, "a"])"; + auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); + auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); + auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); + auto scalar1 = DictScalarFromJSON(type, "0", dict1); + auto scalar2 = DictScalarFromJSON(type, "0", dict2); + + CheckDictionary("if_else", {cond, values1, values2}); + CheckDictionary("if_else", {cond, values1, scalar2}); + CheckDictionary("if_else", {cond, scalar1, values2}); + CheckDictionary("if_else", {cond, values1_null, values2}); + CheckDictionary("if_else", {cond, values1, values2_null}); + CheckDictionary("if_else", {Datum(true), values1, values2}); + CheckDictionary("if_else", {Datum(false), values1, values2}); + CheckDictionary("if_else", {Datum(true), scalar1, values2}); + CheckDictionary("if_else", {Datum(true), values1, scalar2}); + CheckDictionary("if_else", {Datum(false), values1, scalar2}); + CheckDictionary("if_else", {Datum(false), scalar1, values2}); + CheckDictionary("if_else", {MakeNullScalar(boolean()), values1, values2}); +} + template class TestCaseWhenNumeric : public ::testing::Test {}; @@ -627,11 +1082,6 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) { template class TestCaseWhenDict : public ::testing::Test {}; -struct JsonDict { - std::shared_ptr type; - std::string value; -}; - TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); TYPED_TEST(TestCaseWhenDict, Simple) { From a53bc771d321cc5a57e07409dba1c8b2ee35a632 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 29 Sep 2021 12:46:10 -0400 Subject: [PATCH 2/5] ARROW-13358: [C++] Move out common utilities --- .../compute/kernels/scalar_if_else_test.cc | 54 ++++++++++--------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5e74d4a60e4..edf1d954404 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -27,6 +27,20 @@ namespace arrow { namespace compute { +// Helper that combines a dictionary and the value type so it can +// later be used with DictArrayFromJSON +struct JsonDict { + std::shared_ptr type; + std::string value; +}; + +// Helper that makes a list of dictionary indices +std::shared_ptr MakeListOfDict(const std::shared_ptr& indices, + const std::shared_ptr& backing_array) { + EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); + return result; +} + void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right, const Datum& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); @@ -825,11 +839,6 @@ class TestIfElseDict : public ::testing::Test {}; TYPED_TEST_SUITE(TestIfElseDict, IntegralArrowTypes); -struct JsonDict { - std::shared_ptr type; - std::string value; -}; - TYPED_TEST(TestIfElseDict, Simple) { auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); for (const auto& dict : @@ -895,22 +904,19 @@ TYPED_TEST(TestIfElseDict, Mixed) { } TYPED_TEST(TestIfElseDict, NestedSimple) { - auto make_list = [](const std::shared_ptr& indices, - const std::shared_ptr& backing_array) { - EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); - return result; - }; auto index_type = default_type_instance(); auto inner_type = dictionary(index_type, utf8()); auto type = list(inner_type); auto dict = R"(["a", null, "bc", "def"])"; auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); - auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), - DictArrayFromJSON(inner_type, "[]", dict)); + auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); - auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); - auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + auto values1 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); auto scalar = Datum(std::make_shared(DictArrayFromJSON(inner_type, "[0, 1]", dict))); @@ -1133,35 +1139,33 @@ TYPED_TEST(TestCaseWhenDict, Mixed) { } TYPED_TEST(TestCaseWhenDict, NestedSimple) { - auto make_list = [](const std::shared_ptr& indices, - const std::shared_ptr& backing_array) { - EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); - return result; - }; auto index_type = default_type_instance(); auto inner_type = dictionary(index_type, utf8()); auto type = list(inner_type); auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); auto dict = R"(["a", null, "bc", "def"])"; - auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), - DictArrayFromJSON(inner_type, "[]", dict)); + auto values_null = MakeListOfDict(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); - auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); - auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + auto values1 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = + MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}, /*result_is_encoded=*/false); CheckDictionary( "case_when", {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, + MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, /*result_is_encoded=*/false); CheckDictionary( "case_when", {MakeStruct({cond1, cond2}), values1, - make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, + MakeListOfDict(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), + values1}, /*result_is_encoded=*/false); CheckDictionary("case_when", From ba2e94704179ba34df264a2fa275fc4e4c8ab807 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 29 Sep 2021 13:54:10 -0400 Subject: [PATCH 3/5] ARROW-13358: [C++] Reconcile with ARROW-14167 --- .../arrow/compute/kernels/codegen_internal.cc | 7 ++-- .../arrow/compute/kernels/codegen_internal.h | 2 +- .../arrow/compute/kernels/scalar_compare.cc | 2 +- .../arrow/compute/kernels/scalar_if_else.cc | 10 ++++- .../compute/kernels/scalar_if_else_test.cc | 41 ++++++------------- 5 files changed, 27 insertions(+), 35 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index 9077c7e9f0b..70488759261 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -222,11 +222,12 @@ std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count) return nullptr; } -std::shared_ptr CommonBinary(const std::vector& descrs) { +std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count) { bool all_utf8 = true, all_offset32 = true; - for (const auto& descr : descrs) { - auto id = descr.type->id(); + const ValueDescr* end = begin + count; + for (auto it = begin; it != end; ++it) { + auto id = it->type->id(); // a common varbinary type is only possible if all types are binary like switch (id) { case Type::STRING: diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index a4e2c3d6d1a..dc6068dd529 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -1305,7 +1305,7 @@ ARROW_EXPORT std::shared_ptr CommonTemporal(const ValueDescr* begin, size_t count); ARROW_EXPORT -std::shared_ptr CommonBinary(const std::vector& descrs); +std::shared_ptr CommonBinary(const ValueDescr* begin, size_t count); /// How to promote decimal precision/scale in CastBinaryDecimalArgs. enum class DecimalPromotion : uint8_t { diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index e42bcb7b25c..54230f88ff9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -173,7 +173,7 @@ struct CompareFunction : ScalarFunction { ReplaceTypes(type, values); } else if (auto type = CommonTemporal(values->data(), values->size())) { ReplaceTypes(type, values); - } else if (auto type = CommonBinary(*values)) { + } else if (auto type = CommonBinary(values->data(), values->size())) { ReplaceTypes(type, values); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index e9432868ee8..bb7c417f1e5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1137,7 +1137,7 @@ struct IfElseFunction : ScalarFunction { internal::ReplaceNullWithOtherType(left_arg, num_args); if (is_dictionary((*values)[1].type->id()) && - is_dictionary((*values)[2].type->id())) { + (*values)[1].type->Equals(*(*values)[2].type)) { auto kernel = DispatchExactImpl(this, *values); DCHECK(kernel); return kernel; @@ -1151,6 +1151,12 @@ struct IfElseFunction : ScalarFunction { if (auto type = internal::CommonTemporal(left_arg, num_args)) { internal::ReplaceTypes(type, left_arg, num_args); } + if (auto type = internal::CommonBinary(left_arg, num_args)) { + internal::ReplaceTypes(type, left_arg, num_args); + } + if (HasDecimal(*values)) { + RETURN_NOT_OK(CastDecimalArgs(left_arg, num_args)); + } if (auto kernel = DispatchExactImpl(this, *values)) return kernel; @@ -1918,7 +1924,7 @@ struct CoalesceFunction : ScalarFunction { if (auto type = CommonNumeric(*values)) { ReplaceTypes(type, values); } - if (auto type = CommonBinary(*values)) { + if (auto type = CommonBinary(values->data(), values->size())) { ReplaceTypes(type, values); } if (auto type = CommonTemporal(values->data(), values->size())) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index edf1d954404..5acb0b16b6c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -705,23 +705,6 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) { CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["aaaa"])"), ArrayFromJSON(type1, R"(["aaaaa"])")})); - type0 = decimal128(3, 2); - type1 = decimal128(4, 2); - EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, - ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), " - "but got: decimal128(4, 2)"), - CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"), - ArrayFromJSON(type1, R"(["1.23"])")})); - - type1 = decimal128(3, 4); - EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, - ::testing::HasSubstr("All types must be compatible, expected: decimal128(3, 2), " - "but got: decimal128(3, 4)"), - CallFunction("if_else", {cond, ArrayFromJSON(type0, R"(["1.23"])"), - ArrayFromJSON(type1, R"(["1.2345"])")})); - // TODO(ARROW-14105): in principle many of these could be implicitly castable too type0 = struct_({field("a", int32())}); @@ -751,16 +734,6 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) { CallFunction("if_else", {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")})); - type0 = dictionary(int16(), utf8()); - type1 = dictionary(int32(), utf8()); - EXPECT_RAISES_WITH_MESSAGE_THAT( - TypeError, - ::testing::HasSubstr("All types must be compatible, expected: " - "dictionary, but " - "got: dictionary"), - CallFunction("if_else", {cond, DictArrayFromJSON(type0, "[0]", R"(["a"])"), - DictArrayFromJSON(type1, "[0]", R"(["a"])")})); - type0 = timestamp(TimeUnit::SECOND); type1 = timestamp(TimeUnit::SECOND, "America/Phoenix"); EXPECT_RAISES_WITH_MESSAGE_THAT( @@ -868,7 +841,8 @@ TYPED_TEST(TestIfElseDict, Simple) { } TYPED_TEST(TestIfElseDict, Mixed) { - auto type = dictionary(default_type_instance(), utf8()); + auto index_type = default_type_instance(); + auto type = dictionary(index_type, utf8()); auto cond = ArrayFromJSON(boolean(), "[true, false, true, null]"); auto dict = R"(["a", null, "bc", "def"])"; auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); @@ -901,6 +875,17 @@ TYPED_TEST(TestIfElseDict, Mixed) { /*result_is_encoded=*/false); CheckDictionary("if_else", {MakeNullScalar(boolean()), values1_decoded, values2_dict}, /*result_is_encoded=*/false); + + // If we have mismatched dictionary types, we decode (for now) + auto values3_dict = + DictArrayFromJSON(dictionary(index_type, binary()), "[2, 1, null, 0]", dict); + auto values4_dict = DictArrayFromJSON( + dictionary(index_type->id() == Type::UINT8 ? int8() : uint8(), utf8()), + "[2, 1, null, 0]", dict); + CheckDictionary("if_else", {cond, values1_dict, values3_dict}, + /*result_is_encoded=*/false); + CheckDictionary("if_else", {cond, values1_dict, values4_dict}, + /*result_is_encoded=*/false); } TYPED_TEST(TestIfElseDict, NestedSimple) { From 36b2a77b3ef06c804afa0001862ec11129003e26 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 30 Sep 2021 11:09:34 -0400 Subject: [PATCH 4/5] ARROW-13358: [C++] Address feedback --- cpp/src/arrow/array/array_test.cc | 2 + cpp/src/arrow/array/util.cc | 3 +- .../arrow/compute/kernels/scalar_if_else.cc | 147 +++++++++++------- .../compute/kernels/scalar_if_else_test.cc | 74 +++++---- 4 files changed, 129 insertions(+), 97 deletions(-) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index b6cfa0dce63..e5dfd0ccd1a 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -544,10 +544,12 @@ static ScalarVector GetScalars() { sparse_union_ty), std::make_shared(std::make_shared(100), 42, sparse_union_ty), + std::make_shared(42, sparse_union_ty), std::make_shared(std::make_shared(101), 6, dense_union_ty), std::make_shared(std::make_shared(101), 42, dense_union_ty), + std::make_shared(42, dense_union_ty), DictionaryScalar::Make(ScalarFromJSON(int8(), "1"), ArrayFromJSON(utf8(), R"(["foo", "bar"])")), DictionaryScalar::Make(ScalarFromJSON(uint8(), "1"), diff --git a/cpp/src/arrow/array/util.cc b/cpp/src/arrow/array/util.cc index 71f91497f61..d639830f469 100644 --- a/cpp/src/arrow/array/util.cc +++ b/cpp/src/arrow/array/util.cc @@ -775,7 +775,8 @@ Result> MakeArrayOfNull(const std::shared_ptr& Result> MakeArrayFromScalar(const Scalar& scalar, int64_t length, MemoryPool* pool) { - if (!scalar.is_valid) { + // Null union scalars still have a type code associated + if (!scalar.is_valid && !is_union(scalar.type->id())) { return MakeArrayOfNull(scalar.type, length, pool); } return RepeatedArrayFactory(pool, scalar, length).Create(); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index bb7c417f1e5..b3ebba8ea00 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -948,9 +948,7 @@ struct IfElseFunctor> { }; // Use builders for dictionaries - slower, but allows us to unify dictionaries -template -struct IfElseFunctor< - Type, enable_if_t::value || is_dictionary_type::value>> { +struct NestedIfElseExec { // A - Array, S - Scalar, X = Array/Scalar // SXX @@ -989,11 +987,11 @@ struct IfElseFunctor< const ArrayData& right, ArrayData* out) { return RunLoop( ctx, cond, out, - [&](ArrayBuilder* builder, int64_t i) { - return builder->AppendArraySlice(left, i, /*length=*/1); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendArraySlice(left, i, length); }, - [&](ArrayBuilder* builder, int64_t i) { - return builder->AppendArraySlice(right, i, /*length=*/1); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendArraySlice(right, i, length); }); } @@ -1002,9 +1000,11 @@ struct IfElseFunctor< const ArrayData& right, ArrayData* out) { return RunLoop( ctx, cond, out, - [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); }, - [&](ArrayBuilder* builder, int64_t i) { - return builder->AppendArraySlice(right, i, /*length=*/1); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendScalar(left, length); + }, + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendArraySlice(right, i, length); }); } @@ -1013,10 +1013,12 @@ struct IfElseFunctor< const Scalar& right, ArrayData* out) { return RunLoop( ctx, cond, out, - [&](ArrayBuilder* builder, int64_t i) { - return builder->AppendArraySlice(left, i, /*length=*/1); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendArraySlice(left, i, length); }, - [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); }); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendScalar(right, length); + }); } // ASS @@ -1024,8 +1026,12 @@ struct IfElseFunctor< const Scalar& right, ArrayData* out) { return RunLoop( ctx, cond, out, - [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(left); }, - [&](ArrayBuilder* builder, int64_t i) { return builder->AppendScalar(right); }); + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendScalar(left, length); + }, + [&](ArrayBuilder* builder, int64_t i, int64_t length) { + return builder->AppendScalar(right, length); + }); } template @@ -1036,32 +1042,68 @@ struct IfElseFunctor< RETURN_NOT_OK(raw_builder->Reserve(out->length)); const auto* cond_data = cond.buffers[1]->data(); - if (out->buffers[0]) { - const uint8_t* out_valid = out->buffers[0]->data(); - for (int64_t i = 0; i < cond.length; i++) { - if (BitUtil::GetBit(out_valid, i)) { - if (BitUtil::GetBit(cond_data, cond.offset + i)) { - RETURN_NOT_OK(handle_left(raw_builder.get(), i)); - } else { - RETURN_NOT_OK(handle_right(raw_builder.get(), i)); + if (cond.buffers[0]) { + BitRunReader reader(cond.buffers[0]->data(), cond.offset, cond.length); + int64_t position = 0; + while (true) { + auto run = reader.NextRun(); + if (run.length == 0) break; + if (run.set) { + for (int j = 0; j < run.length; j++) { + if (BitUtil::GetBit(cond_data, cond.offset + position + j)) { + RETURN_NOT_OK(handle_left(raw_builder.get(), position + j, 1)); + } else { + RETURN_NOT_OK(handle_right(raw_builder.get(), position + j, 1)); + } } } else { - RETURN_NOT_OK(raw_builder->AppendNull()); + RETURN_NOT_OK(raw_builder->AppendNulls(run.length)); } + position += run.length; } } else { - for (int64_t i = 0; i < cond.length; i++) { - if (BitUtil::GetBit(cond_data, cond.offset + i)) { - RETURN_NOT_OK(handle_left(raw_builder.get(), i)); + BitRunReader reader(cond_data, cond.offset, cond.length); + int64_t position = 0; + while (true) { + auto run = reader.NextRun(); + if (run.length == 0) break; + if (run.set) { + RETURN_NOT_OK(handle_left(raw_builder.get(), position, run.length)); } else { - RETURN_NOT_OK(handle_right(raw_builder.get(), i)); + RETURN_NOT_OK(handle_right(raw_builder.get(), position, run.length)); } + position += run.length; } } ARROW_ASSIGN_OR_RAISE(auto out_arr, raw_builder->Finish()); *out = std::move(*out_arr->data()); return Status::OK(); } + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RETURN_NOT_OK(CheckIdenticalTypes(&batch.values[1], /*count=*/2)); + if (batch[0].is_scalar()) { + const auto& cond = batch[0].scalar_as(); + return Call(ctx, cond, batch[1], batch[2], out); + } + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].array(), + out->mutable_array()); + } else { // AAS + return Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].scalar(), + out->mutable_array()); + } + } else { + if (batch[2].kind() == Datum::ARRAY) { // ASA + return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].array(), + out->mutable_array()); + } else { // ASS + return Call(ctx, *batch[0].array(), *batch[1].scalar(), *batch[2].scalar(), + out->mutable_array()); + } + } + } }; template @@ -1136,8 +1178,10 @@ struct IfElseFunction : ScalarFunction { internal::ReplaceNullWithOtherType(left_arg, num_args); - if (is_dictionary((*values)[1].type->id()) && - (*values)[1].type->Equals(*(*values)[2].type)) { + // If both are identical dictionary types, dispatch to the dictionary kernel + // TODO(ARROW-14105): apply implicit casts to dictionary types too + ValueDescr* right_arg = &(*values)[2]; + if (is_dictionary(left_arg->type->id()) && left_arg->type->Equals(right_arg->type)) { auto kernel = DispatchExactImpl(this, *values); DCHECK(kernel); return kernel; @@ -1147,14 +1191,11 @@ struct IfElseFunction : ScalarFunction { if (auto type = internal::CommonNumeric(left_arg, num_args)) { internal::ReplaceTypes(type, left_arg, num_args); - } - if (auto type = internal::CommonTemporal(left_arg, num_args)) { + } else if (auto type = internal::CommonTemporal(left_arg, num_args)) { internal::ReplaceTypes(type, left_arg, num_args); - } - if (auto type = internal::CommonBinary(left_arg, num_args)) { + } else if (auto type = internal::CommonBinary(left_arg, num_args)) { internal::ReplaceTypes(type, left_arg, num_args); - } - if (HasDecimal(*values)) { + } else if (HasDecimal(*values)) { RETURN_NOT_OK(CastDecimalArgs(left_arg, num_args)); } @@ -1230,14 +1271,17 @@ void AddFixedWidthIfElseKernel(const std::shared_ptr& scalar_fun DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); } -void AddNestedIfElseKernel(const std::shared_ptr& scalar_function, - detail::GetTypeId get_id, ArrayKernelExec exec) { - ScalarKernel kernel({boolean(), InputType(get_id.id), InputType(get_id.id)}, - OutputType(LastType), std::move(exec)); - kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; - kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - kernel.can_write_into_slices = false; - DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); +void AddNestedIfElseKernels(const std::shared_ptr& scalar_function) { + for (const auto type_id : + {Type::LIST, Type::LARGE_LIST, Type::FIXED_SIZE_LIST, Type::STRUCT, + Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) { + ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, + OutputType(LastType), NestedIfElseExec::Exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + kernel.can_write_into_slices = false; + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); + } } // Helper to copy or broadcast fixed-width values between buffers. @@ -2795,20 +2839,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddFixedWidthIfElseKernel(func); AddFixedWidthIfElseKernel(func); AddFixedWidthIfElseKernel(func); - AddNestedIfElseKernel(func, Type::LIST, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::LARGE_LIST, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::FIXED_SIZE_LIST, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::STRUCT, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::DENSE_UNION, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::SPARSE_UNION, - ResolveIfElseExec::Exec); - AddNestedIfElseKernel(func, Type::DICTIONARY, - ResolveIfElseExec::Exec); + AddNestedIfElseKernels(func); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5acb0b16b6c..92e0582c6f1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -116,54 +116,49 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, CheckScalar("if_else", {cond, left, right}, expected); auto len = left->length(); + std::vector array_indices = {-1}; // sentinel for make_input + std::vector scalar_indices(len); + std::iota(scalar_indices.begin(), scalar_indices.end(), 0); + auto make_input = [&](const std::shared_ptr& array, int64_t index, Datum* input, + Datum* input_broadcast, std::string* trace) { + if (index >= 0) { + // Use scalar from array[index] as input; broadcast scalar for computing expected + // result + ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(index)); + *trace += "@" + std::to_string(index) + "=" + scalar->ToString(); + *input = std::move(scalar); + ASSERT_OK_AND_ASSIGN(*input_broadcast, MakeArrayFromScalar(*input->scalar(), len)); + } else { + // Use array as input + *trace += "=Array"; + *input = *input_broadcast = array; + } + }; enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 }; for (int mask = 1; mask <= (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { - for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { + for (int64_t cond_idx : (mask & COND_SCALAR) ? scalar_indices : array_indices) { Datum cond_in, cond_bcast; std::string trace_cond = "Cond"; - if (mask & COND_SCALAR) { - ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx)); - ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len)); - trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); - } else { - cond_in = cond_bcast = cond; - trace_cond += "=Array"; - } - SCOPED_TRACE(trace_cond); + make_input(cond, cond_idx, &cond_in, &cond_bcast, &trace_cond); - for (int64_t left_idx = 0; left_idx < len; ++left_idx) { + for (int64_t left_idx : (mask & LEFT_SCALAR) ? scalar_indices : array_indices) { Datum left_in, left_bcast; std::string trace_left = "Left"; - if (mask & LEFT_SCALAR) { - ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As()); - ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len)); - trace_left += - "@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString(); - } else { - left_in = left_bcast = left; - trace_left += "=Array"; - } - SCOPED_TRACE(trace_left); + make_input(left, left_idx, &left_in, &left_bcast, &trace_left); - for (int64_t right_idx = 0; right_idx < len; ++right_idx) { + for (int64_t right_idx : (mask & RIGHT_SCALAR) ? scalar_indices : array_indices) { Datum right_in, right_bcast; std::string trace_right = "Right"; - if (mask & RIGHT_SCALAR) { - ASSERT_OK_AND_ASSIGN(right_in, right->GetScalar(right_idx)); - ASSERT_OK_AND_ASSIGN(right_bcast, - MakeArrayFromScalar(*right_in.scalar(), len)); - trace_right += - "@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString(); - } else { - right_in = right_bcast = right; - trace_right += "=Array"; - } + make_input(right, right_idx, &right_in, &right_bcast, &trace_right); + SCOPED_TRACE(trace_right); + SCOPED_TRACE(trace_left); + SCOPED_TRACE(trace_cond); Datum expected; ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in)); - if (mask & COND_SCALAR && mask & LEFT_SCALAR && mask & RIGHT_SCALAR) { + if (mask == (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR)) { const auto& scalar = cond_in.scalar_as(); if (scalar.is_valid) { expected = scalar.value ? left_in : right_in; @@ -177,13 +172,9 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, } else { ASSERT_OK_AND_ASSIGN(expected, IfElse(cond_bcast, left_bcast, right_bcast)); } - AssertDatumsEqual(actual, expected, /*verbose=*/true); - - if (right_in.is_array()) break; + AssertDatumsEqual(expected, actual, /*verbose=*/true); } - if (left_in.is_array()) break; } - if (cond_in.is_array()) break; } } // for (mask) } @@ -734,6 +725,13 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) { CallFunction("if_else", {cond, ArrayFromJSON(type0, "[[0]]"), ArrayFromJSON(type1, "[[0]]")})); + type0 = timestamp(TimeUnit::SECOND); + type1 = timestamp(TimeUnit::MILLI); + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type0, "[1, 2, 3, 4]"), + ArrayFromJSON(type1, "[5, 6, 7, 8]"), + ArrayFromJSON(type1, "[1000, 2000, 3000, 8]")); + type0 = timestamp(TimeUnit::SECOND); type1 = timestamp(TimeUnit::SECOND, "America/Phoenix"); EXPECT_RAISES_WITH_MESSAGE_THAT( From fe8329e7aecdced6eb98510002dff15351940bf2 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 4 Oct 2021 15:38:35 +0200 Subject: [PATCH 5/5] Update docs --- docs/source/cpp/compute.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 0a87b0b1d29..f699698e97b 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1145,11 +1145,11 @@ depending on a condition. +==================+============+===================================================+=====================+=========+ | case_when | Varargs | Struct of Boolean (Arg 0), Any (rest) | Input type | \(1) | +------------------+------------+---------------------------------------------------+---------------------+---------+ -| choose | Varargs | Integral (Arg 0); Fixed-width/Binary-like (rest) | Input type | \(2) | +| choose | Varargs | Integral (Arg 0), Fixed-width/Binary-like (rest) | Input type | \(2) | +------------------+------------+---------------------------------------------------+---------------------+---------+ | coalesce | Varargs | Any | Input type | \(3) | +------------------+------------+---------------------------------------------------+---------------------+---------+ -| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type | \(4) | +| if_else | Ternary | Boolean (Arg 0), Any (rest) | Input type | \(4) | +------------------+------------+---------------------------------------------------+---------------------+---------+ * \(1) This function acts like a SQL "case when" statement or switch-case. The