diff --git a/cpp/apidoc/Doxyfile b/cpp/apidoc/Doxyfile index 794fc82d69d..f6b782276e3 100644 --- a/cpp/apidoc/Doxyfile +++ b/cpp/apidoc/Doxyfile @@ -913,7 +913,8 @@ EXCLUDE_SYMLINKS = NO EXCLUDE_PATTERNS = *-test.cc \ *test* \ *_generated.h \ - *-benchmark.cc + *-benchmark.cc \ + *internal* # The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names # (namespaces, classes, functions, etc.) that should be excluded from the diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 0082799b212..44627252aa2 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -30,6 +30,7 @@ add_arrow_compute_test(scalar_test test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(scalar_cast_benchmark PREFIX "arrow-compute") add_arrow_benchmark(scalar_compare_benchmark PREFIX "arrow-compute") add_arrow_benchmark(scalar_string_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc index 4e64c7a2d54..296abd39f1c 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_benchmark.cc @@ -21,10 +21,10 @@ #include "arrow/builder.h" #include "arrow/compute/api.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/memory_pool.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" #include "arrow/util/bit_util.h" namespace arrow { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 2508fafe07d..00b42697d97 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -215,17 +215,15 @@ template struct BoxScalar> { using T = typename GetOutputType::T; using ScalarType = typename TypeTraits::ScalarType; - static std::shared_ptr Box(T val, const std::shared_ptr& type) { - return std::make_shared(val, type); - } + static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } }; template struct BoxScalar> { using T = typename GetOutputType::T; using ScalarType = typename TypeTraits::ScalarType; - static std::shared_ptr Box(T val, const std::shared_ptr&) { - return std::make_shared(val); + static void Box(T val, Scalar* out) { + checked_cast(out)->value = std::make_shared(val); } }; @@ -233,9 +231,7 @@ template <> struct BoxScalar { using T = Decimal128; using ScalarType = Decimal128Scalar; - static std::shared_ptr Box(T val, const std::shared_ptr& type) { - return std::make_shared(val, type); - } + static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } }; // ---------------------------------------------------------------------- @@ -396,8 +392,8 @@ struct ScalarUnary { static void Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { if (arg0.is_valid) { ARG0 arg0_val = UnboxScalar::Unbox(arg0); - out->value = BoxScalar::Box(Op::template Call(ctx, arg0_val), - out->type()); + BoxScalar::Box(Op::template Call(ctx, arg0_val), + out->scalar().get()); } else { out->value = MakeNullScalar(arg0.type); } @@ -533,8 +529,8 @@ struct ScalarUnaryNotNullStateful { void Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { if (arg0.is_valid) { ARG0 arg0_val = UnboxScalar::Unbox(arg0); - out->value = BoxScalar::Box( - this->op.template Call(ctx, arg0_val), out->type()); + BoxScalar::Box(this->op.template Call(ctx, arg0_val), + out->scalar().get()); } else { out->value = MakeNullScalar(arg0.type); } @@ -615,8 +611,8 @@ struct ScalarBinary { if (out->scalar()->is_valid) { auto arg0_val = UnboxScalar::Unbox(arg0); auto arg1_val = UnboxScalar::Unbox(arg1); - out->value = BoxScalar::Box(Op::template Call(ctx, arg0_val, arg1_val), - out->type()); + BoxScalar::Box(Op::template Call(ctx, arg0_val, arg1_val), + out->scalar().get()); } } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc index b301c95c680..cac5679c1a9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_benchmark.cc @@ -20,10 +20,10 @@ #include #include "arrow/compute/api_scalar.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc new file mode 100644 index 00000000000..8eea8725ddf --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_cast_benchmark.cc @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include + +#include "arrow/compute/cast.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x94378165; + +template +static void BenchmarkNumericCast(benchmark::State& state, + std::shared_ptr to_type, + const CastOptions& options, CType min, CType max) { + GenericItemsArgs args(state); + random::RandomArrayGenerator rand(kSeed); + auto array = rand.Numeric(args.size, min, max, args.null_proportion); + for (auto _ : state) { + ABORT_NOT_OK(Cast(array, to_type, options).status()); + } +} + +template +static void BenchmarkFloatingToIntegerCast(benchmark::State& state, + std::shared_ptr from_type, + std::shared_ptr to_type, + const CastOptions& options, CType min, + CType max) { + GenericItemsArgs args(state); + random::RandomArrayGenerator rand(kSeed); + auto array = rand.Numeric(args.size, min, max, args.null_proportion); + + std::shared_ptr values_as_float = *Cast(*array, from_type); + + for (auto _ : state) { + ABORT_NOT_OK(Cast(values_as_float, to_type, options).status()); + } +} + +std::vector g_data_sizes = {kL2Size}; + +void CastSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (auto nulls : std::vector({1000, 10, 2, 1, 0})) { + bench->Args({static_cast(size), nulls}); + } + } +} + +static constexpr int32_t kInt32Min = std::numeric_limits::min(); +static constexpr int32_t kInt32Max = std::numeric_limits::max(); + +static void CastInt64ToInt32Safe(benchmark::State& state) { + BenchmarkNumericCast(state, int32(), CastOptions::Safe(), kInt32Min, + kInt32Max); +} + +static void CastInt64ToInt32Unsafe(benchmark::State& state) { + BenchmarkNumericCast(state, int32(), CastOptions::Unsafe(), kInt32Min, + kInt32Max); +} + +static void CastUInt32ToInt32Safe(benchmark::State& state) { + BenchmarkNumericCast(state, int32(), CastOptions::Safe(), 0, kInt32Max); +} + +static void CastInt64ToDoubleSafe(benchmark::State& state) { + BenchmarkNumericCast(state, float64(), CastOptions::Safe(), 0, 1000); +} + +static void CastInt64ToDoubleUnsafe(benchmark::State& state) { + BenchmarkNumericCast(state, float64(), CastOptions::Unsafe(), 0, 1000); +} + +static void CastDoubleToInt32Safe(benchmark::State& state) { + BenchmarkFloatingToIntegerCast(state, float64(), int32(), + CastOptions::Safe(), -1000, 1000); +} + +static void CastDoubleToInt32Unsafe(benchmark::State& state) { + BenchmarkFloatingToIntegerCast(state, float64(), int32(), + CastOptions::Unsafe(), -1000, 1000); +} + +BENCHMARK(CastInt64ToInt32Safe)->Apply(CastSetArgs); +BENCHMARK(CastInt64ToInt32Unsafe)->Apply(CastSetArgs); +BENCHMARK(CastUInt32ToInt32Safe)->Apply(CastSetArgs); + +BENCHMARK(CastInt64ToDoubleSafe)->Apply(CastSetArgs); +BENCHMARK(CastInt64ToDoubleUnsafe)->Apply(CastSetArgs); +BENCHMARK(CastDoubleToInt32Safe)->Apply(CastSetArgs); +BENCHMARK(CastDoubleToInt32Unsafe)->Apply(CastSetArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index d09b20a0cbb..7207fd256a3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -19,252 +19,362 @@ #include "arrow/compute/kernels/common.h" #include "arrow/compute/kernels/scalar_cast_internal.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/int_util.h" #include "arrow/util/value_parsing.h" namespace arrow { +using internal::BitBlockCount; +using internal::CheckIntegersInRange; +using internal::IntegersCanFit; +using internal::OptionalBitBlockCounter; using internal::ParseValue; namespace compute { namespace internal { -// ---------------------------------------------------------------------- -// Integers and Floating Point - -// Conversions pairs () are partitioned in 4 type traits: -// - is_number_downcast -// - is_integral_signed_to_unsigned -// - is_integral_unsigned_to_signed -// - is_float_truncate -// -// Each class has a different way of validation if the conversion is safe -// (either with bounded intervals or with explicit C casts) - -template -struct is_number_downcast { - static constexpr bool value = false; -}; - -template -struct is_number_downcast< - O, I, enable_if_t::value && is_number_type::value>> { - using O_T = typename O::c_type; - using I_T = typename I::c_type; - - static constexpr bool value = - ((!std::is_same::value) && - // Both types are of the same sign-ness. - ((std::is_signed::value == std::is_signed::value) && - // Both types are of the same integral-ness. - (std::is_floating_point::value == std::is_floating_point::value)) && - // Smaller output size - (sizeof(O_T) < sizeof(I_T))); -}; - -template -struct is_integral_signed_to_unsigned { - static constexpr bool value = false; -}; - -template -struct is_integral_signed_to_unsigned< - O, I, enable_if_t::value && is_integer_type::value>> { - using O_T = typename O::c_type; - using I_T = typename I::c_type; - - static constexpr bool value = - ((!std::is_same::value) && - ((std::is_unsigned::value && std::is_signed::value))); -}; +template +ARROW_DISABLE_UBSAN("float-cast-overflow") +void DoStaticCast(const void* in_data, int64_t in_offset, int64_t length, + int64_t out_offset, void* out_data) { + auto in = reinterpret_cast(in_data) + in_offset; + auto out = reinterpret_cast(out_data) + out_offset; + for (int64_t i = 0; i < length; ++i) { + *out++ = static_cast(*in++); + } +} -template -struct is_integral_unsigned_to_signed { - static constexpr bool value = false; +using StaticCastFunc = std::function; + +template +struct CastPrimitive { + static void Exec(const ExecBatch& batch, Datum* out) { + using OutT = typename OutType::c_type; + using InT = typename InType::c_type; + using OutScalar = typename TypeTraits::ScalarType; + using InScalar = typename TypeTraits::ScalarType; + + StaticCastFunc caster = DoStaticCast; + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& arr = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + caster(arr.buffers[1]->data(), arr.offset, arr.length, out_arr->offset, + out_arr->buffers[1]->mutable_data()); + } else { + // Scalar path. Use the caster with length 1 to place the casted value into + // the output + const auto& in_scalar = batch[0].scalar_as(); + auto out_scalar = checked_cast(out->scalar().get()); + caster(&in_scalar.value, /*in_offset=*/0, /*length=*/1, /*out_offset=*/0, + &out_scalar->value); + } + } }; -template -struct is_integral_unsigned_to_signed< - O, I, enable_if_t::value && is_integer_type::value>> { - using O_T = typename O::c_type; - using I_T = typename I::c_type; - - static constexpr bool value = - ((!std::is_same::value) && - ((std::is_signed::value && std::is_unsigned::value))); +template +struct CastPrimitive::value>> { + // memcpy output + static void Exec(const ExecBatch& batch, Datum* out) { + using T = typename InType::c_type; + using OutScalar = typename TypeTraits::ScalarType; + using InScalar = typename TypeTraits::ScalarType; + + if (batch[0].kind() == Datum::ARRAY) { + const ArrayData& arr = *batch[0].array(); + ArrayData* out_arr = out->mutable_array(); + std::memcpy( + reinterpret_cast(out_arr->buffers[1]->mutable_data()) + out_arr->offset, + reinterpret_cast(arr.buffers[1]->data()) + arr.offset, + arr.length * sizeof(T)); + } else { + // Scalar path. Use the caster with length 1 to place the casted value into + // the output + const auto& in_scalar = batch[0].scalar_as(); + checked_cast(out->scalar().get())->value = in_scalar.value; + } + } }; -// This set of functions SafeMinimum/SafeMaximum would be simplified with -// C++17 and `if constexpr`. +template +void CastNumberImpl(const ExecBatch& batch, Datum* out) { + switch (out->type()->id()) { + case Type::INT8: + return CastPrimitive::Exec(batch, out); + case Type::INT16: + return CastPrimitive::Exec(batch, out); + case Type::INT32: + return CastPrimitive::Exec(batch, out); + case Type::INT64: + return CastPrimitive::Exec(batch, out); + case Type::UINT8: + return CastPrimitive::Exec(batch, out); + case Type::UINT16: + return CastPrimitive::Exec(batch, out); + case Type::UINT32: + return CastPrimitive::Exec(batch, out); + case Type::UINT64: + return CastPrimitive::Exec(batch, out); + case Type::FLOAT: + return CastPrimitive::Exec(batch, out); + case Type::DOUBLE: + return CastPrimitive::Exec(batch, out); + default: + break; + } +} -// clang-format doesn't handle this construct properly. Thus the macro, but it -// also improves readability. -// -// The effective return type of the function is always `I::c_type`, this is -// just how enable_if works with functions. -#define RET_TYPE(TRAIT) enable_if_t::value, typename I::c_type> +void CastNumberToNumberUnsafe(const ExecBatch& batch, Datum* out) { + switch (batch[0].type()->id()) { + case Type::INT8: + return CastNumberImpl(batch, out); + case Type::INT16: + return CastNumberImpl(batch, out); + case Type::INT32: + return CastNumberImpl(batch, out); + case Type::INT64: + return CastNumberImpl(batch, out); + case Type::UINT8: + return CastNumberImpl(batch, out); + case Type::UINT16: + return CastNumberImpl(batch, out); + case Type::UINT32: + return CastNumberImpl(batch, out); + case Type::UINT64: + return CastNumberImpl(batch, out); + case Type::FLOAT: + return CastNumberImpl(batch, out); + case Type::DOUBLE: + return CastNumberImpl(batch, out); + default: + DCHECK(false); + break; + } +} -template -constexpr RET_TYPE(is_number_downcast) SafeMinimum() { - using out_type = typename O::c_type; +void CastIntegerToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = checked_cast(ctx->state())->options; + if (!options.allow_int_overflow) { + KERNEL_RETURN_IF_ERROR(ctx, IntegersCanFit(batch[0], *out->type())); + } + CastNumberToNumberUnsafe(batch, out); +} - return std::numeric_limits::lowest(); +void CastFloatingToFloating(KernelContext*, const ExecBatch& batch, Datum* out) { + CastNumberToNumberUnsafe(batch, out); } -template -constexpr RET_TYPE(is_number_downcast) SafeMaximum() { - using out_type = typename O::c_type; +// ---------------------------------------------------------------------- +// Implement fast safe floating point to integer cast - return std::numeric_limits::max(); +template +std::string FormatInt(T val) { + return std::to_string(val); } -template -constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMinimum() { - return 0; -} +// InType is a floating point type we are planning to cast to integer +template +ARROW_DISABLE_UBSAN("float-cast-overflow") +Status CheckFloatTruncation(const Datum& input, const Datum& output) { + auto WasTruncated = [&](OutT out_val, InT in_val) -> bool { + return static_cast(out_val) != in_val; + }; + auto WasTruncatedMaybeNull = [&](OutT out_val, InT in_val, bool is_valid) -> bool { + return is_valid && static_cast(out_val) != in_val; + }; + auto GetErrorMessage = [&](InT val) { + return Status::Invalid("Float value ", FormatInt(val), " was truncated converting to", + *output.type()); + }; + + if (input.kind() == Datum::SCALAR) { + DCHECK_EQ(output.kind(), Datum::SCALAR); + const auto& in_scalar = input.scalar_as::ScalarType>(); + const auto& out_scalar = output.scalar_as::ScalarType>(); + if (WasTruncatedMaybeNull(out_scalar.value, in_scalar.value, out_scalar.is_valid)) { + return GetErrorMessage(in_scalar.value); + } + return Status::OK(); + } -template -constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMaximum() { - using in_type = typename I::c_type; - using out_type = typename O::c_type; + const ArrayData& in_array = *input.array(); + const ArrayData& out_array = *output.array(); - // Equality is missing because in_type::max() > out_type::max() when types - // are of the same width. - return static_cast(sizeof(in_type) < sizeof(out_type) - ? std::numeric_limits::max() - : std::numeric_limits::max()); -} + const InT* in_data = in_array.GetValues(1); + const OutT* out_data = out_array.GetValues(1); -template -constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMinimum() { - return 0; + const uint8_t* bitmap = nullptr; + if (in_array.buffers[0]) { + bitmap = in_array.buffers[0]->data(); + } + OptionalBitBlockCounter bit_counter(bitmap, in_array.offset, in_array.length); + int64_t position = 0; + int64_t offset_position = in_array.offset; + while (position < in_array.length) { + BitBlockCount block = bit_counter.NextBlock(); + bool block_out_of_bounds = false; + if (block.popcount == block.length) { + // Fast path: branchless + for (int64_t i = 0; i < block.length; ++i) { + block_out_of_bounds |= WasTruncated(out_data[i], in_data[i]); + } + } else if (block.popcount > 0) { + // Indices have nulls, must only boundscheck non-null values + for (int64_t i = 0; i < block.length; ++i) { + block_out_of_bounds |= WasTruncatedMaybeNull( + out_data[i], in_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + } + } + if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { + if (in_array.GetNullCount() > 0) { + for (int64_t i = 0; i < block.length; ++i) { + if (WasTruncatedMaybeNull(out_data[i], in_data[i], + BitUtil::GetBit(bitmap, offset_position + i))) { + return GetErrorMessage(in_data[i]); + } + } + } else { + for (int64_t i = 0; i < block.length; ++i) { + if (WasTruncated(out_data[i], in_data[i])) { + return GetErrorMessage(in_data[i]); + } + } + } + } + in_data += block.length; + out_data += block.length; + position += block.length; + offset_position += block.length; + } + return Status::OK(); } -template -constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMaximum() { - using in_type = typename I::c_type; - using out_type = typename O::c_type; +template +Status CheckFloatToIntTruncationImpl(const Datum& input, const Datum& output) { + switch (output.type()->id()) { + case Type::INT8: + return CheckFloatTruncation(input, output); + case Type::INT16: + return CheckFloatTruncation(input, output); + case Type::INT32: + return CheckFloatTruncation(input, output); + case Type::INT64: + return CheckFloatTruncation(input, output); + case Type::UINT8: + return CheckFloatTruncation(input, output); + case Type::UINT16: + return CheckFloatTruncation(input, output); + case Type::UINT32: + return CheckFloatTruncation(input, output); + case Type::UINT64: + return CheckFloatTruncation(input, output); + default: + break; + } + DCHECK(false); + return Status::OK(); +} - return static_cast(sizeof(in_type) <= sizeof(out_type) - ? std::numeric_limits::max() - : std::numeric_limits::max()); +Status CheckFloatToIntTruncation(const Datum& input, const Datum& output) { + switch (input.type()->id()) { + case Type::FLOAT: + return CheckFloatToIntTruncationImpl(input, output); + case Type::DOUBLE: + return CheckFloatToIntTruncationImpl(input, output); + default: + break; + } + DCHECK(false); + return Status::OK(); } -#undef RET_TYPE +void CastFloatingToInteger(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = checked_cast(ctx->state())->options; + CastNumberToNumberUnsafe(batch, out); + if (!options.allow_float_truncate) { + KERNEL_RETURN_IF_ERROR(ctx, CheckFloatToIntTruncation(batch[0], *out)); + } +} -// Float to Integer or Integer to Float -template -struct is_float_truncate { - static constexpr bool value = false; -}; +// ---------------------------------------------------------------------- +// Implement fast integer to floating point cast -template -struct is_float_truncate< - O, I, - enable_if_t<(is_integer_type::value && is_floating_type::value) || - (is_integer_type::value && is_floating_type::value)>> { - static constexpr bool value = true; -}; +// These are the limits for exact representation of whole numbers in floating +// point numbers +template +struct FloatingIntegerBound {}; -// Leftover of Number combinations that are safe to cast. -template -struct is_safe_numeric_cast { - static constexpr bool value = false; +template <> +struct FloatingIntegerBound { + static const int64_t value = 1LL << 24; }; -template -struct is_safe_numeric_cast< - O, I, enable_if_t::value && is_number_type::value>> { - using O_T = typename O::c_type; - using I_T = typename I::c_type; - - static constexpr bool value = - (std::is_signed::value == std::is_signed::value) && - (std::is_integral::value == std::is_integral::value) && - (sizeof(O_T) >= sizeof(I_T)) && (!std::is_same::value); +template <> +struct FloatingIntegerBound { + static const int64_t value = 1LL << 53; }; -// ---------------------------------------------------------------------- -// Integer to other number types +template ::value> +Status CheckIntegerFloatTruncateImpl(const Datum& input) { + using InScalarType = typename TypeTraits::ScalarType; + const int64_t limit = FloatingIntegerBound::value; + InScalarType bound_lower(IsSigned ? -limit : 0); + InScalarType bound_upper(limit); + return CheckIntegersInRange(input, bound_lower, bound_upper); +} -template -struct IntegerDowncastNoOverflow { - using InT = typename I::c_type; - static constexpr InT kMax = SafeMaximum(); - static constexpr InT kMin = SafeMinimum(); - - template - OutT Call(KernelContext* ctx, InT val) const { - if (ARROW_PREDICT_FALSE(val > kMax || val < kMin)) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); +Status CheckForIntegerToFloatingTruncation(const Datum& input, Type::type out_type) { + switch (input.type()->id()) { + // Small integers are all exactly representable as whole numbers + case Type::INT8: + case Type::INT16: + case Type::UINT8: + case Type::UINT16: + return Status::OK(); + case Type::INT32: { + if (out_type == Type::DOUBLE) { + return Status::OK(); + } + return CheckIntegerFloatTruncateImpl(input); } - return static_cast(val); - } -}; - -struct StaticCast { - template - ARROW_DISABLE_UBSAN("float-cast-overflow") - static OutT Call(KernelContext*, InT val) { - return static_cast(val); - } -}; - -template -struct CastFunctor::value || - is_integral_signed_to_unsigned::value || - is_integral_unsigned_to_signed::value>> { - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast(ctx->state())->options; - if (!options.allow_int_overflow) { - applicator::ScalarUnaryNotNull>::Exec( - ctx, batch, out); - } else { - applicator::ScalarUnary::Exec(ctx, batch, out); + case Type::UINT32: { + if (out_type == Type::DOUBLE) { + return Status::OK(); + } + return CheckIntegerFloatTruncateImpl(input); } - } -}; - -// ---------------------------------------------------------------------- -// Float to other number types - -struct FloatToIntegerNoTruncate { - template - ARROW_DISABLE_UBSAN("float-cast-overflow") - OutT Call(KernelContext* ctx, InT val) const { - auto out_value = static_cast(val); - if (ARROW_PREDICT_FALSE(static_cast(out_value) != val)) { - ctx->SetStatus(Status::Invalid("Floating point value truncated")); + case Type::INT64: { + if (out_type == Type::FLOAT) { + return CheckIntegerFloatTruncateImpl(input); + } else { + return CheckIntegerFloatTruncateImpl(input); + } } - return out_value; - } -}; - -template -struct CastFunctor::value>> { - ARROW_DISABLE_UBSAN("float-cast-overflow") - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - const auto& options = checked_cast(ctx->state())->options; - if (options.allow_float_truncate) { - applicator::ScalarUnary::Exec(ctx, batch, out); - } else { - applicator::ScalarUnaryNotNull::Exec(ctx, batch, - out); + case Type::UINT64: { + if (out_type == Type::FLOAT) { + return CheckIntegerFloatTruncateImpl(input); + } else { + return CheckIntegerFloatTruncateImpl(input); + } } + default: + break; } -}; + DCHECK(false); + return Status::OK(); +} -template -struct CastFunctor< - O, I, - enable_if_t::value && !is_float_truncate::value && - !is_number_downcast::value>> { - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // Due to various checks done via type-trait, the cast is safe and bear - // no truncation. - applicator::ScalarUnary::Exec(ctx, batch, out); +void CastIntegerToFloating(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const auto& options = checked_cast(ctx->state())->options; + if (!options.allow_float_truncate) { + KERNEL_RETURN_IF_ERROR( + ctx, CheckForIntegerToFloatingTruncation(batch[0], out->type()->id())); } -}; + CastNumberToNumberUnsafe(batch, out); +} // ---------------------------------------------------------------------- // Boolean to number @@ -475,21 +585,16 @@ struct CastFunctor { } }; +namespace { + template -void AddPrimitiveNumberCasts(const std::shared_ptr& out_ty, - CastFunction* func) { +void AddCommonNumberCasts(const std::shared_ptr& out_ty, CastFunction* func) { AddCommonCasts(out_ty->id(), out_ty, func); // Cast from boolean to number DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, CastFunctor::Exec)); - // Cast from other numbers - for (const std::shared_ptr& in_ty : NumericTypes()) { - auto exec = GenerateNumeric(*in_ty); - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, exec)); - } - // Cast from other strings for (const std::shared_ptr& in_ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryBase(*in_ty); @@ -502,8 +607,17 @@ std::shared_ptr GetCastToInteger(std::string name) { auto func = std::make_shared(std::move(name), OutType::type_id); auto out_ty = TypeTraits::type_singleton(); + for (const std::shared_ptr& in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToInteger)); + } + + // Cast from floating point + for (const std::shared_ptr& in_ty : FloatingPointTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToInteger)); + } + // From other numbers to integer - AddPrimitiveNumberCasts(out_ty, func.get()); + AddCommonNumberCasts(out_ty, func.get()); // From decimal to integer DCHECK_OK(func->AddKernel(Type::DECIMAL, {InputType::Array(Type::DECIMAL)}, out_ty, @@ -516,8 +630,18 @@ std::shared_ptr GetCastToFloating(std::string name) { auto func = std::make_shared(std::move(name), OutType::type_id); auto out_ty = TypeTraits::type_singleton(); + // Casts from integer to floating point + for (const std::shared_ptr& in_ty : IntTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastIntegerToFloating)); + } + + // Cast from floating point + for (const std::shared_ptr& in_ty : FloatingPointTypes()) { + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, CastFloatingToFloating)); + } + // From other numbers to integer - AddPrimitiveNumberCasts(out_ty, func.get()); + AddCommonNumberCasts(out_ty, func.get()); return func; } @@ -535,6 +659,8 @@ std::shared_ptr GetCastToDecimal() { return func; } +} // namespace + std::vector> GetNumericCasts() { std::vector> functions; diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc index bd631b3a432..0d82eb57b8a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_temporal.cc @@ -126,6 +126,9 @@ struct CastFunctor< enable_if_t<(is_timestamp_type::value && is_timestamp_type::value) || (is_duration_type::value && is_duration_type::value)>> { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); @@ -144,6 +147,9 @@ struct CastFunctor< template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); @@ -164,6 +170,9 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + const CastOptions& options = checked_cast(*ctx->state()).options; const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); @@ -215,6 +224,9 @@ struct CastFunctor::value && is_time_type:: using out_t = typename O::c_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + const ArrayData& input = *batch[0].array(); ArrayData* output = out->mutable_array(); @@ -233,6 +245,9 @@ struct CastFunctor::value && is_time_type:: template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + ShiftTime(ctx, util::MULTIPLY, kMillisecondsInDay, *batch[0].array(), out->mutable_array()); } @@ -241,6 +256,9 @@ struct CastFunctor { template <> struct CastFunctor { static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO: Make this work on scalar inputs + DCHECK_EQ(batch[0].kind(), Datum::ARRAY); + ShiftTime(ctx, util::DIVIDE, kMillisecondsInDay, *batch[0].array(), out->mutable_array()); } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index a3a683c1af5..8fd33d3731c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -51,6 +51,23 @@ using internal::checked_cast; namespace compute { +// Use std::string and Decimal128 for supplying test values for base binary types + +template +struct TestCType { + using type = typename T::c_type; +}; + +template +struct TestCType> { + using type = std::string; +}; + +template +struct TestCType> { + using type = Decimal128; +}; + static constexpr const char* kInvalidUtf8 = "\xa0\xa1"; static std::vector> kNumericTypes = { @@ -65,16 +82,30 @@ static void AssertBufferSame(const Array& left, const Array& right, int buffer_i class TestCast : public TestBase { public: void CheckPass(const Array& input, const Array& expected, - const std::shared_ptr& out_type, const CastOptions& options) { + const std::shared_ptr& out_type, const CastOptions& options, + bool check_scalar = true) { ASSERT_OK_AND_ASSIGN(std::shared_ptr result, Cast(input, out_type, options)); ASSERT_OK(result->ValidateFull()); AssertArraysEqual(expected, *result, /*verbose=*/true); + + if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) { + // ARROW-9194 + check_scalar = false; + } + + if (check_scalar) { + for (int64_t i = 0; i < input.length(); ++i) { + ASSERT_OK_AND_ASSIGN(Datum out, Cast(*input.GetScalar(i), out_type, options)); + AssertScalarsEqual(**expected.GetScalar(i), *out.scalar(), /*verbose=*/true); + } + } } - template + template ::type> void CheckFails(const std::shared_ptr& in_type, const std::vector& in_values, const std::vector& is_valid, - const std::shared_ptr& out_type, const CastOptions& options) { + const std::shared_ptr& out_type, const CastOptions& options, + bool check_scalar = true) { std::shared_ptr input; if (is_valid.size() > 0) { ArrayFromVector(in_type, is_valid, in_values, &input); @@ -82,6 +113,31 @@ class TestCast : public TestBase { ArrayFromVector(in_type, in_values, &input); } ASSERT_RAISES(Invalid, Cast(*input, out_type, options)); + + if (in_type->id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) { + // ARROW-9194 + check_scalar = false; + } + + // For the scalars, check that at least one of the input fails (since many + // of the tests contains a mix of passing and failing values). In some + // cases we will want to check more precisely + if (check_scalar) { + int64_t num_failing = 0; + for (int64_t i = 0; i < input->length(); ++i) { + auto maybe_out = Cast(*input->GetScalar(i), out_type, options); + num_failing += static_cast(maybe_out.status().IsInvalid()); + } + ASSERT_GT(num_failing, 0); + } + } + + template ::type> + void CheckFails(const std::vector& in_values, const std::vector& is_valid, + const std::shared_ptr& out_type, const CastOptions& options, + bool check_scalar = true) { + CheckFails(TypeTraits::type_singleton(), in_values, is_valid, + out_type, options, check_scalar); } void CheckZeroCopy(const Array& input, const std::shared_ptr& out_type) { @@ -93,11 +149,14 @@ class TestCast : public TestBase { } } - template + template ::type, + typename O_TYPE = typename TestCType::type> void CheckCase(const std::shared_ptr& in_type, const std::vector& in_values, const std::vector& is_valid, const std::shared_ptr& out_type, - const std::vector& out_values, const CastOptions& options) { + const std::vector& out_values, const CastOptions& options, + bool check_scalar = true) { ASSERT_EQ(in_values.size(), out_values.size()); std::shared_ptr input, expected; if (is_valid.size() > 0) { @@ -108,26 +167,38 @@ class TestCast : public TestBase { ArrayFromVector(in_type, in_values, &input); ArrayFromVector(out_type, out_values, &expected); } - CheckPass(*input, *expected, out_type, options); + CheckPass(*input, *expected, out_type, options, check_scalar); // Check a sliced variant if (input->length() > 1) { - CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options); + CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options, check_scalar); } } + template + void CheckCase(const std::vector& in_values, const std::vector& is_valid, + const std::vector& out_values, const CastOptions& options, + bool check_scalar = true) { + CheckCase( + TypeTraits::type_singleton(), in_values, is_valid, + TypeTraits::type_singleton(), out_values, options, check_scalar); + } + void CheckCaseJSON(const std::shared_ptr& in_type, const std::shared_ptr& out_type, const std::string& in_json, const std::string& expected_json, + bool check_scalar = true, const CastOptions& options = CastOptions()) { std::shared_ptr input = ArrayFromJSON(in_type, in_json); std::shared_ptr expected = ArrayFromJSON(out_type, expected_json); ASSERT_EQ(input->length(), expected->length()); - CheckPass(*input, *expected, out_type, options); + CheckPass(*input, *expected, out_type, options, check_scalar); // Check a sliced variant if (input->length() > 1) { - CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options); + CheckPass(*input->Slice(1), *expected->Slice(1), out_type, options, + /*check_scalar=*/false); } } @@ -149,12 +220,13 @@ class TestCast : public TestBase { CheckZeroCopy(*array, dest_type); // Should refuse due to invalid utf8 payload - CheckFails(src_type, strings, all, dest_type, options); + CheckFails(strings, all, dest_type, options, + /*check_scalar=*/false); // Should accept due to option override options.allow_invalid_utf8 = true; - CheckCase( - src_type, strings, all, dest_type, strings, options); + CheckCase(strings, all, strings, options, + /*check_scalar=*/false); } template @@ -162,26 +234,31 @@ class TestCast : public TestBase { auto dest_type = TypeTraits::type_singleton(); CheckCaseJSON(int8(), dest_type, "[0, 1, 127, -128, null]", - R"(["0", "1", "127", "-128", null])"); - CheckCaseJSON(uint8(), dest_type, "[0, 1, 255, null]", R"(["0", "1", "255", null])"); + R"(["0", "1", "127", "-128", null])", /*check_scalar=*/false); + CheckCaseJSON(uint8(), dest_type, "[0, 1, 255, null]", R"(["0", "1", "255", null])", + /*check_scalar=*/false); CheckCaseJSON(int16(), dest_type, "[0, 1, 32767, -32768, null]", - R"(["0", "1", "32767", "-32768", null])"); + R"(["0", "1", "32767", "-32768", null])", /*check_scalar=*/false); CheckCaseJSON(uint16(), dest_type, "[0, 1, 65535, null]", - R"(["0", "1", "65535", null])"); + R"(["0", "1", "65535", null])", /*check_scalar=*/false); CheckCaseJSON(int32(), dest_type, "[0, 1, 2147483647, -2147483648, null]", - R"(["0", "1", "2147483647", "-2147483648", null])"); + R"(["0", "1", "2147483647", "-2147483648", null])", + /*check_scalar=*/false); CheckCaseJSON(uint32(), dest_type, "[0, 1, 4294967295, null]", - R"(["0", "1", "4294967295", null])"); + R"(["0", "1", "4294967295", null])", /*check_scalar=*/false); CheckCaseJSON(int64(), dest_type, "[0, 1, 9223372036854775807, -9223372036854775808, null]", - R"(["0", "1", "9223372036854775807", "-9223372036854775808", null])"); + R"(["0", "1", "9223372036854775807", "-9223372036854775808", null])", + /*check_scalar=*/false); CheckCaseJSON(uint64(), dest_type, "[0, 1, 18446744073709551615, null]", - R"(["0", "1", "18446744073709551615", null])"); + R"(["0", "1", "18446744073709551615", null])", /*check_scalar=*/false); CheckCaseJSON(float32(), dest_type, "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]", - R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])"); + R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])", + /*check_scalar=*/false); CheckCaseJSON(float64(), dest_type, "[0.0, -0.0, 1.5, -Inf, Inf, NaN, null]", - R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])"); + R"(["0", "-0", "1.5", "-inf", "inf", "nan", null])", + /*check_scalar=*/false); } template @@ -189,7 +266,7 @@ class TestCast : public TestBase { auto dest_type = TypeTraits::type_singleton(); CheckCaseJSON(boolean(), dest_type, "[true, true, false, null]", - R"(["true", "true", "false", null])"); + R"(["true", "true", "false", null])", /*check_scalar=*/false); } template @@ -205,23 +282,17 @@ class TestCast : public TestBase { std::vector e_int16 = {0, 1, 127, -1, 0}; std::vector e_int32 = {0, 1, 127, -1, 0}; std::vector e_int64 = {0, 1, 127, -1, 0}; - CheckCase(src_type, v_int, is_valid, - int8(), e_int8, options); - CheckCase(src_type, v_int, is_valid, - int16(), e_int16, options); - CheckCase(src_type, v_int, is_valid, - int32(), e_int32, options); - CheckCase(src_type, v_int, is_valid, - int64(), e_int64, options); + CheckCase(v_int, is_valid, e_int8, options); + CheckCase(v_int, is_valid, e_int16, options); + CheckCase(v_int, is_valid, e_int32, options); + CheckCase(v_int, is_valid, e_int64, options); v_int = {"2147483647", "0", "-2147483648", "0", "0"}; e_int32 = {2147483647, 0, -2147483648LL, 0, 0}; - CheckCase(src_type, v_int, is_valid, - int32(), e_int32, options); + CheckCase(v_int, is_valid, e_int32, options); v_int = {"9223372036854775807", "0", "-9223372036854775808", "0", "0"}; e_int64 = {9223372036854775807LL, 0, (-9223372036854775807LL - 1), 0, 0}; - CheckCase(src_type, v_int, is_valid, - int64(), e_int64, options); + CheckCase(v_int, is_valid, e_int64, options); // string to uint std::vector v_uint = {"0", "1", "127", "255", "0"}; @@ -229,42 +300,32 @@ class TestCast : public TestBase { std::vector e_uint16 = {0, 1, 127, 255, 0}; std::vector e_uint32 = {0, 1, 127, 255, 0}; std::vector e_uint64 = {0, 1, 127, 255, 0}; - CheckCase(src_type, v_uint, is_valid, - uint8(), e_uint8, options); - CheckCase(src_type, v_uint, is_valid, - uint16(), e_uint16, options); - CheckCase(src_type, v_uint, is_valid, - uint32(), e_uint32, options); - CheckCase(src_type, v_uint, is_valid, - uint64(), e_uint64, options); + CheckCase(v_uint, is_valid, e_uint8, options); + CheckCase(v_uint, is_valid, e_uint16, options); + CheckCase(v_uint, is_valid, e_uint32, options); + CheckCase(v_uint, is_valid, e_uint64, options); v_uint = {"4294967295", "0", "0", "0", "0"}; e_uint32 = {4294967295, 0, 0, 0, 0}; - CheckCase(src_type, v_uint, is_valid, - uint32(), e_uint32, options); + CheckCase(v_uint, is_valid, e_uint32, options); v_uint = {"18446744073709551615", "0", "0", "0", "0"}; e_uint64 = {18446744073709551615ULL, 0, 0, 0, 0}; - CheckCase(src_type, v_uint, is_valid, - uint64(), e_uint64, options); + CheckCase(v_uint, is_valid, e_uint64, options); // string to float std::vector v_float = {"0.1", "1.2", "127.3", "200.4", "0.5"}; std::vector e_float = {0.1f, 1.2f, 127.3f, 200.4f, 0.5f}; std::vector e_double = {0.1, 1.2, 127.3, 200.4, 0.5}; - CheckCase(src_type, v_float, is_valid, - float32(), e_float, options); - CheckCase(src_type, v_float, is_valid, - float64(), e_double, options); + CheckCase(v_float, is_valid, e_float, options); + CheckCase(v_float, is_valid, e_double, options); #if !defined(_WIN32) || defined(NDEBUG) // Test that casting is locale-independent { // French locale uses the comma as decimal point LocaleGuard locale_guard("fr_FR.UTF-8"); - CheckCase(src_type, v_float, is_valid, - float32(), e_float, options); - CheckCase( - src_type, v_float, is_valid, float64(), e_double, options); + CheckCase(v_float, is_valid, e_float, options); + CheckCase(v_float, is_valid, e_double, options); } #endif } @@ -279,13 +340,11 @@ class TestCast : public TestBase { auto type = timestamp(TimeUnit::SECOND); std::vector e = {0, 0, 951782400}; - CheckCase( - src_type, strings, is_valid, type, e, options); + CheckCase(src_type, strings, is_valid, type, e, options); type = timestamp(TimeUnit::MICRO); e = {0, 0, 951782400000000LL}; - CheckCase( - src_type, strings, is_valid, type, e, options); + CheckCase(src_type, strings, is_valid, type, e, options); // NOTE: timestamp parsing is tested comprehensively in parsing-util-test.cc } @@ -322,8 +381,7 @@ TEST_F(TestCast, FromBoolean) { } } - CheckCase(boolean(), v1, is_valid, int32(), e1, - options); + CheckCase(v1, is_valid, e1, options); } TEST_F(TestCast, ToBoolean) { @@ -349,20 +407,17 @@ TEST_F(TestCast, ToIntUpcast) { // int8 to int32 std::vector v1 = {0, 1, 127, -1, 0}; std::vector e1 = {0, 1, 127, -1, 0}; - CheckCase(int8(), v1, is_valid, int32(), e1, - options); + CheckCase(v1, is_valid, e1, options); // bool to int8 std::vector v2 = {false, true, false, true, true}; std::vector e2 = {0, 1, 0, 1, 1}; - CheckCase(boolean(), v2, is_valid, int8(), e2, - options); + CheckCase(v2, is_valid, e2, options); // uint8 to int16, no overflow/underrun std::vector v3 = {0, 100, 200, 255, 0}; std::vector e3 = {0, 100, 200, 255, 0}; - CheckCase(uint8(), v3, is_valid, int16(), e3, - options); + CheckCase(v3, is_valid, e3, options); } TEST_F(TestCast, OverflowInNullSlot) { @@ -375,7 +430,7 @@ TEST_F(TestCast, OverflowInNullSlot) { std::vector e11 = {0, 0, 2000, 1000, 0}; std::shared_ptr expected; - ArrayFromVector(int16(), is_valid, e11, &expected); + ArrayFromVector(int16(), is_valid, e11, &expected); auto buf = Buffer::Wrap(v11.data(), v11.size()); Int32Array tmp11(5, buf, expected->null_bitmap(), -1); @@ -392,33 +447,31 @@ TEST_F(TestCast, ToIntDowncastSafe) { // int16 to uint8, no overflow/underrun std::vector v1 = {0, 100, 200, 1, 2}; std::vector e1 = {0, 100, 200, 1, 2}; - CheckCase(int16(), v1, is_valid, uint8(), e1, - options); + CheckCase(v1, is_valid, e1, options); // int16 to uint8, with overflow std::vector v2 = {0, 100, 256, 0, 0}; - CheckFails(int16(), v2, is_valid, uint8(), options); + CheckFails(v2, is_valid, uint8(), options); // underflow std::vector v3 = {0, 100, -1, 0, 0}; - CheckFails(int16(), v3, is_valid, uint8(), options); + CheckFails(v3, is_valid, uint8(), options); // int32 to int16, no overflow std::vector v4 = {0, 1000, 2000, 1, 2}; std::vector e4 = {0, 1000, 2000, 1, 2}; - CheckCase(int32(), v4, is_valid, int16(), e4, - options); + CheckCase(v4, is_valid, e4, options); // int32 to int16, overflow std::vector v5 = {0, 1000, 2000, 70000, 0}; - CheckFails(int32(), v5, is_valid, int16(), options); + CheckFails(v5, is_valid, int16(), options); // underflow std::vector v6 = {0, 1000, 2000, -70000, 0}; - CheckFails(int32(), v6, is_valid, int16(), options); + CheckFails(v6, is_valid, int16(), options); std::vector v7 = {0, 1000, 2000, -70000, 0}; - CheckFails(int32(), v7, is_valid, uint8(), options); + CheckFails(v7, is_valid, uint8(), options); } template @@ -440,26 +493,25 @@ TEST_F(TestCast, IntegerSignedToUnsigned) { std::vector v1 = {INT32_MIN, 100, -1, UINT16_MAX, INT32_MAX}; // Same width - CheckFails(int32(), v1, is_valid, uint32(), options); + CheckFails(v1, is_valid, uint32(), options); // Wider - CheckFails(int32(), v1, is_valid, uint64(), options); + CheckFails(v1, is_valid, uint64(), options); // Narrower - CheckFails(int32(), v1, is_valid, uint16(), options); + CheckFails(v1, is_valid, uint16(), options); // Fail because of overflow (instead of underflow). std::vector over = {0, -11, 0, UINT16_MAX + 1, INT32_MAX}; - CheckFails(int32(), over, is_valid, uint16(), options); + CheckFails(over, is_valid, uint16(), options); options.allow_int_overflow = true; - CheckCase( - int32(), v1, is_valid, uint32(), UnsafeVectorCast(v1), options); - CheckCase( - int32(), v1, is_valid, uint64(), UnsafeVectorCast(v1), options); - CheckCase( - int32(), v1, is_valid, uint16(), UnsafeVectorCast(v1), options); - CheckCase( - int32(), over, is_valid, uint16(), UnsafeVectorCast(over), - options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(over, is_valid, + UnsafeVectorCast(over), options); } TEST_F(TestCast, IntegerUnsignedToSigned) { @@ -471,21 +523,21 @@ TEST_F(TestCast, IntegerUnsignedToSigned) { std::vector v1 = {0, INT16_MAX + 1, UINT32_MAX}; std::vector v2 = {0, INT16_MAX + 1, 2}; // Same width - CheckFails(uint32(), v1, is_valid, int32(), options); + CheckFails(v1, is_valid, int32(), options); // Narrower - CheckFails(uint32(), v1, is_valid, int16(), options); - CheckFails(uint32(), v2, is_valid, int16(), options); + CheckFails(v1, is_valid, int16(), options); + CheckFails(v2, is_valid, int16(), options); options.allow_int_overflow = true; - CheckCase( - uint32(), v1, is_valid, int32(), UnsafeVectorCast(v1), options); - CheckCase( - uint32(), v1, is_valid, int64(), UnsafeVectorCast(v1), options); - CheckCase( - uint32(), v1, is_valid, int16(), UnsafeVectorCast(v1), options); - CheckCase( - uint32(), v2, is_valid, int16(), UnsafeVectorCast(v2), options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(v1, is_valid, UnsafeVectorCast(v1), + options); + CheckCase(v2, is_valid, UnsafeVectorCast(v2), + options); } TEST_F(TestCast, ToIntDowncastUnsafe) { @@ -497,40 +549,34 @@ TEST_F(TestCast, ToIntDowncastUnsafe) { // int16 to uint8, no overflow/underrun std::vector v1 = {0, 100, 200, 1, 2}; std::vector e1 = {0, 100, 200, 1, 2}; - CheckCase(int16(), v1, is_valid, uint8(), e1, - options); + CheckCase(v1, is_valid, e1, options); // int16 to uint8, with overflow std::vector v2 = {0, 100, 256, 0, 0}; std::vector e2 = {0, 100, 0, 0, 0}; - CheckCase(int16(), v2, is_valid, uint8(), e2, - options); + CheckCase(v2, is_valid, e2, options); // underflow std::vector v3 = {0, 100, -1, 0, 0}; std::vector e3 = {0, 100, 255, 0, 0}; - CheckCase(int16(), v3, is_valid, uint8(), e3, - options); + CheckCase(v3, is_valid, e3, options); // int32 to int16, no overflow std::vector v4 = {0, 1000, 2000, 1, 2}; std::vector e4 = {0, 1000, 2000, 1, 2}; - CheckCase(int32(), v4, is_valid, int16(), e4, - options); + CheckCase(v4, is_valid, e4, options); // int32 to int16, overflow // TODO(wesm): do we want to allow this? we could set to null std::vector v5 = {0, 1000, 2000, 70000, 0}; std::vector e5 = {0, 1000, 2000, 4464, 0}; - CheckCase(int32(), v5, is_valid, int16(), e5, - options); + CheckCase(v5, is_valid, e5, options); // underflow // TODO(wesm): do we want to allow this? we could set overflow to null std::vector v6 = {0, 1000, 2000, -70000, 0}; std::vector e6 = {0, 1000, 2000, -4464, 0}; - CheckCase(int32(), v6, is_valid, int16(), e6, - options); + CheckCase(v6, is_valid, e6, options); } TEST_F(TestCast, FloatingPointToInt) { @@ -543,72 +589,74 @@ TEST_F(TestCast, FloatingPointToInt) { // float32 to int32 no truncation std::vector v1 = {1.0, 0, 0.0, -1.0, 5.0}; std::vector e1 = {1, 0, 0, -1, 5}; - CheckCase(float32(), v1, is_valid, int32(), e1, - options); - CheckCase(float32(), v1, all_valid, int32(), e1, - options); + CheckCase(v1, is_valid, e1, options); + CheckCase(v1, all_valid, e1, options); // float64 to int32 no truncation std::vector v2 = {1.0, 0, 0.0, -1.0, 5.0}; std::vector e2 = {1, 0, 0, -1, 5}; - CheckCase(float64(), v2, is_valid, int32(), e2, - options); - CheckCase(float64(), v2, all_valid, int32(), e2, - options); + CheckCase(v2, is_valid, e2, options); + CheckCase(v2, all_valid, e2, options); // float64 to int64 no truncation std::vector v3 = {1.0, 0, 0.0, -1.0, 5.0}; std::vector e3 = {1, 0, 0, -1, 5}; - CheckCase(float64(), v3, is_valid, int64(), e3, - options); - CheckCase(float64(), v3, all_valid, int64(), e3, - options); + CheckCase(v3, is_valid, e3, options); + CheckCase(v3, all_valid, e3, options); // float64 to int32 truncate std::vector v4 = {1.5, 0, 0.5, -1.5, 5.5}; std::vector e4 = {1, 0, 0, -1, 5}; options.allow_float_truncate = false; - CheckFails(float64(), v4, is_valid, int32(), options); - CheckFails(float64(), v4, all_valid, int32(), options); + CheckFails(v4, is_valid, int32(), options); + CheckFails(v4, all_valid, int32(), options); options.allow_float_truncate = true; - CheckCase(float64(), v4, is_valid, int32(), e4, - options); - CheckCase(float64(), v4, all_valid, int32(), e4, - options); + CheckCase(v4, is_valid, e4, options); + CheckCase(v4, all_valid, e4, options); // float64 to int64 truncate std::vector v5 = {1.5, 0, 0.5, -1.5, 5.5}; std::vector e5 = {1, 0, 0, -1, 5}; options.allow_float_truncate = false; - CheckFails(float64(), v5, is_valid, int64(), options); - CheckFails(float64(), v5, all_valid, int64(), options); + CheckFails(v5, is_valid, int64(), options); + CheckFails(v5, all_valid, int64(), options); options.allow_float_truncate = true; - CheckCase(float64(), v5, is_valid, int64(), e5, - options); - CheckCase(float64(), v5, all_valid, int64(), e5, - options); + CheckCase(v5, is_valid, e5, options); + CheckCase(v5, all_valid, e5, options); } -#if ARROW_BITNESS >= 64 TEST_F(TestCast, IntToFloatingPoint) { auto options = CastOptions::Safe(); std::vector all_valid = {true, true, true, true, true}; std::vector all_invalid = {false, false, false, false, false}; + std::vector u32_v1 = {1LL << 24, (1LL << 24) + 1}; + CheckFails(u32_v1, {true, true}, float32(), options); + + std::vector u32_v2 = {1LL << 24, 1LL << 24}; + CheckCase(u32_v2, {true, true}, + UnsafeVectorCast(u32_v2), options); + + std::vector i32_v1 = {1LL << 24, (1LL << 24) + 1}; + std::vector i32_v2 = {1LL << 24, 1LL << 24}; + CheckFails(i32_v1, {true, true}, float32(), options); + CheckCase(i32_v2, {true, true}, + UnsafeVectorCast(i32_v2), options); + std::vector v1 = {INT64_MIN, INT64_MIN + 1, 0, INT64_MAX - 1, INT64_MAX}; - CheckFails(int64(), v1, all_valid, float32(), options); + CheckFails(v1, all_valid, float64(), options); // While it's not safe to convert, all values are null. - CheckCase(int64(), v1, all_invalid, float64(), - UnsafeVectorCast(v1), - options); + CheckCase(v1, all_invalid, UnsafeVectorCast(v1), + options); + + CheckFails({1LL << 53, (1LL << 53) + 1}, {true, true}, float64(), options); } -#endif TEST_F(TestCast, DecimalToInt) { CastOptions options; @@ -628,10 +676,10 @@ TEST_F(TestCast, DecimalToInt) { for (bool allow_decimal_truncate : {false, true}) { options.allow_int_overflow = allow_int_overflow; options.allow_decimal_truncate = allow_decimal_truncate; - CheckCase( - decimal(38, 10), v12, is_valid2, int64(), e12, options); - CheckCase( - decimal(38, 10), v13, is_valid3, int64(), e13, options); + CheckCase(decimal(38, 10), v12, is_valid2, int64(), e12, + options); + CheckCase(decimal(38, 10), v13, is_valid3, int64(), e13, + options); } } @@ -647,10 +695,10 @@ TEST_F(TestCast, DecimalToInt) { for (bool allow_int_overflow : {false, true}) { options.allow_int_overflow = allow_int_overflow; options.allow_decimal_truncate = true; - CheckCase( - decimal(38, 10), v22, is_valid2, int64(), e22, options); - CheckCase( - decimal(38, 10), v23, is_valid3, int64(), e23, options); + CheckCase(decimal(38, 10), v22, is_valid2, int64(), e22, + options); + CheckCase(decimal(38, 10), v23, is_valid3, int64(), e23, + options); options.allow_decimal_truncate = false; CheckFails(decimal(38, 10), v22, is_valid2, int64(), options); CheckFails(decimal(38, 10), v23, is_valid3, int64(), options); @@ -669,10 +717,10 @@ TEST_F(TestCast, DecimalToInt) { for (bool allow_decimal_truncate : {false, true}) { options.allow_decimal_truncate = allow_decimal_truncate; options.allow_int_overflow = true; - CheckCase( - decimal(38, 10), v32, is_valid2, int64(), e32, options); - CheckCase( - decimal(38, 10), v33, is_valid3, int64(), e33, options); + CheckCase(decimal(38, 10), v32, is_valid2, int64(), e32, + options); + CheckCase(decimal(38, 10), v33, is_valid3, int64(), e33, + options); options.allow_int_overflow = false; CheckFails(decimal(38, 10), v32, is_valid2, int64(), options); CheckFails(decimal(38, 10), v33, is_valid3, int64(), options); @@ -693,10 +741,10 @@ TEST_F(TestCast, DecimalToInt) { options.allow_int_overflow = allow_int_overflow; options.allow_decimal_truncate = allow_decimal_truncate; if (options.allow_int_overflow && options.allow_decimal_truncate) { - CheckCase( - decimal(38, 10), v42, is_valid2, int64(), e42, options); - CheckCase( - decimal(38, 10), v43, is_valid3, int64(), e43, options); + CheckCase(decimal(38, 10), v42, is_valid2, int64(), + e42, options); + CheckCase(decimal(38, 10), v43, is_valid3, int64(), + e43, options); } else { CheckFails(decimal(38, 10), v42, is_valid2, int64(), options); CheckFails(decimal(38, 10), v43, is_valid3, int64(), options); @@ -708,8 +756,8 @@ TEST_F(TestCast, DecimalToInt) { std::vector v5 = {Decimal128("1234567890000."), Decimal128("-120000.")}; for (int i = 0; i < 2; i++) v5[i] = v5[i].Rescale(0, -4).ValueOrDie(); std::vector e5 = {1234567890000, -120000}; - CheckCase( - decimal(38, -4), v5, is_valid2, int64(), e5, options); + CheckCase(decimal(38, -4), v5, is_valid2, int64(), e5, + options); } TEST_F(TestCast, DecimalToDecimal) { @@ -730,26 +778,26 @@ TEST_F(TestCast, DecimalToDecimal) { for (bool allow_decimal_truncate : {false, true}) { options.allow_decimal_truncate = allow_decimal_truncate; - CheckCase( - decimal(38, 10), v12, is_valid2, decimal(28, 0), e12, options); - CheckCase( - decimal(38, 10), v13, is_valid3, decimal(28, 0), e13, options); + CheckCase(decimal(38, 10), v12, is_valid2, + decimal(28, 0), e12, options); + CheckCase(decimal(38, 10), v13, is_valid3, + decimal(28, 0), e13, options); // and back - CheckCase( - decimal(28, 0), e12, is_valid2, decimal(38, 10), v12, options); - CheckCase( - decimal(28, 0), e13, is_valid3, decimal(38, 10), v13, options); + CheckCase(decimal(28, 0), e12, is_valid2, + decimal(38, 10), v12, options); + CheckCase(decimal(28, 0), e13, is_valid3, + decimal(38, 10), v13, options); } // Same scale, different precision std::vector v14 = {Decimal128("12.34"), Decimal128("0.56")}; for (bool allow_decimal_truncate : {false, true}) { options.allow_decimal_truncate = allow_decimal_truncate; - CheckCase( - decimal(5, 2), v14, is_valid2, decimal(4, 2), v14, options); + CheckCase(decimal(5, 2), v14, is_valid2, + decimal(4, 2), v14, options); // and back - CheckCase( - decimal(4, 2), v14, is_valid2, decimal(5, 2), v14, options); + CheckCase(decimal(4, 2), v14, is_valid2, + decimal(5, 2), v14, options); } auto check_truncate = [this](const std::shared_ptr& input_type, @@ -760,8 +808,8 @@ TEST_F(TestCast, DecimalToDecimal) { CastOptions options; options.allow_decimal_truncate = true; - CheckCase( - input_type, input, is_valid, output_type, expected_output, options); + CheckCase(input_type, input, is_valid, output_type, + expected_output, options); options.allow_decimal_truncate = false; CheckFails(input_type, input, is_valid, output_type, options); }; @@ -775,19 +823,19 @@ TEST_F(TestCast, DecimalToDecimal) { CastOptions options; options.allow_decimal_truncate = true; - CheckCase( - input_type, input, is_valid, output_type, expected_output, options); + CheckCase(input_type, input, is_valid, + output_type, expected_output, options); // and back - CheckCase( - output_type, expected_output, is_valid, input_type, expected_back_convert, - options); + CheckCase(output_type, expected_output, is_valid, + input_type, expected_back_convert, + options); options.allow_decimal_truncate = false; CheckFails(input_type, input, is_valid, output_type, options); // back case is valid - CheckCase( - output_type, expected_output, is_valid, input_type, expected_back_convert, - options); + CheckCase(output_type, expected_output, is_valid, + input_type, expected_back_convert, + options); }; // Rescale leads to truncation @@ -830,14 +878,16 @@ TEST_F(TestCast, DecimalToDecimal) { TEST_F(TestCast, TimestampToTimestamp) { CastOptions options; - auto CheckTimestampCast = - [this](const CastOptions& options, TimeUnit::type from_unit, TimeUnit::type to_unit, - const std::vector& from_values, - const std::vector& to_values, const std::vector& is_valid) { - CheckCase( - timestamp(from_unit), from_values, is_valid, timestamp(to_unit), to_values, - options); - }; + auto CheckTimestampCast = [this](const CastOptions& options, TimeUnit::type from_unit, + TimeUnit::type to_unit, + const std::vector& from_values, + const std::vector& to_values, + const std::vector& is_valid) { + // ARROW-9196: make temporal casts work with scalars + CheckCase(timestamp(from_unit), from_values, is_valid, + timestamp(to_unit), to_values, options, + /*check_scalar=*/false); + }; std::vector is_valid = {true, false, true, true, true}; @@ -869,8 +919,7 @@ TEST_F(TestCast, TimestampToTimestamp) { // Zero copy std::vector v7 = {0, 70000, 2000, 1000, 0}; std::shared_ptr arr; - ArrayFromVector(timestamp(TimeUnit::SECOND), is_valid, v7, - &arr); + ArrayFromVector(timestamp(TimeUnit::SECOND), is_valid, v7, &arr); CheckZeroCopy(*arr, timestamp(TimeUnit::SECOND)); // ARROW-1773, cast to integer @@ -897,17 +946,23 @@ TEST_F(TestCast, TimestampToTimestamp) { // Disallow truncate, failures options.allow_time_truncate = false; CheckFails(timestamp(TimeUnit::MILLI), v8, is_valid, - timestamp(TimeUnit::SECOND), options); + timestamp(TimeUnit::SECOND), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MICRO), v8, is_valid, - timestamp(TimeUnit::MILLI), options); + timestamp(TimeUnit::MILLI), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::NANO), v8, is_valid, - timestamp(TimeUnit::MICRO), options); + timestamp(TimeUnit::MICRO), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MICRO), v9, is_valid, - timestamp(TimeUnit::SECOND), options); + timestamp(TimeUnit::SECOND), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::NANO), v9, is_valid, - timestamp(TimeUnit::MILLI), options); + timestamp(TimeUnit::MILLI), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::NANO), v10, is_valid, - timestamp(TimeUnit::SECOND), options); + timestamp(TimeUnit::SECOND), options, + /*check_scalar=*/false); // Multiply overflow @@ -917,7 +972,8 @@ TEST_F(TestCast, TimestampToTimestamp) { options.allow_time_overflow = false; CheckFails(timestamp(TimeUnit::SECOND), v11, is_valid, - timestamp(TimeUnit::NANO), options); + timestamp(TimeUnit::NANO), options, + /*check_scalar=*/false); } TEST_F(TestCast, TimestampToDate32_Date64) { @@ -933,23 +989,31 @@ TEST_F(TestCast, TimestampToDate32_Date64) { std::vector v_day = {10957, 10958, 0}; // Simple conversions - CheckCase( - timestamp(TimeUnit::NANO), v_nano, is_valid, date64(), v_milli, options); - CheckCase( - timestamp(TimeUnit::MICRO), v_micro, is_valid, date64(), v_milli, options); - CheckCase( - timestamp(TimeUnit::MILLI), v_milli, is_valid, date64(), v_milli, options); - CheckCase( - timestamp(TimeUnit::SECOND), v_second, is_valid, date64(), v_milli, options); - - CheckCase( - timestamp(TimeUnit::NANO), v_nano, is_valid, date32(), v_day, options); - CheckCase( - timestamp(TimeUnit::MICRO), v_micro, is_valid, date32(), v_day, options); - CheckCase( - timestamp(TimeUnit::MILLI), v_milli, is_valid, date32(), v_day, options); - CheckCase( - timestamp(TimeUnit::SECOND), v_second, is_valid, date32(), v_day, options); + CheckCase(timestamp(TimeUnit::NANO), v_nano, is_valid, + date64(), v_milli, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::MICRO), v_micro, is_valid, + date64(), v_milli, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::MILLI), v_milli, is_valid, + date64(), v_milli, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::SECOND), v_second, is_valid, + date64(), v_milli, options, + /*check_scalar=*/false); + + CheckCase(timestamp(TimeUnit::NANO), v_nano, is_valid, + date32(), v_day, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::MICRO), v_micro, is_valid, + date32(), v_day, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::MILLI), v_milli, is_valid, + date32(), v_day, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::SECOND), v_second, is_valid, + date32(), v_day, options, + /*check_scalar=*/false); // Disallow truncate, failures std::vector v_nano_fail = {946684800000000001, 946771200000000001, 0}; @@ -959,29 +1023,39 @@ TEST_F(TestCast, TimestampToDate32_Date64) { options.allow_time_truncate = false; CheckFails(timestamp(TimeUnit::NANO), v_nano_fail, is_valid, date64(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MICRO), v_micro_fail, is_valid, date64(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MILLI), v_milli_fail, is_valid, date64(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::SECOND), v_second_fail, is_valid, - date64(), options); + date64(), options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::NANO), v_nano_fail, is_valid, date32(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MICRO), v_micro_fail, is_valid, date32(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::MILLI), v_milli_fail, is_valid, date32(), - options); + options, + /*check_scalar=*/false); CheckFails(timestamp(TimeUnit::SECOND), v_second_fail, is_valid, - date32(), options); + date32(), options, + /*check_scalar=*/false); // Make sure that nulls are excluded from the truncation checks std::vector v_second_nofail = {946684800, 946771200, 1}; - CheckCase( - timestamp(TimeUnit::SECOND), v_second_nofail, is_valid, date64(), v_milli, options); - CheckCase( - timestamp(TimeUnit::SECOND), v_second_nofail, is_valid, date32(), v_day, options); + CheckCase(timestamp(TimeUnit::SECOND), v_second_nofail, + is_valid, date64(), v_milli, options, + /*check_scalar=*/false); + CheckCase(timestamp(TimeUnit::SECOND), v_second_nofail, + is_valid, date32(), v_day, options, + /*check_scalar=*/false); } TEST_F(TestCast, TimeToCompatible) { @@ -992,45 +1066,51 @@ TEST_F(TestCast, TimeToCompatible) { // Multiply promotions std::vector v1 = {0, 100, 200, 1, 2}; std::vector e1 = {0, 100000, 200000, 1000, 2000}; - CheckCase( - time32(TimeUnit::SECOND), v1, is_valid, time32(TimeUnit::MILLI), e1, options); + CheckCase(time32(TimeUnit::SECOND), v1, is_valid, + time32(TimeUnit::MILLI), e1, options, + /*check_scalar=*/false); std::vector v2 = {0, 100, 200, 1, 2}; std::vector e2 = {0, 100000000L, 200000000L, 1000000, 2000000}; - CheckCase( - time32(TimeUnit::SECOND), v2, is_valid, time64(TimeUnit::MICRO), e2, options); + CheckCase(time32(TimeUnit::SECOND), v2, is_valid, + time64(TimeUnit::MICRO), e2, options, + /*check_scalar=*/false); std::vector v3 = {0, 100, 200, 1, 2}; std::vector e3 = {0, 100000000000L, 200000000000L, 1000000000L, 2000000000L}; - CheckCase( - time32(TimeUnit::SECOND), v3, is_valid, time64(TimeUnit::NANO), e3, options); + CheckCase(time32(TimeUnit::SECOND), v3, is_valid, + time64(TimeUnit::NANO), e3, options, + /*check_scalar=*/false); std::vector v4 = {0, 100, 200, 1, 2}; std::vector e4 = {0, 100000, 200000, 1000, 2000}; - CheckCase( - time32(TimeUnit::MILLI), v4, is_valid, time64(TimeUnit::MICRO), e4, options); + CheckCase(time32(TimeUnit::MILLI), v4, is_valid, + time64(TimeUnit::MICRO), e4, options, + /*check_scalar=*/false); std::vector v5 = {0, 100, 200, 1, 2}; std::vector e5 = {0, 100000000L, 200000000L, 1000000, 2000000}; - CheckCase( - time32(TimeUnit::MILLI), v5, is_valid, time64(TimeUnit::NANO), e5, options); + CheckCase(time32(TimeUnit::MILLI), v5, is_valid, + time64(TimeUnit::NANO), e5, options, + /*check_scalar=*/false); std::vector v6 = {0, 100, 200, 1, 2}; std::vector e6 = {0, 100000, 200000, 1000, 2000}; - CheckCase( - time64(TimeUnit::MICRO), v6, is_valid, time64(TimeUnit::NANO), e6, options); + CheckCase(time64(TimeUnit::MICRO), v6, is_valid, + time64(TimeUnit::NANO), e6, options, + /*check_scalar=*/false); // Zero copy std::vector v7 = {0, 70000, 2000, 1000, 0}; std::shared_ptr arr; - ArrayFromVector(time64(TimeUnit::MICRO), is_valid, v7, &arr); + ArrayFromVector(time64(TimeUnit::MICRO), is_valid, v7, &arr); CheckZeroCopy(*arr, time64(TimeUnit::MICRO)); // ARROW-1773: cast to int64 CheckZeroCopy(*arr, int64()); std::vector v7_2 = {0, 70000, 2000, 1000, 0}; - ArrayFromVector(time32(TimeUnit::SECOND), is_valid, v7_2, &arr); + ArrayFromVector(time32(TimeUnit::SECOND), is_valid, v7_2, &arr); CheckZeroCopy(*arr, time32(TimeUnit::SECOND)); // ARROW-1773: cast to int64 @@ -1041,40 +1121,46 @@ TEST_F(TestCast, TimeToCompatible) { std::vector e8 = {0, 100, 200, 1, 2}; options.allow_time_truncate = true; - CheckCase( - time32(TimeUnit::MILLI), v8, is_valid, time32(TimeUnit::SECOND), e8, options); - CheckCase( - time64(TimeUnit::MICRO), v8, is_valid, time32(TimeUnit::MILLI), e8, options); - CheckCase( - time64(TimeUnit::NANO), v8, is_valid, time64(TimeUnit::MICRO), e8, options); + CheckCase(time32(TimeUnit::MILLI), v8, is_valid, + time32(TimeUnit::SECOND), e8, options, + /*check_scalar=*/false); + CheckCase(time64(TimeUnit::MICRO), v8, is_valid, + time32(TimeUnit::MILLI), e8, options, + /*check_scalar=*/false); + CheckCase(time64(TimeUnit::NANO), v8, is_valid, + time64(TimeUnit::MICRO), e8, options, + /*check_scalar=*/false); std::vector v9 = {0, 100123000, 200456000, 1123000, 2456000}; std::vector e9 = {0, 100, 200, 1, 2}; - CheckCase( - time64(TimeUnit::MICRO), v9, is_valid, time32(TimeUnit::SECOND), e9, options); - CheckCase( - time64(TimeUnit::NANO), v9, is_valid, time32(TimeUnit::MILLI), e9, options); + CheckCase(time64(TimeUnit::MICRO), v9, is_valid, + time32(TimeUnit::SECOND), e9, options, + /*check_scalar=*/false); + CheckCase(time64(TimeUnit::NANO), v9, is_valid, + time32(TimeUnit::MILLI), e9, options, + /*check_scalar=*/false); std::vector v10 = {0, 100123000000L, 200456000000L, 1123000000L, 2456000000}; std::vector e10 = {0, 100, 200, 1, 2}; - CheckCase( - time64(TimeUnit::NANO), v10, is_valid, time32(TimeUnit::SECOND), e10, options); + CheckCase(time64(TimeUnit::NANO), v10, is_valid, + time32(TimeUnit::SECOND), e10, options, + /*check_scalar=*/false); // Disallow truncate, failures options.allow_time_truncate = false; CheckFails(time32(TimeUnit::MILLI), v8, is_valid, time32(TimeUnit::SECOND), - options); + options, /*check_scalar=*/false); CheckFails(time64(TimeUnit::MICRO), v8, is_valid, time32(TimeUnit::MILLI), - options); + options, /*check_scalar=*/false); CheckFails(time64(TimeUnit::NANO), v8, is_valid, time64(TimeUnit::MICRO), - options); + options, /*check_scalar=*/false); CheckFails(time64(TimeUnit::MICRO), v9, is_valid, time32(TimeUnit::SECOND), - options); + options, /*check_scalar=*/false); CheckFails(time64(TimeUnit::NANO), v9, is_valid, time32(TimeUnit::MILLI), - options); + options, /*check_scalar=*/false); CheckFails(time64(TimeUnit::NANO), v10, is_valid, time32(TimeUnit::SECOND), - options); + options, /*check_scalar=*/false); } TEST_F(TestCast, DateToCompatible) { @@ -1087,20 +1173,20 @@ TEST_F(TestCast, DateToCompatible) { // Multiply promotion std::vector v1 = {0, 100, 200, 1, 2}; std::vector e1 = {0, 100 * F, 200 * F, F, 2 * F}; - CheckCase(date32(), v1, is_valid, date64(), - e1, options); + CheckCase(date32(), v1, is_valid, date64(), e1, options, + /*check_scalar=*/false); // Zero copy std::vector v2 = {0, 70000, 2000, 1000, 0}; std::vector v3 = {0, 70000, 2000, 1000, 0}; std::shared_ptr arr; - ArrayFromVector(date32(), is_valid, v2, &arr); + ArrayFromVector(date32(), is_valid, v2, &arr); CheckZeroCopy(*arr, date32()); // ARROW-1773: zero copy cast to integer CheckZeroCopy(*arr, int32()); - ArrayFromVector(date64(), is_valid, v3, &arr); + ArrayFromVector(date64(), is_valid, v3, &arr); CheckZeroCopy(*arr, date64()); // ARROW-1773: zero copy cast to integer @@ -1111,12 +1197,12 @@ TEST_F(TestCast, DateToCompatible) { std::vector e8 = {0, 100, 200, 1, 2}; options.allow_time_truncate = true; - CheckCase(date64(), v8, is_valid, date32(), - e8, options); + CheckCase(date64(), v8, is_valid, date32(), e8, options, + /*check_scalar=*/false); // Disallow truncate, failures options.allow_time_truncate = false; - CheckFails(date64(), v8, is_valid, date32(), options); + CheckFails(v8, is_valid, date32(), options, /*check_scalar=*/false); } TEST_F(TestCast, DurationToCompatible) { @@ -1126,9 +1212,9 @@ TEST_F(TestCast, DurationToCompatible) { [this](const CastOptions& options, TimeUnit::type from_unit, TimeUnit::type to_unit, const std::vector& from_values, const std::vector& to_values, const std::vector& is_valid) { - CheckCase( - duration(from_unit), from_values, is_valid, duration(to_unit), to_values, - options); + CheckCase(duration(from_unit), from_values, is_valid, + duration(to_unit), to_values, options, + /*check_scalar=*/false); }; std::vector is_valid = {true, false, true, true, true}; @@ -1161,7 +1247,7 @@ TEST_F(TestCast, DurationToCompatible) { // Zero copy std::vector v7 = {0, 70000, 2000, 1000, 0}; std::shared_ptr arr; - ArrayFromVector(duration(TimeUnit::SECOND), is_valid, v7, &arr); + ArrayFromVector(duration(TimeUnit::SECOND), is_valid, v7, &arr); CheckZeroCopy(*arr, duration(TimeUnit::SECOND)); CheckZeroCopy(*arr, int64()); @@ -1186,17 +1272,17 @@ TEST_F(TestCast, DurationToCompatible) { // Disallow truncate, failures options.allow_time_truncate = false; CheckFails(duration(TimeUnit::MILLI), v8, is_valid, - duration(TimeUnit::SECOND), options); + duration(TimeUnit::SECOND), options, /*check_scalar=*/false); CheckFails(duration(TimeUnit::MICRO), v8, is_valid, - duration(TimeUnit::MILLI), options); + duration(TimeUnit::MILLI), options, /*check_scalar=*/false); CheckFails(duration(TimeUnit::NANO), v8, is_valid, - duration(TimeUnit::MICRO), options); + duration(TimeUnit::MICRO), options, /*check_scalar=*/false); CheckFails(duration(TimeUnit::MICRO), v9, is_valid, - duration(TimeUnit::SECOND), options); + duration(TimeUnit::SECOND), options, /*check_scalar=*/false); CheckFails(duration(TimeUnit::NANO), v9, is_valid, - duration(TimeUnit::MILLI), options); + duration(TimeUnit::MILLI), options, /*check_scalar=*/false); CheckFails(duration(TimeUnit::NANO), v10, is_valid, - duration(TimeUnit::SECOND), options); + duration(TimeUnit::SECOND), options, /*check_scalar=*/false); // Multiply overflow @@ -1205,7 +1291,7 @@ TEST_F(TestCast, DurationToCompatible) { options.allow_time_overflow = false; CheckFails(duration(TimeUnit::SECOND), v11, is_valid, - duration(TimeUnit::NANO), options); + duration(TimeUnit::NANO), options, /*check_scalar=*/false); } TEST_F(TestCast, ToDouble) { @@ -1215,20 +1301,17 @@ TEST_F(TestCast, ToDouble) { // int16 to double std::vector v1 = {0, 100, 200, 1, 2}; std::vector e1 = {0, 100, 200, 1, 2}; - CheckCase(int16(), v1, is_valid, float64(), e1, - options); + CheckCase(v1, is_valid, e1, options); // float to double std::vector v2 = {0, 100, 200, 1, 2}; std::vector e2 = {0, 100, 200, 1, 2}; - CheckCase(float32(), v2, is_valid, float64(), e2, - options); + CheckCase(v2, is_valid, e2, options); // bool to double std::vector v3 = {true, true, false, false, true}; std::vector e3 = {1, 1, 0, 0, 1}; - CheckCase(boolean(), v3, is_valid, float64(), e3, - options); + CheckCase(v3, is_valid, e3, options); } TEST_F(TestCast, ChunkedArray) { @@ -1267,7 +1350,7 @@ TEST_F(TestCast, UnsupportedTarget) { std::vector v1 = {0, 1, 2, 3, 4}; std::shared_ptr arr; - ArrayFromVector(int32(), is_valid, v1, &arr); + ArrayFromVector(int32(), is_valid, v1, &arr); ASSERT_RAISES(NotImplemented, Cast(*arr, list(utf8()))); } @@ -1277,13 +1360,13 @@ TEST_F(TestCast, DateTimeZeroCopy) { std::vector v1 = {0, 70000, 2000, 1000, 0}; std::shared_ptr arr; - ArrayFromVector(int32(), is_valid, v1, &arr); + ArrayFromVector(int32(), is_valid, v1, &arr); CheckZeroCopy(*arr, time32(TimeUnit::SECOND)); CheckZeroCopy(*arr, date32()); std::vector v2 = {0, 70000, 2000, 1000, 0}; - ArrayFromVector(int64(), is_valid, v2, &arr); + ArrayFromVector(int64(), is_valid, v2, &arr); CheckZeroCopy(*arr, time64(TimeUnit::MICRO)); CheckZeroCopy(*arr, date64()); @@ -1299,14 +1382,13 @@ TEST_F(TestCast, StringToBoolean) { std::vector v1 = {"False", "true", "true", "True", "false"}; std::vector v2 = {"0", "1", "1", "1", "0"}; std::vector e = {false, true, true, true, false}; - CheckCase(utf8(), v1, is_valid, boolean(), - e, options); - CheckCase(utf8(), v2, is_valid, boolean(), - e, options); + CheckCase(utf8(), v1, is_valid, boolean(), e, + options); + CheckCase(utf8(), v2, is_valid, boolean(), e, + options); // Same with LargeStringType - CheckCase(large_utf8(), v1, is_valid, - boolean(), e, options); + CheckCase(v1, is_valid, e, options); } TEST_F(TestCast, StringToBooleanErrors) { @@ -1314,10 +1396,9 @@ TEST_F(TestCast, StringToBooleanErrors) { std::vector is_valid = {true}; - CheckFails(utf8(), {"false "}, is_valid, boolean(), options); - CheckFails(utf8(), {"T"}, is_valid, boolean(), options); - CheckFails(large_utf8(), {"T"}, is_valid, boolean(), - options); + CheckFails({"false "}, is_valid, boolean(), options); + CheckFails({"T"}, is_valid, boolean(), options); + CheckFails({"T"}, is_valid, boolean(), options); } TEST_F(TestCast, StringToNumber) { TestCastStringToNumber(); } @@ -1329,16 +1410,16 @@ TEST_F(TestCast, StringToNumberErrors) { std::vector is_valid = {true}; - CheckFails(utf8(), {"z"}, is_valid, int8(), options); - CheckFails(utf8(), {"12 z"}, is_valid, int8(), options); - CheckFails(utf8(), {"128"}, is_valid, int8(), options); - CheckFails(utf8(), {"-129"}, is_valid, int8(), options); - CheckFails(utf8(), {"0.5"}, is_valid, int8(), options); + CheckFails({"z"}, is_valid, int8(), options); + CheckFails({"12 z"}, is_valid, int8(), options); + CheckFails({"128"}, is_valid, int8(), options); + CheckFails({"-129"}, is_valid, int8(), options); + CheckFails({"0.5"}, is_valid, int8(), options); - CheckFails(utf8(), {"256"}, is_valid, uint8(), options); - CheckFails(utf8(), {"-1"}, is_valid, uint8(), options); + CheckFails({"256"}, is_valid, uint8(), options); + CheckFails({"-1"}, is_valid, uint8(), options); - CheckFails(utf8(), {"z"}, is_valid, float32(), options); + CheckFails({"z"}, is_valid, float32(), options); } TEST_F(TestCast, StringToTimestamp) { TestCastStringToTimestamp(); } @@ -1352,8 +1433,8 @@ TEST_F(TestCast, StringToTimestampErrors) { for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI, TimeUnit::MICRO, TimeUnit::NANO}) { auto type = timestamp(unit); - CheckFails(utf8(), {""}, is_valid, type, options); - CheckFails(utf8(), {"xxx"}, is_valid, type, options); + CheckFails({""}, is_valid, type, options); + CheckFails({"xxx"}, is_valid, type, options); } } @@ -1385,7 +1466,7 @@ TEST_F(TestCast, ListToList) { std::vector offsets_values = {0, 1, 2, 5, 7, 7, 8, 10}; std::vector offsets_is_valid = {true, true, true, true, false, true, true, true}; - ArrayFromVector(offsets_is_valid, offsets_values, &offsets); + ArrayFromVector(offsets_is_valid, offsets_values, &offsets); std::shared_ptr int32_plain_array = TestBase::MakeRandomArray::ArrayType>(10, 2); @@ -1402,14 +1483,20 @@ TEST_F(TestCast, ListToList) { ASSERT_OK_AND_ASSIGN(auto float64_list_array, ListArray::FromArrays(*offsets, *float64_plain_array, pool_)); - CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), options); - CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), options); - CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), options); - CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(), options); + CheckPass(*int32_list_array, *int64_list_array, int64_list_array->type(), options, + /*check_scalar=*/false); + CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), options, + /*check_scalar=*/false); + CheckPass(*int64_list_array, *int32_list_array, int32_list_array->type(), options, + /*check_scalar=*/false); + CheckPass(*int64_list_array, *float64_list_array, float64_list_array->type(), options, + /*check_scalar=*/false); options.allow_float_truncate = true; - CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), options); - CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), options); + CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), options, + /*check_scalar=*/false); + CheckPass(*float64_list_array, *int64_list_array, int64_list_array->type(), options, + /*check_scalar=*/false); } TEST_F(TestCast, LargeListToLargeList) { @@ -1419,7 +1506,7 @@ TEST_F(TestCast, LargeListToLargeList) { std::vector offsets_values = {0, 1, 2, 5, 7, 7, 8, 10}; std::vector offsets_is_valid = {true, true, true, true, false, true, true, true}; - ArrayFromVector(offsets_is_valid, offsets_values, &offsets); + ArrayFromVector(offsets_is_valid, offsets_values, &offsets); std::shared_ptr int32_plain_array = TestBase::MakeRandomArray::ArrayType>(10, 2); @@ -1431,10 +1518,12 @@ TEST_F(TestCast, LargeListToLargeList) { ASSERT_OK_AND_ASSIGN(auto float64_list_array, LargeListArray::FromArrays(*offsets, *float64_plain_array, pool_)); - CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), options); + CheckPass(*int32_list_array, *float64_list_array, float64_list_array->type(), options, + /*check_scalar=*/false); options.allow_float_truncate = true; - CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), options); + CheckPass(*float64_list_array, *int32_list_array, int32_list_array->type(), options, + /*check_scalar=*/false); } TEST_F(TestCast, IdentityCasts) { @@ -1544,8 +1633,9 @@ TYPED_TEST(TestDictionaryCast, Basic) { ASSERT_OK_AND_ASSIGN(Datum encoded, DictionaryEncode(plain_array->data())); ASSERT_EQ(encoded.array()->type->id(), Type::DICTIONARY); - this->CheckPass(*MakeArray(encoded.array()), *plain_array, plain_array->type(), - options); + // TODO: Should casting dictionary scalars work? + this->CheckPass(*MakeArray(encoded.array()), *plain_array, plain_array->type(), options, + /*check_scalar=*/false); } TYPED_TEST(TestDictionaryCast, NoNulls) { @@ -1570,7 +1660,8 @@ TYPED_TEST(TestDictionaryCast, NoNulls) { std::shared_ptr dict_array = std::make_shared(data); ASSERT_OK(dict_array->ValidateFull()); - this->CheckPass(*dict_array, *plain_array, plain_array->type(), options); + this->CheckPass(*dict_array, *plain_array, plain_array->type(), options, + /*check_scalar=*/false); } // TODO: See how this might cause problems post-refactor @@ -1610,14 +1701,14 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { // Smallint(int16) to uint8, no overflow/underrun auto v1 = SmallintArrayFromJSON("[0, 100, 200, 1, 2]"); auto e1 = ArrayFromJSON(uint8(), "[0, 100, 200, 1, 2]"); - CheckPass(*v1, *e1, uint8(), options); + CheckPass(*v1, *e1, uint8(), options, /*check_scalar=*/false); // Smallint(int16) to uint8, with overflow auto v2 = SmallintArrayFromJSON("[0, null, 256, 1, 3]"); auto e2 = ArrayFromJSON(uint8(), "[0, null, 0, 1, 3]"); // allow overflow options.allow_int_overflow = true; - CheckPass(*v2, *e2, uint8(), options); + CheckPass(*v2, *e2, uint8(), options, /*check_scalar=*/false); // disallow overflow options.allow_int_overflow = false; ASSERT_RAISES(Invalid, Cast(*v2, uint8(), options)); @@ -1627,7 +1718,7 @@ TEST_F(TestCast, ExtensionTypeToIntDowncast) { auto e3 = ArrayFromJSON(uint8(), "[0, null, 255, 1, 0]"); // allow overflow options.allow_int_overflow = true; - CheckPass(*v3, *e3, uint8(), options); + CheckPass(*v3, *e3, uint8(), options, /*check_scalar=*/false); // disallow overflow options.allow_int_overflow = false; ASSERT_RAISES(Invalid, Cast(*v3, uint8(), options)); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc index 8af575f5874..ce18365fb5d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_benchmark.cc @@ -20,10 +20,10 @@ #include #include "arrow/compute/api_scalar.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index e561da0bde1..2bee05e21d4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -18,10 +18,10 @@ #include "benchmark/benchmark.h" #include "arrow/compute/api_scalar.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc index e76b27146f2..eabac6bacb7 100644 --- a/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_partition_benchmark.cc @@ -18,10 +18,10 @@ #include "benchmark/benchmark.h" #include "arrow/compute/api_vector.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/vector_selection.cc b/cpp/src/arrow/compute/kernels/vector_selection.cc index cc141681afa..653e379ef02 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection.cc @@ -46,10 +46,10 @@ using internal::BinaryBitBlockCounter; using internal::BitBlockCount; using internal::BitBlockCounter; using internal::BitmapReader; +using internal::CheckIndexBounds; using internal::CopyBitmap; using internal::CountSetBits; using internal::GetArrayView; -using internal::IndexBoundsCheck; using internal::OptionalBitBlockCounter; using internal::OptionalBitIndexer; @@ -487,7 +487,7 @@ void TakeIndexDispatch(const PrimitiveArg& values, const PrimitiveArg& indices, void PrimitiveTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& state = checked_cast(*ctx->state()); if (state.options.boundscheck) { - KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length())); + KERNEL_RETURN_IF_ERROR(ctx, CheckIndexBounds(*batch[1].array(), batch[0].length())); } PrimitiveArg values = GetPrimitiveArg(*batch[0].array()); @@ -849,7 +849,7 @@ void PrimitiveFilter(KernelContext* ctx, const ExecBatch& batch, Datum* out) { void NullTake(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& state = checked_cast(*ctx->state()); if (state.options.boundscheck) { - KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length())); + KERNEL_RETURN_IF_ERROR(ctx, CheckIndexBounds(*batch[1].array(), batch[0].length())); } out->value = std::make_shared(batch.length)->data(); } @@ -1737,7 +1737,7 @@ template void TakeExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const auto& state = checked_cast(*ctx->state()); if (state.options.boundscheck) { - KERNEL_RETURN_IF_ERROR(ctx, IndexBoundsCheck(*batch[1].array(), batch[0].length())); + KERNEL_RETURN_IF_ERROR(ctx, CheckIndexBounds(*batch[1].array(), batch[0].length())); } Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out); KERNEL_RETURN_IF_ERROR(ctx, kernel.ExecTake()); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc index 422088c09d3..c595736912d 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc @@ -21,10 +21,10 @@ #include #include "arrow/compute/api_vector.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc index 344de258ccd..663d003c29b 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc @@ -18,10 +18,10 @@ #include "benchmark/benchmark.h" #include "arrow/compute/api_vector.h" -#include "arrow/compute/benchmark_util.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" namespace arrow { namespace compute { diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index d884583f94b..4cab9dca178 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -65,7 +65,7 @@ namespace arrow { class MemoryPool; using internal::checked_cast; -using internal::IndexBoundsCheck; +using internal::CheckIndexBounds; using internal::OptionalParallelFor; // ---------------------------------------------------------------------- @@ -1444,7 +1444,7 @@ class CategoricalWriter const auto& indices = checked_cast(*arr.indices()); auto values = reinterpret_cast(indices.raw_values()); - RETURN_NOT_OK(IndexBoundsCheck(*indices.data(), arr.dictionary()->length())); + RETURN_NOT_OK(CheckIndexBounds(*indices.data(), arr.dictionary()->length())); // Null is -1 in CategoricalBlock for (int i = 0; i < arr.length(); ++i) { if (indices.IsValid(i)) { @@ -1478,7 +1478,7 @@ class CategoricalWriter auto transpose = reinterpret_cast(transpose_buffer->data()); int64_t dict_length = arr.dictionary()->length(); - RETURN_NOT_OK(IndexBoundsCheck(*indices.data(), dict_length)); + RETURN_NOT_OK(CheckIndexBounds(*indices.data(), dict_length)); // Null is -1 in CategoricalBlock for (int i = 0; i < arr.length(); ++i) { @@ -1503,7 +1503,7 @@ class CategoricalWriter if (data.num_chunks() == 1 && indices_first->null_count() == 0) { RETURN_NOT_OK( - IndexBoundsCheck(*indices_first->data(), arr_first.dictionary()->length())); + CheckIndexBounds(*indices_first->data(), arr_first.dictionary()->length())); PyObject* wrapped; npy_intp dims[1] = {static_cast(this->num_rows_)}; diff --git a/cpp/src/arrow/compute/benchmark_util.h b/cpp/src/arrow/util/benchmark_util.h similarity index 83% rename from cpp/src/arrow/compute/benchmark_util.h rename to cpp/src/arrow/util/benchmark_util.h index edd2007c2b2..8379948bcbc 100644 --- a/cpp/src/arrow/compute/benchmark_util.h +++ b/cpp/src/arrow/util/benchmark_util.h @@ -15,20 +15,18 @@ // specific language governing permissions and limitations // under the License. -#pragma once - #include -#include +#include +#include + +#include "benchmark/benchmark.h" -#include "arrow/testing/gtest_util.h" #include "arrow/util/cpu_info.h" namespace arrow { using internal::CpuInfo; -namespace compute { - static CpuInfo* cpu_info = CpuInfo::GetInstance(); static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE); @@ -54,6 +52,32 @@ struct BenchmarkArgsType::type; +struct GenericItemsArgs { + // number of items processed per iteration + const int64_t size; + + // proportion of nulls in generated arrays + double null_proportion; + + explicit GenericItemsArgs(benchmark::State& state) + : size(state.range(0)), state_(state) { + if (state.range(1) == 0) { + this->null_proportion = 0.0; + } else { + this->null_proportion = std::min(1., 1. / static_cast(state.range(1))); + } + } + + ~GenericItemsArgs() { + state_.counters["size"] = static_cast(size); + state_.counters["null_percent"] = null_proportion * 100; + state_.SetItemsProcessed(state_.iterations() * size); + } + + private: + benchmark::State& state_; +}; + void BenchmarkSetArgsWithSizes(benchmark::internal::Benchmark* bench, const std::vector& sizes = kMemorySizes) { bench->Unit(benchmark::kMicrosecond); @@ -111,5 +135,4 @@ struct RegressionArgs { bool size_is_bytes_; }; -} // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/util/int_util.cc b/cpp/src/arrow/util/int_util.cc index d6e3c1e129b..c05e196c099 100644 --- a/cpp/src/arrow/util/int_util.cc +++ b/cpp/src/arrow/util/int_util.cc @@ -22,15 +22,21 @@ #include #include "arrow/array/data.h" +#include "arrow/datum.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/bit_block_counter.h" #include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" +#include "arrow/util/ubsan.h" namespace arrow { namespace internal { +using internal::checked_cast; + static constexpr uint64_t max_uint8 = static_cast(std::numeric_limits::max()); static constexpr uint64_t max_uint16 = @@ -446,8 +452,13 @@ INSTANTIATE_ALL() #undef INSTANTIATE_ALL #undef INSTANTIATE_ALL_DEST +template +std::string FormatInt(T val) { + return std::to_string(val); +} + template ::value> -Status IndexBoundsCheckImpl(const ArrayData& indices, uint64_t upper_limit) { +Status CheckIndexBoundsImpl(const ArrayData& indices, uint64_t upper_limit) { // For unsigned integers, if the values array is larger than the maximum // index value (e.g. especially for UINT8 / UINT16), then there is no need to // boundscheck. @@ -461,43 +472,54 @@ Status IndexBoundsCheckImpl(const ArrayData& indices, uint64_t upper_limit) { if (indices.buffers[0]) { bitmap = indices.buffers[0]->data(); } - auto IsOutOfBounds = [&](int64_t i) -> bool { - return ( - (IsSigned && indices_data[i] < 0) || - (indices_data[i] >= 0 && static_cast(indices_data[i]) >= upper_limit)); + auto IsOutOfBounds = [&](IndexCType val) -> bool { + return ((IsSigned && val < 0) || + (val >= 0 && static_cast(val) >= upper_limit)); + }; + auto IsOutOfBoundsMaybeNull = [&](IndexCType val, bool is_valid) -> bool { + return is_valid && ((IsSigned && val < 0) || + (val >= 0 && static_cast(val) >= upper_limit)); }; OptionalBitBlockCounter indices_bit_counter(bitmap, indices.offset, indices.length); int64_t position = 0; + int64_t offset_position = indices.offset; while (position < indices.length) { BitBlockCount block = indices_bit_counter.NextBlock(); bool block_out_of_bounds = false; if (block.popcount == block.length) { // Fast path: branchless for (int64_t i = 0; i < block.length; ++i) { - block_out_of_bounds |= IsOutOfBounds(i); + block_out_of_bounds |= IsOutOfBounds(indices_data[i]); } } else if (block.popcount > 0) { // Indices have nulls, must only boundscheck non-null values - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(bitmap, indices.offset + position + i)) { - block_out_of_bounds |= IsOutOfBounds(i); + int64_t i = 0; + for (int64_t chunk = 0; chunk < block.length / 8; ++chunk) { + // Let the compiler unroll this + for (int j = 0; j < 8; ++j) { + block_out_of_bounds |= IsOutOfBoundsMaybeNull( + indices_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + ++i; } } + for (; i < block.length; ++i) { + block_out_of_bounds |= IsOutOfBoundsMaybeNull( + indices_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + } } if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { if (indices.GetNullCount() > 0) { for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(bitmap, indices.offset + position + i)) { - if (IsOutOfBounds(i)) { - return Status::IndexError("Index ", static_cast(indices_data[i]), - " out of bounds"); - } + if (IsOutOfBoundsMaybeNull(indices_data[i], + BitUtil::GetBit(bitmap, offset_position + i))) { + return Status::IndexError("Index ", FormatInt(indices_data[i]), + " out of bounds"); } } } else { for (int64_t i = 0; i < block.length; ++i) { - if (IsOutOfBounds(i)) { - return Status::IndexError("Index ", static_cast(indices_data[i]), + if (IsOutOfBounds(indices_data[i])) { + return Status::IndexError("Index ", FormatInt(indices_data[i]), " out of bounds"); } } @@ -505,6 +527,7 @@ Status IndexBoundsCheckImpl(const ArrayData& indices, uint64_t upper_limit) { } indices_data += block.length; position += block.length; + offset_position += block.length; } return Status::OK(); } @@ -512,28 +535,375 @@ Status IndexBoundsCheckImpl(const ArrayData& indices, uint64_t upper_limit) { /// \brief Branchless boundschecking of the indices. Processes batches of /// indices at a time and shortcircuits when encountering an out-of-bounds /// index in a batch -Status IndexBoundsCheck(const ArrayData& indices, uint64_t upper_limit) { +Status CheckIndexBounds(const ArrayData& indices, uint64_t upper_limit) { switch (indices.type->id()) { case Type::INT8: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::INT16: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::INT32: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::INT64: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::UINT8: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::UINT16: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::UINT32: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); case Type::UINT64: - return IndexBoundsCheckImpl(indices, upper_limit); + return CheckIndexBoundsImpl(indices, upper_limit); default: return Status::Invalid("Invalid index type for boundschecking"); } } +// ---------------------------------------------------------------------- +// Utilities for casting from one integer type to another + +template +Status IntegersInRange(const Datum& datum, CType bound_lower, CType bound_upper) { + if (std::numeric_limits::lowest() >= bound_lower && + std::numeric_limits::max() <= bound_upper) { + return Status::OK(); + } + + auto IsOutOfBounds = [&](CType val) -> bool { + return val < bound_lower || val > bound_upper; + }; + auto IsOutOfBoundsMaybeNull = [&](CType val, bool is_valid) -> bool { + return is_valid && (val < bound_lower || val > bound_upper); + }; + auto GetErrorMessage = [&](CType val) { + return Status::Invalid("Integer value ", FormatInt(val), + " not in range: ", FormatInt(bound_lower), " to ", + FormatInt(bound_upper)); + }; + + if (datum.kind() == Datum::SCALAR) { + const auto& scalar = datum.scalar_as::ScalarType>(); + if (IsOutOfBoundsMaybeNull(scalar.value, scalar.is_valid)) { + return GetErrorMessage(scalar.value); + } + return Status::OK(); + } + + const ArrayData& indices = *datum.array(); + const CType* indices_data = indices.GetValues(1); + const uint8_t* bitmap = nullptr; + if (indices.buffers[0]) { + bitmap = indices.buffers[0]->data(); + } + OptionalBitBlockCounter indices_bit_counter(bitmap, indices.offset, indices.length); + int64_t position = 0; + int64_t offset_position = indices.offset; + while (position < indices.length) { + BitBlockCount block = indices_bit_counter.NextBlock(); + bool block_out_of_bounds = false; + if (block.popcount == block.length) { + // Fast path: branchless + int64_t i = 0; + for (int64_t chunk = 0; chunk < block.length / 8; ++chunk) { + // Let the compiler unroll this + for (int j = 0; j < 8; ++j) { + block_out_of_bounds |= IsOutOfBounds(indices_data[i++]); + } + } + for (; i < block.length; ++i) { + block_out_of_bounds |= IsOutOfBounds(indices_data[i]); + } + } else if (block.popcount > 0) { + // Indices have nulls, must only boundscheck non-null values + int64_t i = 0; + for (int64_t chunk = 0; chunk < block.length / 8; ++chunk) { + // Let the compiler unroll this + for (int j = 0; j < 8; ++j) { + block_out_of_bounds |= IsOutOfBoundsMaybeNull( + indices_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + ++i; + } + } + for (; i < block.length; ++i) { + block_out_of_bounds |= IsOutOfBoundsMaybeNull( + indices_data[i], BitUtil::GetBit(bitmap, offset_position + i)); + } + } + if (ARROW_PREDICT_FALSE(block_out_of_bounds)) { + if (indices.GetNullCount() > 0) { + for (int64_t i = 0; i < block.length; ++i) { + if (IsOutOfBoundsMaybeNull(indices_data[i], + BitUtil::GetBit(bitmap, offset_position + i))) { + return GetErrorMessage(indices_data[i]); + } + } + } else { + for (int64_t i = 0; i < block.length; ++i) { + if (IsOutOfBounds(indices_data[i])) { + return GetErrorMessage(indices_data[i]); + } + } + } + } + indices_data += block.length; + position += block.length; + offset_position += block.length; + } + return Status::OK(); +} + +template +Status CheckIntegersInRangeImpl(const Datum& datum, const Scalar& bound_lower, + const Scalar& bound_upper) { + using ScalarType = typename TypeTraits::ScalarType; + return IntegersInRange(datum, checked_cast(bound_lower).value, + checked_cast(bound_upper).value); +} + +Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower, + const Scalar& bound_upper) { + Type::type type_id = datum.type()->id(); + + if (bound_lower.type->id() != type_id || bound_upper.type->id() != type_id || + !bound_lower.is_valid || !bound_upper.is_valid) { + return Status::Invalid("Scalar bound types must be non-null and same type as data"); + } + + switch (type_id) { + case Type::INT8: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::INT16: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::INT32: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::INT64: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::UINT8: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::UINT16: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::UINT32: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + case Type::UINT64: + return CheckIntegersInRangeImpl(datum, bound_lower, bound_upper); + default: + return Status::TypeError("Invalid index type for boundschecking"); + } +} + +template +struct is_number_downcast { + static constexpr bool value = false; +}; + +template +struct is_number_downcast< + O, I, enable_if_t::value && is_number_type::value>> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + static constexpr bool value = + ((!std::is_same::value) && + // Both types are of the same sign-ness. + ((std::is_signed::value == std::is_signed::value) && + // Both types are of the same integral-ness. + (std::is_floating_point::value == std::is_floating_point::value)) && + // Smaller output size + (sizeof(O_T) < sizeof(I_T))); +}; + +template +struct is_number_upcast { + static constexpr bool value = false; +}; + +template +struct is_number_upcast< + O, I, enable_if_t::value && is_number_type::value>> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + static constexpr bool value = + ((!std::is_same::value) && + // Both types are of the same sign-ness. + ((std::is_signed::value == std::is_signed::value) && + // Both types are of the same integral-ness. + (std::is_floating_point::value == std::is_floating_point::value)) && + // Larger output size + (sizeof(O_T) > sizeof(I_T))); +}; + +template +struct is_integral_signed_to_unsigned { + static constexpr bool value = false; +}; + +template +struct is_integral_signed_to_unsigned< + O, I, enable_if_t::value && is_integer_type::value>> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + static constexpr bool value = + ((!std::is_same::value) && + ((std::is_unsigned::value && std::is_signed::value))); +}; + +template +struct is_integral_unsigned_to_signed { + static constexpr bool value = false; +}; + +template +struct is_integral_unsigned_to_signed< + O, I, enable_if_t::value && is_integer_type::value>> { + using O_T = typename O::c_type; + using I_T = typename I::c_type; + + static constexpr bool value = + ((!std::is_same::value) && + ((std::is_signed::value && std::is_unsigned::value))); +}; + +// This set of functions SafeMinimum/SafeMaximum would be simplified with +// C++17 and `if constexpr`. + +// clang-format doesn't handle this construct properly. Thus the macro, but it +// also improves readability. +// +// The effective return type of the function is always `I::c_type`, this is +// just how enable_if works with functions. +#define RET_TYPE(TRAIT) enable_if_t::value, typename I::c_type> + +template +constexpr RET_TYPE(std::is_same) SafeMinimum() { + using out_type = typename O::c_type; + + return std::numeric_limits::lowest(); +} + +template +constexpr RET_TYPE(std::is_same) SafeMaximum() { + using out_type = typename O::c_type; + + return std::numeric_limits::max(); +} + +template +constexpr RET_TYPE(is_number_downcast) SafeMinimum() { + using out_type = typename O::c_type; + + return std::numeric_limits::lowest(); +} + +template +constexpr RET_TYPE(is_number_downcast) SafeMaximum() { + using out_type = typename O::c_type; + + return std::numeric_limits::max(); +} + +template +constexpr RET_TYPE(is_number_upcast) SafeMinimum() { + using in_type = typename I::c_type; + return std::numeric_limits::lowest(); +} + +template +constexpr RET_TYPE(is_number_upcast) SafeMaximum() { + using in_type = typename I::c_type; + return std::numeric_limits::max(); +} + +template +constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMinimum() { + return 0; +} + +template +constexpr RET_TYPE(is_integral_unsigned_to_signed) SafeMaximum() { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + // Equality is missing because in_type::max() > out_type::max() when types + // are of the same width. + return static_cast(sizeof(in_type) < sizeof(out_type) + ? std::numeric_limits::max() + : std::numeric_limits::max()); +} + +template +constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMinimum() { + return 0; +} + +template +constexpr RET_TYPE(is_integral_signed_to_unsigned) SafeMaximum() { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + return static_cast(sizeof(in_type) <= sizeof(out_type) + ? std::numeric_limits::max() + : std::numeric_limits::max()); +} + +#undef RET_TYPE + +#define GET_MIN_MAX_CASE(TYPE, OUT_TYPE) \ + case Type::TYPE: \ + *min = SafeMinimum(); \ + *max = SafeMaximum(); \ + break + +template +void GetSafeMinMax(Type::type out_type, T* min, T* max) { + switch (out_type) { + GET_MIN_MAX_CASE(INT8, Int8Type); + GET_MIN_MAX_CASE(INT16, Int16Type); + GET_MIN_MAX_CASE(INT32, Int32Type); + GET_MIN_MAX_CASE(INT64, Int64Type); + GET_MIN_MAX_CASE(UINT8, UInt8Type); + GET_MIN_MAX_CASE(UINT16, UInt16Type); + GET_MIN_MAX_CASE(UINT32, UInt32Type); + GET_MIN_MAX_CASE(UINT64, UInt64Type); + default: + break; + } +} + +template ::ScalarType> +Status IntegersCanFitImpl(const Datum& datum, const DataType& target_type) { + CType bound_min{}, bound_max{}; + GetSafeMinMax(target_type.id(), &bound_min, &bound_max); + return CheckIntegersInRange(datum, ScalarType(bound_min), ScalarType(bound_max)); +} + +Status IntegersCanFit(const Datum& datum, const DataType& target_type) { + if (!is_integer(target_type.id())) { + return Status::Invalid("Target type is not an integer type: ", target_type); + } + + switch (datum.type()->id()) { + case Type::INT8: + return IntegersCanFitImpl(datum, target_type); + case Type::INT16: + return IntegersCanFitImpl(datum, target_type); + case Type::INT32: + return IntegersCanFitImpl(datum, target_type); + case Type::INT64: + return IntegersCanFitImpl(datum, target_type); + case Type::UINT8: + return IntegersCanFitImpl(datum, target_type); + case Type::UINT16: + return IntegersCanFitImpl(datum, target_type); + case Type::UINT32: + return IntegersCanFitImpl(datum, target_type); + case Type::UINT64: + return IntegersCanFitImpl(datum, target_type); + default: + return Status::TypeError("Invalid index type for boundschecking"); + } +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/int_util.h b/cpp/src/arrow/util/int_util.h index 3131476bfec..c4ed0eb7d5b 100644 --- a/cpp/src/arrow/util/int_util.h +++ b/cpp/src/arrow/util/int_util.h @@ -26,7 +26,10 @@ namespace arrow { +class DataType; struct ArrayData; +struct Datum; +struct Scalar; namespace internal { @@ -122,11 +125,24 @@ UpcastInt(Integer v) { return v; } -/// \brief Do vectorized boundschecking of integer-type indices. The indices -/// must be non-nonnegative and strictly less than the passed upper limit -/// (which is usually the length of an array that is being indexed-into). +/// \brief Do vectorized boundschecking of integer-type array indices. The +/// indices must be non-nonnegative and strictly less than the passed upper +/// limit (which is usually the length of an array that is being indexed-into). ARROW_EXPORT -Status IndexBoundsCheck(const ArrayData& indices, uint64_t upper_limit); +Status CheckIndexBounds(const ArrayData& indices, uint64_t upper_limit); + +/// \brief Boundscheck integer values to determine if they are all between the +/// passed upper and lower limits (inclusive). Upper and lower bounds must be +/// the same type as the data and are not currently casted. +ARROW_EXPORT +Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower, + const Scalar& bound_upper); + +/// \brief Use CheckIntegersInRange to determine whether the passed integers +/// can fit safely in the passed integer type. This helps quickly determine if +/// integer narrowing (e.g. int64->int32) is safe to do. +ARROW_EXPORT +Status IntegersCanFit(const Datum& datum, const DataType& target_type); } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/int_util_benchmark.cc b/cpp/src/arrow/util/int_util_benchmark.cc index 2912ba02456..1eae604a7da 100644 --- a/cpp/src/arrow/util/int_util_benchmark.cc +++ b/cpp/src/arrow/util/int_util_benchmark.cc @@ -20,11 +20,17 @@ #include #include +#include "arrow/array/array_base.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" #include "arrow/util/int_util.h" namespace arrow { namespace internal { +constexpr auto kSeed = 0x94378165; + std::vector GetUIntSequence(int n_values, uint64_t addend = 0) { std::vector values(n_values); for (int i = 0; i < n_values; ++i) { @@ -95,10 +101,43 @@ static void DetectIntWidthNulls(benchmark::State& state) { // NOLINT non-const state.SetBytesProcessed(state.iterations() * values.size() * sizeof(uint64_t)); } +static void CheckIndexBoundsInt32( + benchmark::State& state) { // NOLINT non-const reference + GenericItemsArgs args(state); + random::RandomArrayGenerator rand(kSeed); + auto arr = rand.Int32(args.size, 0, 100000, args.null_proportion); + for (auto _ : state) { + ABORT_NOT_OK(CheckIndexBounds(*arr->data(), 100001)); + } +} + +static void CheckIndexBoundsUInt32( + benchmark::State& state) { // NOLINT non-const reference + GenericItemsArgs args(state); + random::RandomArrayGenerator rand(kSeed); + auto arr = rand.UInt32(args.size, 0, 100000, args.null_proportion); + for (auto _ : state) { + ABORT_NOT_OK(CheckIndexBounds(*arr->data(), 100001)); + } +} + BENCHMARK(DetectUIntWidthNoNulls); BENCHMARK(DetectUIntWidthNulls); BENCHMARK(DetectIntWidthNoNulls); BENCHMARK(DetectIntWidthNulls); +std::vector g_data_sizes = {kL1Size, kL2Size}; + +void BoundsCheckSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (auto nulls : std::vector({1000, 10, 2, 1, 0})) { + bench->Args({static_cast(size), nulls}); + } + } +} + +BENCHMARK(CheckIndexBoundsInt32)->Apply(BoundsCheckSetArgs); +BENCHMARK(CheckIndexBoundsUInt32)->Apply(BoundsCheckSetArgs); + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/int_util_test.cc b/cpp/src/arrow/util/int_util_test.cc index 20e6a2eb42d..6cb0e2506b1 100644 --- a/cpp/src/arrow/util/int_util_test.cc +++ b/cpp/src/arrow/util/int_util_test.cc @@ -24,6 +24,7 @@ #include +#include "arrow/datum.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/type.h" @@ -389,16 +390,16 @@ TEST(TransposeInts, Int8ToInt64) { void BoundsCheckPasses(const std::shared_ptr& type, const std::string& indices_json, uint64_t upper_limit) { auto indices = ArrayFromJSON(type, indices_json); - ASSERT_OK(IndexBoundsCheck(*indices->data(), upper_limit)); + ASSERT_OK(CheckIndexBounds(*indices->data(), upper_limit)); } void BoundsCheckFails(const std::shared_ptr& type, const std::string& indices_json, uint64_t upper_limit) { auto indices = ArrayFromJSON(type, indices_json); - ASSERT_RAISES(IndexError, IndexBoundsCheck(*indices->data(), upper_limit)); + ASSERT_RAISES(IndexError, CheckIndexBounds(*indices->data(), upper_limit)); } -TEST(IndexBoundsCheck, Batching) { +TEST(CheckIndexBounds, Batching) { auto rand = random::RandomArrayGenerator(/*seed=*/0); const int64_t length = 200; @@ -411,25 +412,25 @@ TEST(IndexBoundsCheck, Batching) { uint8_t* bitmap = index_data->buffers[0]->mutable_data(); BitUtil::SetBitsTo(bitmap, 0, length, true); - ASSERT_OK(IndexBoundsCheck(*index_data, 1)); + ASSERT_OK(CheckIndexBounds(*index_data, 1)); // We'll place out of bounds indices at various locations values[99] = 1; - ASSERT_RAISES(IndexError, IndexBoundsCheck(*index_data, 1)); + ASSERT_RAISES(IndexError, CheckIndexBounds(*index_data, 1)); // Make that value null BitUtil::ClearBit(bitmap, 99); - ASSERT_OK(IndexBoundsCheck(*index_data, 1)); + ASSERT_OK(CheckIndexBounds(*index_data, 1)); values[199] = 1; - ASSERT_RAISES(IndexError, IndexBoundsCheck(*index_data, 1)); + ASSERT_RAISES(IndexError, CheckIndexBounds(*index_data, 1)); // Make that value null BitUtil::ClearBit(bitmap, 199); - ASSERT_OK(IndexBoundsCheck(*index_data, 1)); + ASSERT_OK(CheckIndexBounds(*index_data, 1)); } -TEST(IndexBoundsCheck, SignedInts) { +TEST(CheckIndexBounds, SignedInts) { auto CheckCommon = [&](const std::shared_ptr& ty) { BoundsCheckPasses(ty, "[0, 0, 0]", 1); BoundsCheckFails(ty, "[0, 0, 0]", 0); @@ -456,7 +457,7 @@ TEST(IndexBoundsCheck, SignedInts) { BoundsCheckFails(int64(), "[0, 10000000000, 10000000000]", 10000000000LL); } -TEST(IndexBoundsCheck, UnsignedInts) { +TEST(CheckIndexBounds, UnsignedInts) { auto CheckCommon = [&](const std::shared_ptr& ty) { BoundsCheckPasses(ty, "[0, 0, 0]", 1); BoundsCheckFails(ty, "[0, 0, 0]", 0); @@ -483,5 +484,113 @@ TEST(IndexBoundsCheck, UnsignedInts) { BoundsCheckFails(uint64(), "[0, 10000000000, 10000000000]", 10000000000LL); } +void CheckInRangePasses(const std::shared_ptr& type, + const std::string& values_json, const std::string& limits_json) { + auto values = ArrayFromJSON(type, values_json); + auto limits = ArrayFromJSON(type, limits_json); + ASSERT_OK(CheckIntegersInRange(Datum(values->data()), **limits->GetScalar(0), + **limits->GetScalar(1))); +} + +void CheckInRangeFails(const std::shared_ptr& type, + const std::string& values_json, const std::string& limits_json) { + auto values = ArrayFromJSON(type, values_json); + auto limits = ArrayFromJSON(type, limits_json); + ASSERT_RAISES(Invalid, + CheckIntegersInRange(Datum(values->data()), **limits->GetScalar(0), + **limits->GetScalar(1))); +} + +TEST(CheckIntegersInRange, Batching) { + auto rand = random::RandomArrayGenerator(/*seed=*/0); + + const int64_t length = 200; + + auto indices = rand.Int16(length, 0, 0, /*null_probability=*/0); + ArrayData* index_data = indices->data().get(); + index_data->buffers[0] = *AllocateBitmap(length); + + int16_t* values = index_data->GetMutableValues(1); + uint8_t* bitmap = index_data->buffers[0]->mutable_data(); + BitUtil::SetBitsTo(bitmap, 0, length, true); + + auto zero = std::make_shared(0); + auto one = std::make_shared(1); + + ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one)); + + // 1 is included + values[99] = 1; + ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one)); + + // We'll place out of bounds indices at various locations + values[99] = 2; + ASSERT_RAISES(Invalid, CheckIntegersInRange(*index_data, *zero, *one)); + + // Make that value null + BitUtil::ClearBit(bitmap, 99); + ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one)); + + values[199] = 2; + ASSERT_RAISES(Invalid, CheckIntegersInRange(*index_data, *zero, *one)); + + // Make that value null + BitUtil::ClearBit(bitmap, 199); + ASSERT_OK(CheckIntegersInRange(*index_data, *zero, *one)); +} + +TEST(CheckIntegersInRange, SignedInts) { + auto CheckCommon = [&](const std::shared_ptr& ty) { + CheckInRangePasses(ty, "[0, 0, 0]", "[0, 0]"); + CheckInRangeFails(ty, "[0, 1, 0]", "[0, 0]"); + CheckInRangeFails(ty, "[1, 1, 1]", "[2, 4]"); + CheckInRangeFails(ty, "[-1]", "[0, 0]"); + CheckInRangeFails(ty, "[-128]", "[-127, 0]"); + CheckInRangeFails(ty, "[0, 100, 127]", "[0, 126]"); + CheckInRangePasses(ty, "[0, 100, 127]", "[0, 127]"); + }; + + CheckCommon(int8()); + + CheckCommon(int16()); + CheckInRangePasses(int16(), "[0, 999, 999]", "[0, 999]"); + CheckInRangeFails(int16(), "[0, 1000, 1000]", "[0, 999]"); + + CheckCommon(int32()); + CheckInRangePasses(int32(), "[0, 999999, 999999]", "[0, 999999]"); + CheckInRangeFails(int32(), "[0, 1000000, 1000000]", "[0, 999999]"); + + CheckCommon(int64()); + CheckInRangePasses(int64(), "[0, 9999999999, 9999999999]", "[0, 9999999999]"); + CheckInRangeFails(int64(), "[0, 10000000000, 10000000000]", "[0, 9999999999]"); +} + +TEST(CheckIntegersInRange, UnsignedInts) { + auto CheckCommon = [&](const std::shared_ptr& ty) { + CheckInRangePasses(ty, "[0, 0, 0]", "[0, 0]"); + CheckInRangeFails(ty, "[0, 1, 0]", "[0, 0]"); + CheckInRangeFails(ty, "[1, 1, 1]", "[2, 4]"); + CheckInRangeFails(ty, "[0, 100, 200]", "[0, 199]"); + CheckInRangePasses(ty, "[0, 100, 200]", "[0, 200]"); + }; + + CheckCommon(uint8()); + CheckInRangePasses(uint8(), "[255, 255, 255]", "[0, 255]"); + + CheckCommon(uint16()); + CheckInRangePasses(uint16(), "[0, 999, 999]", "[0, 999]"); + CheckInRangeFails(uint16(), "[0, 1000, 1000]", "[0, 999]"); + CheckInRangePasses(uint16(), "[0, 65535]", "[0, 65535]"); + + CheckCommon(uint32()); + CheckInRangePasses(uint32(), "[0, 999999, 999999]", "[0, 999999]"); + CheckInRangeFails(uint32(), "[0, 1000000, 1000000]", "[0, 999999]"); + CheckInRangePasses(uint32(), "[0, 4294967295]", "[0, 4294967295]"); + + CheckCommon(uint64()); + CheckInRangePasses(uint64(), "[0, 9999999999, 9999999999]", "[0, 9999999999]"); + CheckInRangeFails(uint64(), "[0, 10000000000, 10000000000]", "[0, 9999999999]"); +} + } // namespace internal } // namespace arrow diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 883261e2031..cd9fd14a18b 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -1035,8 +1035,7 @@ def test_floating_point_truncate_unsafe(): ] for case in unsafe_cases: # test safe casting raises - with pytest.raises(pa.ArrowInvalid, - match='Floating point value truncated'): + with pytest.raises(pa.ArrowInvalid, match='truncated'): _check_cast_case(case, safe=True) # test unsafe casting truncates @@ -1172,8 +1171,7 @@ def test_decimal_to_decimal(): def test_safe_cast_nan_to_int_raises(): arr = pa.array([np.nan, 1.]) - with pytest.raises(pa.ArrowInvalid, - match='Floating point value truncated'): + with pytest.raises(pa.ArrowInvalid, match='truncated'): arr.cast(pa.int64(), safe=True) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index e99fad95fd6..bd221248792 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1059,8 +1059,7 @@ def test_table_unsafe_casting(): pa.field('d', pa.string()) ]) - with pytest.raises(pa.ArrowInvalid, - match='Floating point value truncated'): + with pytest.raises(pa.ArrowInvalid, match='truncated'): table.cast(target_schema) casted_table = table.cast(target_schema, safe=False) diff --git a/r/tests/testthat/test-Array.R b/r/tests/testthat/test-Array.R index 19c5d9eb973..3dc1f83af01 100644 --- a/r/tests/testthat/test-Array.R +++ b/r/tests/testthat/test-Array.R @@ -283,7 +283,7 @@ test_that("integer types casts (ARROW-3741)", { test_that("integer types cast safety (ARROW-3741, ARROW-5541)", { a <- Array$create(-(1:10)) for (type in uint_types) { - expect_error(a$cast(type), regexp = "Integer value out of bounds") + expect_error(a$cast(type), regexp = "Integer value -1 not in range") expect_error(a$cast(type, safe = FALSE), NA) } })