diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 713ec5b73c0..9d3812e5967 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Implementation of casting to integer or floating point types +#include #include "arrow/array/array_base.h" #include "arrow/array/builder_binary.h" @@ -23,6 +23,7 @@ #include "arrow/compute/kernels/scalar_cast_internal.h" #include "arrow/result.h" #include "arrow/util/formatting.h" +#include "arrow/util/int_util.h" #include "arrow/util/optional.h" #include "arrow/util/utf8.h" #include "arrow/visitor_inline.h" @@ -36,13 +37,13 @@ using util::ValidateUTF8; namespace compute { namespace internal { +namespace { + // ---------------------------------------------------------------------- // Number / Boolean to String -template -struct CastFunctor::value && - (is_number_type::value || is_boolean_type::value)>> { +template +struct NumericToStringCastFunctor { using value_type = typename TypeTraits::CType; using BuilderType = typename TypeTraits::BuilderType; using FormatterType = StringFormatter; @@ -71,7 +72,7 @@ struct CastFunctor -struct BinaryToStringSameWidthCastFunctor { +struct CastBinaryToBinaryOffsets; + +// Cast same-width offsets (no-op) +template <> +struct CastBinaryToBinaryOffsets { + static void CastOffsets(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + } +}; +template <> +struct CastBinaryToBinaryOffsets { + static void CastOffsets(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + } +}; + +// Upcast offsets +template <> +struct CastBinaryToBinaryOffsets { + static void CastOffsets(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + using input_offset_type = int32_t; + using output_offset_type = int64_t; + KERNEL_ASSIGN_OR_RAISE(output->buffers[1], ctx, + ctx->Allocate((output->length + output->offset + 1) * + sizeof(output_offset_type))); + memset(output->buffers[1]->mutable_data(), 0, + output->offset * sizeof(output_offset_type)); + ::arrow::internal::CastInts(input.GetValues(1), + output->GetMutableValues(1), + output->length + 1); + } +}; + +// Downcast offsets +template <> +struct CastBinaryToBinaryOffsets { + static void CastOffsets(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + using input_offset_type = int64_t; + using output_offset_type = int32_t; + + constexpr input_offset_type kMaxOffset = + std::numeric_limits::max(); + + auto input_offsets = input.GetValues(1); + + // Binary offsets are ascending, so it's enough to check the last one for overflow. + if (input_offsets[input.length] > kMaxOffset) { + ctx->SetStatus(Status::Invalid("Failed casting from ", input.type->ToString(), + " to ", output->type->ToString(), + ": input array too large")); + } else { + KERNEL_ASSIGN_OR_RAISE(output->buffers[1], ctx, + ctx->Allocate((output->length + output->offset + 1) * + sizeof(output_offset_type))); + memset(output->buffers[1]->mutable_data(), 0, + output->offset * sizeof(output_offset_type)); + ::arrow::internal::CastInts(input.GetValues(1), + output->GetMutableValues(1), + output->length + 1); + } + } +}; + +template +struct BinaryToBinaryCastFunctor { + using input_offset_type = typename I::offset_type; + using output_offset_type = typename O::offset_type; + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const CastOptions& options = checked_cast(*ctx->state()).options; - if (!options.allow_invalid_utf8) { + const ArrayData& input = *batch[0].array(); + + if (!I::is_utf8 && O::is_utf8 && !options.allow_invalid_utf8) { InitializeUTF8(); - const ArrayData& input = *batch[0].array(); ArrayDataVisitor visitor; Utf8Validator validator; @@ -107,75 +174,84 @@ struct BinaryToStringSameWidthCastFunctor { return; } } - // It's OK to call this because base binary types do not preallocate - // anything + + // Start with a zero-copy cast, but change indices to expected size ZeroCopyCastExec(ctx, batch, out); + CastBinaryToBinaryOffsets::CastOffsets( + ctx, input, out->mutable_array()); } }; -template <> -struct CastFunctor - : public BinaryToStringSameWidthCastFunctor {}; - -template <> -struct CastFunctor - : public BinaryToStringSameWidthCastFunctor {}; - #if defined(_MSC_VER) #pragma warning(pop) #endif -// String casts available -// -// * Numbers and boolean to String / LargeString -// * Binary / LargeBinary to String / LargeString with UTF8 validation +// ---------------------------------------------------------------------- +// Cast functions registration template -void AddNumberToStringCasts(std::shared_ptr out_ty, CastFunction* func) { +void AddNumberToStringCasts(CastFunction* func) { + auto out_ty = TypeTraits::type_singleton(); + DCHECK_OK(func->AddKernel(Type::BOOL, {boolean()}, out_ty, - CastFunctor::Exec, + NumericToStringCastFunctor::Exec, NullHandling::COMPUTED_NO_PREALLOCATE)); for (const std::shared_ptr& in_ty : NumericTypes()) { - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, out_ty, - GenerateNumeric(*in_ty), - NullHandling::COMPUTED_NO_PREALLOCATE)); + DCHECK_OK( + func->AddKernel(in_ty->id(), {in_ty}, out_ty, + GenerateNumeric(*in_ty), + NullHandling::COMPUTED_NO_PREALLOCATE)); } } +template +void AddBinaryToBinaryCast(CastFunction* func) { + auto in_ty = TypeTraits::type_singleton(); + auto out_ty = TypeTraits::type_singleton(); + + DCHECK_OK(func->AddKernel(OutType::type_id, {in_ty}, out_ty, + BinaryToBinaryCastFunctor::Exec, + NullHandling::COMPUTED_NO_PREALLOCATE)); +} + +} // namespace + std::vector> GetBinaryLikeCasts() { auto cast_binary = std::make_shared("cast_binary", Type::BINARY); AddCommonCasts(Type::BINARY, binary(), cast_binary.get()); - AddZeroCopyCast(Type::STRING, {utf8()}, binary(), cast_binary.get()); + AddBinaryToBinaryCast(cast_binary.get()); + AddBinaryToBinaryCast(cast_binary.get()); + AddBinaryToBinaryCast(cast_binary.get()); auto cast_large_binary = std::make_shared("cast_large_binary", Type::LARGE_BINARY); AddCommonCasts(Type::LARGE_BINARY, large_binary(), cast_large_binary.get()); - AddZeroCopyCast(Type::LARGE_STRING, {large_utf8()}, large_binary(), - cast_large_binary.get()); - - auto cast_fsb = - std::make_shared("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY); - AddCommonCasts(Type::FIXED_SIZE_BINARY, OutputType(ResolveOutputFromOptions), - cast_fsb.get()); + AddBinaryToBinaryCast(cast_large_binary.get()); + AddBinaryToBinaryCast(cast_large_binary.get()); + AddBinaryToBinaryCast(cast_large_binary.get()); auto cast_string = std::make_shared("cast_string", Type::STRING); AddCommonCasts(Type::STRING, utf8(), cast_string.get()); - AddNumberToStringCasts(utf8(), cast_string.get()); - DCHECK_OK(cast_string->AddKernel(Type::BINARY, {binary()}, utf8(), - CastFunctor::Exec, - NullHandling::COMPUTED_NO_PREALLOCATE)); + AddNumberToStringCasts(cast_string.get()); + AddBinaryToBinaryCast(cast_string.get()); + AddBinaryToBinaryCast(cast_string.get()); + AddBinaryToBinaryCast(cast_string.get()); auto cast_large_string = std::make_shared("cast_large_string", Type::LARGE_STRING); AddCommonCasts(Type::LARGE_STRING, large_utf8(), cast_large_string.get()); - AddNumberToStringCasts(large_utf8(), cast_large_string.get()); - DCHECK_OK( - cast_large_string->AddKernel(Type::LARGE_BINARY, {large_binary()}, large_utf8(), - CastFunctor::Exec, - NullHandling::COMPUTED_NO_PREALLOCATE)); + AddNumberToStringCasts(cast_large_string.get()); + AddBinaryToBinaryCast(cast_large_string.get()); + AddBinaryToBinaryCast(cast_large_string.get()); + AddBinaryToBinaryCast(cast_large_string.get()); + + auto cast_fsb = + std::make_shared("cast_fixed_size_binary", Type::FIXED_SIZE_BINARY); + AddCommonCasts(Type::FIXED_SIZE_BINARY, OutputType(ResolveOutputFromOptions), + cast_fsb.get()); - return {cast_binary, cast_fsb, cast_large_binary, cast_string, cast_large_string}; + return {cast_binary, cast_large_binary, cast_string, cast_large_string, cast_fsb}; } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index c340a8a50be..68d6f09c76d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -223,7 +223,7 @@ class TestCast : public TestBase { } template - void TestCastBinaryToString() { + void TestCastBinaryToBinary() { CastOptions options; auto src_type = TypeTraits::type_singleton(); auto dest_type = TypeTraits::type_singleton(); @@ -233,41 +233,28 @@ class TestCast : public TestBase { std::vector valid = {1, 1, 1, 1, 0}; std::vector strings = {"Hi", "olá mundo", "你好世界", "", kInvalidUtf8}; - std::shared_ptr array; - // Should accept when invalid but null. - ArrayFromVector(src_type, valid, strings, &array); - CheckZeroCopy(*array, dest_type); - - // Should refuse due to invalid utf8 payload - CheckFails(strings, all, dest_type, options, - /*check_scalar=*/false); - - // Should accept due to option override - options.allow_invalid_utf8 = true; - CheckCase(strings, all, strings, options, - /*check_scalar=*/false, /*validate_full=*/false); - } - - template - void TestCastStringToBinary() { - CastOptions options; - auto src_type = TypeTraits::type_singleton(); - auto dest_type = TypeTraits::type_singleton(); - - // All valid except the last one - std::vector all = {1, 1, 1, 1, 1}; - std::vector valid = {1, 1, 1, 1, 0}; - std::vector strings = {"Hi", "olá mundo", "你好世界", "", kInvalidUtf8}; - - std::shared_ptr array; - - // Should accept when invalid but null. - ArrayFromVector(src_type, valid, strings, &array); - CheckZeroCopy(*array, dest_type); - - CheckCase(src_type, strings, all, dest_type, strings, options, + CheckCase(strings, valid, strings, options, /*check_scalar=*/false); + + // Should accept empty array + CheckCaseJSON(src_type, dest_type, "[]", "[]", /*check_scalar=*/false); + + if (!SourceType::is_utf8 && DestType::is_utf8) { + // Should refuse due to invalid utf8 payload + CheckFails(strings, all, dest_type, options, + /*check_scalar=*/false); + // Should accept due to option override + options.allow_invalid_utf8 = true; + CheckCase(strings, all, strings, options, + /*check_scalar=*/false, /*validate_full=*/false); + } else { + // Destination type allows non-utf8 data, + // or source type also enforces utf8 data. + const bool validate_full = !DestType::is_utf8; + CheckCase(strings, all, strings, options, + /*check_scalar=*/false, validate_full); + } } template @@ -1577,16 +1564,48 @@ TEST_F(TestCast, StringToTimestampErrors) { } } -TEST_F(TestCast, BinaryToString) { TestCastBinaryToString(); } +TEST_F(TestCast, BinaryToString) { TestCastBinaryToBinary(); } + +TEST_F(TestCast, BinaryToLargeBinary) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, BinaryToLargeString) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, LargeBinaryToBinary) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, LargeBinaryToString) { + TestCastBinaryToBinary(); +} TEST_F(TestCast, LargeBinaryToLargeString) { - TestCastBinaryToString(); + TestCastBinaryToBinary(); } -TEST_F(TestCast, StringToBinary) { TestCastStringToBinary(); } +TEST_F(TestCast, StringToBinary) { TestCastBinaryToBinary(); } + +TEST_F(TestCast, StringToLargeBinary) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, StringToLargeString) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, LargeStringToBinary) { + TestCastBinaryToBinary(); +} + +TEST_F(TestCast, LargeStringToString) { + TestCastBinaryToBinary(); +} TEST_F(TestCast, LargeStringToLargeBinary) { - TestCastStringToBinary(); + TestCastBinaryToBinary(); } TEST_F(TestCast, NumberToString) { TestCastNumberToString(); } diff --git a/cpp/src/arrow/util/int_util.cc b/cpp/src/arrow/util/int_util.cc index ee7f2ec956c..ad7d88d7902 100644 --- a/cpp/src/arrow/util/int_util.cc +++ b/cpp/src/arrow/util/int_util.cc @@ -57,7 +57,7 @@ static const uint64_t max_uints[] = {0, max_uint8, max_uint16, 0, max_ui 0, 0, 0, max_uint64}; // Check if we would need to expand the underlying storage type -inline uint8_t ExpandedUIntWidth(uint64_t val, uint8_t current_width) { +static inline uint8_t ExpandedUIntWidth(uint64_t val, uint8_t current_width) { // Optimize for the common case where width doesn't change if (ARROW_PREDICT_TRUE(val <= max_uints[current_width])) { return current_width; @@ -364,7 +364,7 @@ uint8_t DetectIntWidth(const int64_t* values, const uint8_t* valid_bytes, int64_ } template -inline void DowncastIntsInternal(const Source* src, Dest* dest, int64_t length) { +static inline void CastIntsInternal(const Source* src, Dest* dest, int64_t length) { while (length >= 4) { dest[0] = static_cast(src[0]); dest[1] = static_cast(src[1]); @@ -381,15 +381,15 @@ inline void DowncastIntsInternal(const Source* src, Dest* dest, int64_t length) } void DowncastInts(const int64_t* source, int8_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastInts(const int64_t* source, int16_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastInts(const int64_t* source, int32_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastInts(const int64_t* source, int64_t* dest, int64_t length) { @@ -397,21 +397,25 @@ void DowncastInts(const int64_t* source, int64_t* dest, int64_t length) { } void DowncastUInts(const uint64_t* source, uint8_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastUInts(const uint64_t* source, uint16_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastUInts(const uint64_t* source, uint32_t* dest, int64_t length) { - DowncastIntsInternal(source, dest, length); + CastIntsInternal(source, dest, length); } void DowncastUInts(const uint64_t* source, uint64_t* dest, int64_t length) { memcpy(dest, source, length * sizeof(int64_t)); } +void UpcastInts(const int32_t* source, int64_t* dest, int64_t length) { + CastIntsInternal(source, dest, length); +} + template void TransposeInts(const InputInt* src, OutputInt* dest, int64_t length, const int32_t* transpose_map) { @@ -461,12 +465,12 @@ INSTANTIATE_ALL() #undef INSTANTIATE_ALL_DEST template -std::string FormatInt(T val) { +static std::string FormatInt(T val) { return std::to_string(val); } template ::value> -Status CheckIndexBoundsImpl(const ArrayData& indices, uint64_t upper_limit) { +static Status CheckIndexBoundsImpl(const ArrayData& indices, uint64_t upper_limit) { // For unsigned integers, if the values array is larger than the maximum // index value (e.g. especially for UINT8 / UINT16), then there is no need to // boundscheck. @@ -569,6 +573,8 @@ Status CheckIndexBounds(const ArrayData& indices, uint64_t upper_limit) { // ---------------------------------------------------------------------- // Utilities for casting from one integer type to another +namespace { + template Status IntegersInRange(const Datum& datum, CType bound_lower, CType bound_upper) { if (std::numeric_limits::lowest() >= bound_lower && @@ -667,6 +673,8 @@ Status CheckIntegersInRangeImpl(const Datum& datum, const Scalar& bound_lower, checked_cast(bound_upper).value); } +} // namespace + Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower, const Scalar& bound_upper) { Type::type type_id = datum.type()->id(); @@ -698,6 +706,8 @@ Status CheckIntegersInRange(const Datum& datum, const Scalar& bound_lower, } } +namespace { + template struct is_number_downcast { static constexpr bool value = false; @@ -886,6 +896,8 @@ Status IntegersCanFitImpl(const Datum& datum, const DataType& target_type) { return CheckIntegersInRange(datum, ScalarType(bound_min), ScalarType(bound_max)); } +} // namespace + Status IntegersCanFit(const Datum& datum, const DataType& target_type) { if (!is_integer(target_type.id())) { return Status::Invalid("Target type is not an integer type: ", target_type); diff --git a/cpp/src/arrow/util/int_util.h b/cpp/src/arrow/util/int_util.h index 11fd9745a03..4db1624f231 100644 --- a/cpp/src/arrow/util/int_util.h +++ b/cpp/src/arrow/util/int_util.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include "arrow/status.h" #include "arrow/util/visibility.h" @@ -69,6 +70,21 @@ void DowncastUInts(const uint64_t* source, uint32_t* dest, int64_t length); ARROW_EXPORT void DowncastUInts(const uint64_t* source, uint64_t* dest, int64_t length); +ARROW_EXPORT +void UpcastInts(const int32_t* source, int64_t* dest, int64_t length); + +template +inline typename std::enable_if<(sizeof(InputInt) >= sizeof(OutputInt))>::type CastInts( + const InputInt* source, OutputInt* dest, int64_t length) { + DowncastInts(source, dest, length); +} + +template +inline typename std::enable_if<(sizeof(InputInt) < sizeof(OutputInt))>::type CastInts( + const InputInt* source, OutputInt* dest, int64_t length) { + UpcastInts(source, dest, length); +} + template ARROW_EXPORT void TransposeInts(const InputInt* source, OutputInt* dest, int64_t length, const int32_t* transpose_map);