From 15779bf288487144fa7bb24458aaaef8ae05082c Mon Sep 17 00:00:00 2001 From: Jacek Pliszka Date: Fri, 14 Feb 2020 19:45:51 +0100 Subject: [PATCH 1/5] ARROW-3329 [C++] Added casts Decimal128 to Decimal128 and Decimal128 to Int64 --- cpp/src/arrow/compute/kernels/cast.cc | 268 ++++++++++++++++++ cpp/src/arrow/compute/kernels/cast.h | 3 + cpp/src/arrow/compute/kernels/cast_test.cc | 165 +++++++++++ .../kernels/generated/cast_codegen_internal.h | 4 + .../compute/kernels/generated/codegen.py | 7 +- cpp/src/arrow/type.h | 1 + cpp/src/arrow/type_fwd.h | 1 + 7 files changed, 446 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index c0b68539e80..14d1ba0ddf7 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -404,6 +404,272 @@ struct CastFunctor< } }; +// Decimal to Integer + +template +struct is_decimal_truncate { + static constexpr bool value = false; +}; + +template +struct is_decimal_truncate< + O, I, enable_if_t<(is_integer_type::value && is_decimal_type::value)>> { + static constexpr bool value = true; +}; + +template +struct CastFunctor::value>> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + const auto& in_type_inst = checked_cast(*input.type); + auto in_scale = in_type_inst.scale(); + + const in_type* in_data = input.GetValues(1); + auto out_data = output->GetMutableValues(1); + + if (input.null_count == 0) { // no nulls + if (options.allow_decimal_truncate) { // no nulls, truncate + if (options.allow_int_overflow) { // no nulls, truncate, overflow + if (in_scale < 0) { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = + static_cast(in_data++->IncreaseScaleBy(-in_scale).low_bits()); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast( + in_data++->ReduceScaleBy(in_scale, false).low_bits()); + } + } + } else { // no nulls, truncate, no overflow + auto min_value = std::numeric_limits::min(); + auto max_value = std::numeric_limits::max(); + + if (in_scale < 0) { + for (int64_t i = 0; i < input.length; ++i) { + auto result = in_data++->IncreaseScaleBy(-in_scale); + if (result < min_value || result > max_value) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data++ = static_cast(result.low_bits()); + } + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + auto result = in_data++->ReduceScaleBy(in_scale, false); + if (result < min_value || result > max_value) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data++ = static_cast(result.low_bits()); + } + } + } + } + } else { // no nulls, no truncate + if (options.allow_int_overflow) { // no nulls, no truncate, overflow + for (int64_t i = 0; i < input.length; ++i) { + auto result = in_data++->Rescale(in_scale, 0); + if (result.ok()) { + auto result_value = result.ValueOrDie(); + *out_data++ = static_cast(result_value.low_bits()); + } else { + ctx->SetStatus(result.status()); + } + } + } else { // no nulls, no truncate, no overflow + auto min_value = std::numeric_limits::min(); + auto max_value = std::numeric_limits::max(); + for (int64_t i = 0; i < input.length; ++i) { + auto result = in_data++->Rescale(in_scale, 0); + if (result.ok()) { + auto result_value = result.ValueOrDie(); + if (result_value < min_value || result_value > max_value) { + ctx->SetStatus(Status::Invalid("Invalid cast from Decimal128 to ", + sizeof(out_type), " byte integer")); + } else { + *out_data++ = static_cast(result_value.low_bits()); + } + } else { + ctx->SetStatus(result.status()); + } + } + } + } + } else { // nulls + if (options.allow_decimal_truncate) { // nulls, truncate + if (options.allow_int_overflow) { // nulls, truncate, overflow + if (in_scale < 0) { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = + static_cast(in_data++->IncreaseScaleBy(-in_scale).low_bits()); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast( + in_data++->ReduceScaleBy(in_scale, false).low_bits()); + } + } + } else { // nulls, truncate, no overflow + auto min_value = std::numeric_limits::min(); + auto max_value = std::numeric_limits::max(); + internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, + input.length); + + if (in_scale < 0) { + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { + if (*in_data > max_value || *in_data < min_value) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast( + in_data->IncreaseScaleBy(-in_scale).low_bits()); + } + } + out_data++; + in_data++; + is_valid_reader.Next(); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { + if (*in_data > max_value || *in_data < min_value) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast( + in_data->ReduceScaleBy(in_scale, false).low_bits()); + } + } + out_data++; + in_data++; + is_valid_reader.Next(); + } + } + } + } else { // nulls, no truncate + internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, + input.length); + + if (options.allow_int_overflow) { // nulls, no truncate, overflow + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { + auto result = in_data->Rescale(in_scale, 0); + if (result.ok()) { + auto result_value = result.ValueOrDie(); + *out_data = static_cast(result_value.low_bits()); + } else { + ctx->SetStatus(result.status()); + } + } + out_data++; + in_data++; + is_valid_reader.Next(); + } + } else { // nulls, no truncate, no overflow + auto min_value = std::numeric_limits::min(); + auto max_value = std::numeric_limits::max(); + + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { + if (*in_data > max_value || *in_data < min_value) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + auto result = in_data->Rescale(in_scale, 0); + if (result.ok()) { + auto result_value = result.ValueOrDie(); + *out_data = static_cast(result_value.low_bits()); + } else { + ctx->SetStatus(result.status()); + } + } + } + out_data++; + in_data++; + is_valid_reader.Next(); + } + } + } + } + } +}; + +// ---------------------------------------------------------------------- +// Decimal to Decimal + +template +struct CastFunctor< + O, I, enable_if_t<(is_decimal_type::value && is_decimal_type::value)>> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using in_type = typename I::c_type; + using out_type = typename O::c_type; + + const auto& in_type_inst = checked_cast(*input.type); + const auto& out_type_inst = checked_cast(*output->type); + auto in_scale = in_type_inst.scale(); + auto out_scale = out_type_inst.scale(); + + const in_type* in_data = input.GetValues(1); + auto out_data = output->GetMutableValues(1); + + if (input.null_count == 0) { // no nulls + if (options.allow_decimal_truncate) { // no nulls, truncate + if (in_scale < out_scale) { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = + static_cast(in_data++->IncreaseScaleBy(out_scale - in_scale)); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast( + in_data++->ReduceScaleBy(in_scale - out_scale, false)); + } + } + } else { // no nulls, no truncate + for (int64_t i = 0; i < input.length; ++i) { + auto result = in_data++->Rescale(in_scale, out_scale); + if (result.ok()) { + *out_data = static_cast(result.ValueOrDie()); + } else { + ctx->SetStatus(result.status()); + } + out_data++; + } + } + } else { // nulls + if (options.allow_decimal_truncate) { // nulls, truncate + if (in_scale < out_scale) { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = + static_cast(in_data++->IncreaseScaleBy(out_scale - in_scale)); + } + } else { + for (int64_t i = 0; i < input.length; ++i) { + *out_data++ = static_cast( + in_data++->ReduceScaleBy(in_scale - out_scale, false)); + } + } + } else { // nulls, no truncate + internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, + input.length); + for (int64_t i = 0; i < input.length; ++i) { + if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { + auto result = in_data++->Rescale(in_scale, out_scale); + if (result.ok()) { + *out_data = static_cast(result.ValueOrDie()); + } else { + ctx->SetStatus(result.status()); + } + } + out_data++; + is_valid_reader.Next(); + } + } + } + } +}; + // ---------------------------------------------------------------------- // From one timestamp to another @@ -1203,6 +1469,7 @@ GET_CAST_FUNCTION(UINT64_CASES, UInt64Type, CastKernel) GET_CAST_FUNCTION(INT64_CASES, Int64Type, CastKernel) GET_CAST_FUNCTION(FLOAT_CASES, FloatType, CastKernel) GET_CAST_FUNCTION(DOUBLE_CASES, DoubleType, CastKernel) +GET_CAST_FUNCTION(DECIMAL128_CASES, Decimal128Type, CastKernel) GET_CAST_FUNCTION(DATE32_CASES, Date32Type, CastKernel) GET_CAST_FUNCTION(DATE64_CASES, Date64Type, CastKernel) GET_CAST_FUNCTION(TIME32_CASES, Time32Type, CastKernel) @@ -1293,6 +1560,7 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty CAST_FUNCTION_CASE(Int64Type); CAST_FUNCTION_CASE(FloatType); CAST_FUNCTION_CASE(DoubleType); + CAST_FUNCTION_CASE(Decimal128Type); CAST_FUNCTION_CASE(Date32Type); CAST_FUNCTION_CASE(Date64Type); CAST_FUNCTION_CASE(Time32Type); diff --git a/cpp/src/arrow/compute/kernels/cast.h b/cpp/src/arrow/compute/kernels/cast.h index c5b873195f8..fec63a91026 100644 --- a/cpp/src/arrow/compute/kernels/cast.h +++ b/cpp/src/arrow/compute/kernels/cast.h @@ -38,6 +38,7 @@ struct ARROW_EXPORT CastOptions { : allow_int_overflow(false), allow_time_truncate(false), allow_time_overflow(false), + allow_decimal_truncate(false), allow_float_truncate(false), allow_invalid_utf8(false) {} @@ -45,6 +46,7 @@ struct ARROW_EXPORT CastOptions { : allow_int_overflow(!safe), allow_time_truncate(!safe), allow_time_overflow(!safe), + allow_decimal_truncate(!safe), allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} @@ -55,6 +57,7 @@ struct ARROW_EXPORT CastOptions { bool allow_int_overflow; bool allow_time_truncate; bool allow_time_overflow; + bool allow_decimal_truncate; bool allow_float_truncate; // Indicate if conversions from Binary/FixedSizeBinary to string must // validate the utf8 payload. diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index 7198a10c4d1..cd00f06aac0 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -604,6 +604,171 @@ TEST_F(TestCast, IntToFloatingPoint) { } #endif +TEST_F(TestCast, DecimalToInt) { + CastOptions options; + std::vector is_valid2 = {true, true}; + std::vector is_valid3 = {true, true, false}; + + // no overflow no truncation + std::vector v12 = {Decimal128("02.0000000000"), + Decimal128("-11.0000000000")}; + std::vector v13 = {Decimal128("02.0000000000"), + Decimal128("-11.0000000000"), + Decimal128("-12.0000000000")}; + std::vector e12 = {2, -11}; + std::vector e13 = {2, -11, 0}; + + for (int i = 0; i < 4; i++) { + options.allow_int_overflow = bool(i / 2); + options.allow_decimal_truncate = bool(i % 2); + CheckCase( + decimal(38, 10), v12, is_valid2, int64(), e12, options); + CheckCase( + decimal(38, 10), v13, is_valid3, int64(), e13, options); + } + + // truncation, no overflow + std::vector v22 = {Decimal128("02.1000000000"), + Decimal128("-11.0000004500")}; + std::vector v23 = {Decimal128("02.1000000000"), + Decimal128("-11.0000004500"), + Decimal128("-12.0000004500")}; + std::vector e22 = {2, -11}; + std::vector e23 = {2, -11, 0}; + + for (int i = 0; i < 4; i++) { + options.allow_int_overflow = bool(i / 2); + options.allow_decimal_truncate = bool(i % 2); + if (options.allow_decimal_truncate) { + CheckCase( + decimal(38, 10), v22, is_valid2, int64(), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, int64(), e23, options); + } else { + CheckFails(decimal(38, 10), v22, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v23, is_valid3, int64(), options); + } + } + + // overflow, no truncation + std::vector v32 = {Decimal128("12345678901234567890000.0000000000"), + Decimal128("99999999999999999999999.0000000000")}; + std::vector v33 = {Decimal128("12345678901234567890000.0000000000"), + Decimal128("99999999999999999999999.0000000000"), + Decimal128("99999999999999999999999.0000000000")}; + // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64 + std::vector e32 = {4807115922877858896, 200376420520689663}; + std::vector e33 = {4807115922877858896, 200376420520689663, -2}; + + for (int i = 0; i < 4; i++) { + options.allow_int_overflow = bool(i / 2); + options.allow_decimal_truncate = bool(i % 2); + if (options.allow_int_overflow) { + CheckCase( + decimal(38, 10), v32, is_valid2, int64(), e32, options); + CheckCase( + decimal(38, 10), v33, is_valid3, int64(), e33, options); + } else { + CheckFails(decimal(38, 10), v32, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v33, is_valid3, int64(), options); + } + } + + // overflow, truncation + std::vector v42 = {Decimal128("12345678901234567890000.0045345000"), + Decimal128("99999999999999999999999.0000005430")}; + std::vector v43 = {Decimal128("12345678901234567890000.0005345340"), + Decimal128("99999999999999999999999.0000344300"), + Decimal128("99999999999999999999999.0004354000")}; + // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64 + std::vector e42 = {4807115922877858896, 200376420520689663}; + std::vector e43 = {4807115922877858896, 200376420520689663, -2}; + + for (int i = 0; i < 4; i++) { + options.allow_int_overflow = bool(i / 2); + options.allow_decimal_truncate = bool(i % 2); + if (options.allow_int_overflow && options.allow_decimal_truncate) { + CheckCase( + decimal(38, 10), v42, is_valid2, int64(), e42, options); + CheckCase( + decimal(38, 10), v43, is_valid3, int64(), e43, options); + } else { + CheckFails(decimal(38, 10), v42, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v43, is_valid3, int64(), options); + } + } + + // negative scale + std::vector v5 = {Decimal128("1234567890000."), Decimal128("-120000.")}; + for (int i = 0; i < 2; i++) v5[i] = v5[i].Rescale(0, -4).ValueOrDie(); + std::vector e5 = {1234567890000, -120000}; + CheckCase( + decimal(38, -4), v5, is_valid2, int64(), e5, options); +} + +TEST_F(TestCast, DecimalToDecimal) { + CastOptions options; + + std::vector is_valid2 = {true, true}; + std::vector is_valid3 = {true, true, false}; + + // simple cases decimal + + std::vector v12 = {Decimal128("02.0000000000"), + Decimal128("30.0000000000")}; + std::vector e12 = {Decimal128("02."), Decimal128("30.")}; + std::vector v13 = {Decimal128("02.0000000000"), Decimal128("30.0000000000"), + Decimal128("30.0000000000")}; + std::vector e13 = {Decimal128("02."), Decimal128("30."), Decimal128("-1.")}; + + for (int i = 0; i < 2; i++) { + options.allow_decimal_truncate = bool(i % 2); + CheckCase( + decimal(38, 10), v12, is_valid2, decimal(28, 0), e12, options); + CheckCase( + decimal(38, 10), v13, is_valid3, decimal(28, 0), e13, options); + // and back + CheckCase( + decimal(28, 0), e12, is_valid2, decimal(38, 10), v12, options); + CheckCase( + decimal(28, 0), e13, is_valid3, decimal(38, 10), v13, options); + } + + std::vector v22 = {Decimal128("-02.1234567890"), + Decimal128("30.1234567890")}; + std::vector e22 = {Decimal128("-02."), Decimal128("30.")}; + std::vector f22 = {Decimal128("-02.0000000000"), + Decimal128("30.0000000000")}; + std::vector v23 = {Decimal128("-02.1234567890"), + Decimal128("30.1234567890"), + Decimal128("30.1234567890")}; + std::vector e23 = {Decimal128("-02."), Decimal128("30."), + Decimal128("-70.")}; + std::vector f23 = {Decimal128("-02.0000000000"), + Decimal128("30.0000000000"), + Decimal128("80.0000000000")}; + + for (int i = 0; i < 2; i++) { + options.allow_decimal_truncate = bool(i % 2); + if (options.allow_decimal_truncate) { + CheckCase( + decimal(38, 10), v22, is_valid2, decimal(28, 0), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, decimal(28, 0), e23, options); + // and back + CheckCase( + decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); + CheckCase( + decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); + } else { + CheckFails(decimal(38, 10), v22, is_valid2, decimal(28, 0), + options); + CheckFails(decimal(38, 10), v23, is_valid3, decimal(28, 0), + options); + } + } +} + TEST_F(TestCast, TimestampToTimestamp) { CastOptions options; diff --git a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h index 2f7cb291d10..85b0c4afcdd 100644 --- a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h @@ -171,6 +171,10 @@ TEMPLATE(DoubleType, StringType) \ TEMPLATE(DoubleType, LargeStringType) +#define DECIMAL128_CASES(TEMPLATE) \ + TEMPLATE(Decimal128Type, Decimal128Type) \ + TEMPLATE(Decimal128Type, Int64Type) + #define DATE32_CASES(TEMPLATE) \ TEMPLATE(Date32Type, Date64Type) diff --git a/cpp/src/arrow/compute/kernels/generated/codegen.py b/cpp/src/arrow/compute/kernels/generated/codegen.py index 76d27ad3344..e3bdd58a0b9 100644 --- a/cpp/src/arrow/compute/kernels/generated/codegen.py +++ b/cpp/src/arrow/compute/kernels/generated/codegen.py @@ -23,9 +23,8 @@ import io import os - -INTEGER_TYPES = ['UInt8', 'Int8', 'UInt16', 'Int16', - 'UInt32', 'Int32', 'UInt64', 'Int64'] +SIGNED_INTEGER_TYPES = ['Int8', 'Int16', 'Int32', 'Int64'] +INTEGER_TYPES = SIGNED_INTEGER_TYPES + ['UInt8', 'UInt16', 'UInt32', 'UInt64'] FLOATING_TYPES = ['Float', 'Double'] NUMERIC_TYPES = ['Boolean'] + INTEGER_TYPES + FLOATING_TYPES STRING_TYPES = ['String', 'LargeString'] @@ -78,6 +77,8 @@ def generate(self): CastCodeGenerator('Int64', NUMERIC_TYPES + STRING_TYPES), CastCodeGenerator('Float', NUMERIC_TYPES + STRING_TYPES), CastCodeGenerator('Double', NUMERIC_TYPES + STRING_TYPES), + CastCodeGenerator('Decimal128', ['Decimal128'] + SIGNED_INTEGER_TYPES, + parametric=True), CastCodeGenerator('Date32', ['Date64']), CastCodeGenerator('Date64', ['Date32']), CastCodeGenerator('Time32', ['Time32', 'Time64'], diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 7f58369b38b..b700d4ebceb 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -991,6 +991,7 @@ class ARROW_EXPORT Decimal128Type : public DecimalType { static constexpr int32_t kMinPrecision = 1; static constexpr int32_t kMaxPrecision = 38; + using c_type = arrow::Decimal128; }; struct UnionMode { diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 918b28960d5..7aed75a1a57 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -130,6 +130,7 @@ class StructArray; class StructBuilder; struct StructScalar; +class Decimal128; class Decimal128Type; class Decimal128Array; class Decimal128Builder; From 74f1c9446e881c0da005263580fc2a31f473f687 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 17 Mar 2020 17:23:09 +0100 Subject: [PATCH 2/5] Some simplifications. Also, regenerate generated code. --- cpp/src/arrow/compute/kernels/cast.cc | 315 +++++------------- cpp/src/arrow/compute/kernels/cast_test.cc | 126 +++---- .../kernels/generated/cast_codegen_internal.h | 99 +++--- cpp/src/arrow/type.h | 1 - 4 files changed, 204 insertions(+), 337 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 14d1ba0ddf7..829957e82d3 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -404,268 +404,133 @@ struct CastFunctor< } }; -// Decimal to Integer - -template -struct is_decimal_truncate { - static constexpr bool value = false; -}; +// ---------------------------------------------------------------------- +// Decimals -template -struct is_decimal_truncate< - O, I, enable_if_t<(is_integer_type::value && is_decimal_type::value)>> { - static constexpr bool value = true; -}; +// Decimal to Integer -template -struct CastFunctor::value>> { +template +struct CastFunctor::value>> { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { - using in_type = typename I::c_type; using out_type = typename O::c_type; - const auto& in_type_inst = checked_cast(*input.type); + const auto& in_type_inst = checked_cast(*input.type); auto in_scale = in_type_inst.scale(); - const in_type* in_data = input.GetValues(1); auto out_data = output->GetMutableValues(1); - if (input.null_count == 0) { // no nulls - if (options.allow_decimal_truncate) { // no nulls, truncate - if (options.allow_int_overflow) { // no nulls, truncate, overflow - if (in_scale < 0) { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = - static_cast(in_data++->IncreaseScaleBy(-in_scale).low_bits()); - } - } else { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = static_cast( - in_data++->ReduceScaleBy(in_scale, false).low_bits()); - } - } - } else { // no nulls, truncate, no overflow - auto min_value = std::numeric_limits::min(); - auto max_value = std::numeric_limits::max(); - - if (in_scale < 0) { - for (int64_t i = 0; i < input.length; ++i) { - auto result = in_data++->IncreaseScaleBy(-in_scale); - if (result < min_value || result > max_value) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); - } else { - *out_data++ = static_cast(result.low_bits()); - } - } - } else { - for (int64_t i = 0; i < input.length; ++i) { - auto result = in_data++->ReduceScaleBy(in_scale, false); - if (result < min_value || result > max_value) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); - } else { - *out_data++ = static_cast(result.low_bits()); - } - } - } - } - } else { // no nulls, no truncate - if (options.allow_int_overflow) { // no nulls, no truncate, overflow - for (int64_t i = 0; i < input.length; ++i) { - auto result = in_data++->Rescale(in_scale, 0); - if (result.ok()) { - auto result_value = result.ValueOrDie(); - *out_data++ = static_cast(result_value.low_bits()); + constexpr auto min_value = std::numeric_limits::min(); + constexpr auto max_value = std::numeric_limits::max(); + + if (options.allow_decimal_truncate) { + if (in_scale < 0) { + // Unsafe upscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto converted = dec_value.IncreaseScaleBy(-in_scale); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); } else { - ctx->SetStatus(result.status()); + *out_data = static_cast(converted.low_bits()); } } - } else { // no nulls, no truncate, no overflow - auto min_value = std::numeric_limits::min(); - auto max_value = std::numeric_limits::max(); - for (int64_t i = 0; i < input.length; ++i) { - auto result = in_data++->Rescale(in_scale, 0); - if (result.ok()) { - auto result_value = result.ValueOrDie(); - if (result_value < min_value || result_value > max_value) { - ctx->SetStatus(Status::Invalid("Invalid cast from Decimal128 to ", - sizeof(out_type), " byte integer")); - } else { - *out_data++ = static_cast(result_value.low_bits()); - } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } else { + // Unsafe downscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto converted = dec_value.ReduceScaleBy(in_scale, false); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); } else { - ctx->SetStatus(result.status()); + *out_data = static_cast(converted.low_bits()); } } - } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); } - } else { // nulls - if (options.allow_decimal_truncate) { // nulls, truncate - if (options.allow_int_overflow) { // nulls, truncate, overflow - if (in_scale < 0) { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = - static_cast(in_data++->IncreaseScaleBy(-in_scale).low_bits()); - } - } else { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = static_cast( - in_data++->ReduceScaleBy(in_scale, false).low_bits()); - } - } - } else { // nulls, truncate, no overflow - auto min_value = std::numeric_limits::min(); - auto max_value = std::numeric_limits::max(); - internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, - input.length); - - if (in_scale < 0) { - for (int64_t i = 0; i < input.length; ++i) { - if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { - if (*in_data > max_value || *in_data < min_value) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); - } else { - *out_data = static_cast( - in_data->IncreaseScaleBy(-in_scale).low_bits()); - } - } - out_data++; - in_data++; - is_valid_reader.Next(); - } + } else { + // Safe rescale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto result = dec_value.Rescale(in_scale, 0); + if (ARROW_PREDICT_FALSE(!result.ok())) { + ctx->SetStatus(result.status()); } else { - for (int64_t i = 0; i < input.length; ++i) { - if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { - if (*in_data > max_value || *in_data < min_value) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); - } else { - *out_data = static_cast( - in_data->ReduceScaleBy(in_scale, false).low_bits()); - } - } - out_data++; - in_data++; - is_valid_reader.Next(); - } - } - } - } else { // nulls, no truncate - internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, - input.length); - - if (options.allow_int_overflow) { // nulls, no truncate, overflow - for (int64_t i = 0; i < input.length; ++i) { - if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { - auto result = in_data->Rescale(in_scale, 0); - if (result.ok()) { - auto result_value = result.ValueOrDie(); - *out_data = static_cast(result_value.low_bits()); - } else { - ctx->SetStatus(result.status()); - } - } - out_data++; - in_data++; - is_valid_reader.Next(); - } - } else { // nulls, no truncate, no overflow - auto min_value = std::numeric_limits::min(); - auto max_value = std::numeric_limits::max(); - - for (int64_t i = 0; i < input.length; ++i) { - if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { - if (*in_data > max_value || *in_data < min_value) { - ctx->SetStatus(Status::Invalid("Integer value out of bounds")); - } else { - auto result = in_data->Rescale(in_scale, 0); - if (result.ok()) { - auto result_value = result.ValueOrDie(); - *out_data = static_cast(result_value.low_bits()); - } else { - ctx->SetStatus(result.status()); - } - } + auto converted = *std::move(result); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast(converted.low_bits()); } - out_data++; - in_data++; - is_valid_reader.Next(); } } - } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); } } }; -// ---------------------------------------------------------------------- // Decimal to Decimal -template -struct CastFunctor< - O, I, enable_if_t<(is_decimal_type::value && is_decimal_type::value)>> { +template <> +struct CastFunctor { void operator()(FunctionContext* ctx, const CastOptions& options, const ArrayData& input, ArrayData* output) { - using in_type = typename I::c_type; - using out_type = typename O::c_type; - - const auto& in_type_inst = checked_cast(*input.type); - const auto& out_type_inst = checked_cast(*output->type); + const auto& in_type_inst = checked_cast(*input.type); + const auto& out_type_inst = checked_cast(*output->type); auto in_scale = in_type_inst.scale(); auto out_scale = out_type_inst.scale(); - const in_type* in_data = input.GetValues(1); - auto out_data = output->GetMutableValues(1); + auto out_data = output->GetMutableValues(1); - if (input.null_count == 0) { // no nulls - if (options.allow_decimal_truncate) { // no nulls, truncate - if (in_scale < out_scale) { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = - static_cast(in_data++->IncreaseScaleBy(out_scale - in_scale)); + if (options.allow_decimal_truncate) { + if (in_scale < out_scale) { + // Unsafe upscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + dec_value.IncreaseScaleBy(out_scale - in_scale).ToBytes(out_data); } - } else { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = static_cast( - in_data++->ReduceScaleBy(in_scale - out_scale, false)); - } - } - } else { // no nulls, no truncate - for (int64_t i = 0; i < input.length; ++i) { - auto result = in_data++->Rescale(in_scale, out_scale); - if (result.ok()) { - *out_data = static_cast(result.ValueOrDie()); - } else { - ctx->SetStatus(result.status()); + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } else { + // Unsafe downscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + dec_value.ReduceScaleBy(in_scale - out_scale, false).ToBytes(out_data); } - out_data++; - } + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); } - } else { // nulls - if (options.allow_decimal_truncate) { // nulls, truncate - if (in_scale < out_scale) { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = - static_cast(in_data++->IncreaseScaleBy(out_scale - in_scale)); - } - } else { - for (int64_t i = 0; i < input.length; ++i) { - *out_data++ = static_cast( - in_data++->ReduceScaleBy(in_scale - out_scale, false)); - } - } - } else { // nulls, no truncate - internal::BitmapReader is_valid_reader(input.buffers[0]->data(), input.offset, - input.length); - for (int64_t i = 0; i < input.length; ++i) { - if (ARROW_PREDICT_FALSE(is_valid_reader.IsSet())) { - auto result = in_data++->Rescale(in_scale, out_scale); - if (result.ok()) { - *out_data = static_cast(result.ValueOrDie()); - } else { - ctx->SetStatus(result.status()); - } + } else { + // Safe rescale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto result = dec_value.Rescale(in_scale, out_scale); + if (ARROW_PREDICT_FALSE(!result.ok())) { + ctx->SetStatus(result.status()); + } else { + (*std::move(result)).ToBytes(out_data); } - out_data++; - is_valid_reader.Next(); } - } + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); } } }; diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index cd00f06aac0..82b4d95e6de 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -618,13 +618,15 @@ TEST_F(TestCast, DecimalToInt) { std::vector e12 = {2, -11}; std::vector e13 = {2, -11, 0}; - for (int i = 0; i < 4; i++) { - options.allow_int_overflow = bool(i / 2); - options.allow_decimal_truncate = bool(i % 2); - CheckCase( - decimal(38, 10), v12, is_valid2, int64(), e12, options); - CheckCase( - decimal(38, 10), v13, is_valid3, int64(), e13, options); + for (bool allow_int_overflow : {false, true}) { + for (bool allow_decimal_truncate : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = allow_decimal_truncate; + CheckCase( + decimal(38, 10), v12, is_valid2, int64(), e12, options); + CheckCase( + decimal(38, 10), v13, is_valid3, int64(), e13, options); + } } // truncation, no overflow @@ -636,18 +638,16 @@ TEST_F(TestCast, DecimalToInt) { std::vector e22 = {2, -11}; std::vector e23 = {2, -11, 0}; - for (int i = 0; i < 4; i++) { - options.allow_int_overflow = bool(i / 2); - options.allow_decimal_truncate = bool(i % 2); - if (options.allow_decimal_truncate) { - CheckCase( - decimal(38, 10), v22, is_valid2, int64(), e22, options); - CheckCase( - decimal(38, 10), v23, is_valid3, int64(), e23, options); - } else { - CheckFails(decimal(38, 10), v22, is_valid2, int64(), options); - CheckFails(decimal(38, 10), v23, is_valid3, int64(), options); - } + for (bool allow_int_overflow : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = true; + CheckCase( + decimal(38, 10), v22, is_valid2, int64(), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, int64(), e23, options); + options.allow_decimal_truncate = false; + CheckFails(decimal(38, 10), v22, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v23, is_valid3, int64(), options); } // overflow, no truncation @@ -660,18 +660,16 @@ TEST_F(TestCast, DecimalToInt) { std::vector e32 = {4807115922877858896, 200376420520689663}; std::vector e33 = {4807115922877858896, 200376420520689663, -2}; - for (int i = 0; i < 4; i++) { - options.allow_int_overflow = bool(i / 2); - options.allow_decimal_truncate = bool(i % 2); - if (options.allow_int_overflow) { - CheckCase( - decimal(38, 10), v32, is_valid2, int64(), e32, options); - CheckCase( - decimal(38, 10), v33, is_valid3, int64(), e33, options); - } else { - CheckFails(decimal(38, 10), v32, is_valid2, int64(), options); - CheckFails(decimal(38, 10), v33, is_valid3, int64(), options); - } + for (bool allow_decimal_truncate : {false, true}) { + options.allow_decimal_truncate = allow_decimal_truncate; + options.allow_int_overflow = true; + CheckCase( + decimal(38, 10), v32, is_valid2, int64(), e32, options); + CheckCase( + decimal(38, 10), v33, is_valid3, int64(), e33, options); + options.allow_int_overflow = false; + CheckFails(decimal(38, 10), v32, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v33, is_valid3, int64(), options); } // overflow, truncation @@ -684,17 +682,19 @@ TEST_F(TestCast, DecimalToInt) { std::vector e42 = {4807115922877858896, 200376420520689663}; std::vector e43 = {4807115922877858896, 200376420520689663, -2}; - for (int i = 0; i < 4; i++) { - options.allow_int_overflow = bool(i / 2); - options.allow_decimal_truncate = bool(i % 2); - if (options.allow_int_overflow && options.allow_decimal_truncate) { - CheckCase( - decimal(38, 10), v42, is_valid2, int64(), e42, options); - CheckCase( - decimal(38, 10), v43, is_valid3, int64(), e43, options); - } else { - CheckFails(decimal(38, 10), v42, is_valid2, int64(), options); - CheckFails(decimal(38, 10), v43, is_valid3, int64(), options); + for (bool allow_int_overflow : {false, true}) { + for (bool allow_decimal_truncate : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = allow_decimal_truncate; + if (options.allow_int_overflow && options.allow_decimal_truncate) { + CheckCase( + decimal(38, 10), v42, is_valid2, int64(), e42, options); + CheckCase( + decimal(38, 10), v43, is_valid3, int64(), e43, options); + } else { + CheckFails(decimal(38, 10), v42, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v43, is_valid3, int64(), options); + } } } @@ -721,8 +721,8 @@ TEST_F(TestCast, DecimalToDecimal) { Decimal128("30.0000000000")}; std::vector e13 = {Decimal128("02."), Decimal128("30."), Decimal128("-1.")}; - for (int i = 0; i < 2; i++) { - options.allow_decimal_truncate = bool(i % 2); + for (bool allow_decimal_truncate : {false, true}) { + options.allow_decimal_truncate = allow_decimal_truncate; CheckCase( decimal(38, 10), v12, is_valid2, decimal(28, 0), e12, options); CheckCase( @@ -748,25 +748,25 @@ TEST_F(TestCast, DecimalToDecimal) { Decimal128("30.0000000000"), Decimal128("80.0000000000")}; - for (int i = 0; i < 2; i++) { - options.allow_decimal_truncate = bool(i % 2); - if (options.allow_decimal_truncate) { - CheckCase( - decimal(38, 10), v22, is_valid2, decimal(28, 0), e22, options); - CheckCase( - decimal(38, 10), v23, is_valid3, decimal(28, 0), e23, options); - // and back - CheckCase( - decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); - CheckCase( - decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); - } else { - CheckFails(decimal(38, 10), v22, is_valid2, decimal(28, 0), - options); - CheckFails(decimal(38, 10), v23, is_valid3, decimal(28, 0), - options); - } - } + options.allow_decimal_truncate = true; + CheckCase( + decimal(38, 10), v22, is_valid2, decimal(28, 0), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, decimal(28, 0), e23, options); + // and back + CheckCase( + decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); + CheckCase( + decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); + + options.allow_decimal_truncate = false; + CheckFails(decimal(38, 10), v22, is_valid2, decimal(28, 0), options); + CheckFails(decimal(38, 10), v23, is_valid3, decimal(28, 0), options); + // back case is ok + CheckCase( + decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); + CheckCase( + decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); } TEST_F(TestCast, TimestampToTimestamp) { diff --git a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h index 85b0c4afcdd..bc4ce64d0fb 100644 --- a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h @@ -18,14 +18,14 @@ // THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT // Generated by codegen.py script #define BOOLEAN_CASES(TEMPLATE) \ - TEMPLATE(BooleanType, UInt8Type) \ TEMPLATE(BooleanType, Int8Type) \ - TEMPLATE(BooleanType, UInt16Type) \ TEMPLATE(BooleanType, Int16Type) \ - TEMPLATE(BooleanType, UInt32Type) \ TEMPLATE(BooleanType, Int32Type) \ - TEMPLATE(BooleanType, UInt64Type) \ TEMPLATE(BooleanType, Int64Type) \ + TEMPLATE(BooleanType, UInt8Type) \ + TEMPLATE(BooleanType, UInt16Type) \ + TEMPLATE(BooleanType, UInt32Type) \ + TEMPLATE(BooleanType, UInt64Type) \ TEMPLATE(BooleanType, FloatType) \ TEMPLATE(BooleanType, DoubleType) \ TEMPLATE(BooleanType, StringType) \ @@ -34,12 +34,12 @@ #define UINT8_CASES(TEMPLATE) \ TEMPLATE(UInt8Type, BooleanType) \ TEMPLATE(UInt8Type, Int8Type) \ - TEMPLATE(UInt8Type, UInt16Type) \ TEMPLATE(UInt8Type, Int16Type) \ - TEMPLATE(UInt8Type, UInt32Type) \ TEMPLATE(UInt8Type, Int32Type) \ - TEMPLATE(UInt8Type, UInt64Type) \ TEMPLATE(UInt8Type, Int64Type) \ + TEMPLATE(UInt8Type, UInt16Type) \ + TEMPLATE(UInt8Type, UInt32Type) \ + TEMPLATE(UInt8Type, UInt64Type) \ TEMPLATE(UInt8Type, FloatType) \ TEMPLATE(UInt8Type, DoubleType) \ TEMPLATE(UInt8Type, StringType) \ @@ -47,13 +47,13 @@ #define INT8_CASES(TEMPLATE) \ TEMPLATE(Int8Type, BooleanType) \ + TEMPLATE(Int8Type, Int16Type) \ + TEMPLATE(Int8Type, Int32Type) \ + TEMPLATE(Int8Type, Int64Type) \ TEMPLATE(Int8Type, UInt8Type) \ TEMPLATE(Int8Type, UInt16Type) \ - TEMPLATE(Int8Type, Int16Type) \ TEMPLATE(Int8Type, UInt32Type) \ - TEMPLATE(Int8Type, Int32Type) \ TEMPLATE(Int8Type, UInt64Type) \ - TEMPLATE(Int8Type, Int64Type) \ TEMPLATE(Int8Type, FloatType) \ TEMPLATE(Int8Type, DoubleType) \ TEMPLATE(Int8Type, StringType) \ @@ -61,13 +61,13 @@ #define UINT16_CASES(TEMPLATE) \ TEMPLATE(UInt16Type, BooleanType) \ - TEMPLATE(UInt16Type, UInt8Type) \ TEMPLATE(UInt16Type, Int8Type) \ TEMPLATE(UInt16Type, Int16Type) \ - TEMPLATE(UInt16Type, UInt32Type) \ TEMPLATE(UInt16Type, Int32Type) \ - TEMPLATE(UInt16Type, UInt64Type) \ TEMPLATE(UInt16Type, Int64Type) \ + TEMPLATE(UInt16Type, UInt8Type) \ + TEMPLATE(UInt16Type, UInt32Type) \ + TEMPLATE(UInt16Type, UInt64Type) \ TEMPLATE(UInt16Type, FloatType) \ TEMPLATE(UInt16Type, DoubleType) \ TEMPLATE(UInt16Type, StringType) \ @@ -75,13 +75,13 @@ #define INT16_CASES(TEMPLATE) \ TEMPLATE(Int16Type, BooleanType) \ - TEMPLATE(Int16Type, UInt8Type) \ TEMPLATE(Int16Type, Int8Type) \ + TEMPLATE(Int16Type, Int32Type) \ + TEMPLATE(Int16Type, Int64Type) \ + TEMPLATE(Int16Type, UInt8Type) \ TEMPLATE(Int16Type, UInt16Type) \ TEMPLATE(Int16Type, UInt32Type) \ - TEMPLATE(Int16Type, Int32Type) \ TEMPLATE(Int16Type, UInt64Type) \ - TEMPLATE(Int16Type, Int64Type) \ TEMPLATE(Int16Type, FloatType) \ TEMPLATE(Int16Type, DoubleType) \ TEMPLATE(Int16Type, StringType) \ @@ -89,13 +89,13 @@ #define UINT32_CASES(TEMPLATE) \ TEMPLATE(UInt32Type, BooleanType) \ - TEMPLATE(UInt32Type, UInt8Type) \ TEMPLATE(UInt32Type, Int8Type) \ - TEMPLATE(UInt32Type, UInt16Type) \ TEMPLATE(UInt32Type, Int16Type) \ TEMPLATE(UInt32Type, Int32Type) \ - TEMPLATE(UInt32Type, UInt64Type) \ TEMPLATE(UInt32Type, Int64Type) \ + TEMPLATE(UInt32Type, UInt8Type) \ + TEMPLATE(UInt32Type, UInt16Type) \ + TEMPLATE(UInt32Type, UInt64Type) \ TEMPLATE(UInt32Type, FloatType) \ TEMPLATE(UInt32Type, DoubleType) \ TEMPLATE(UInt32Type, StringType) \ @@ -103,13 +103,13 @@ #define UINT64_CASES(TEMPLATE) \ TEMPLATE(UInt64Type, BooleanType) \ - TEMPLATE(UInt64Type, UInt8Type) \ TEMPLATE(UInt64Type, Int8Type) \ - TEMPLATE(UInt64Type, UInt16Type) \ TEMPLATE(UInt64Type, Int16Type) \ - TEMPLATE(UInt64Type, UInt32Type) \ TEMPLATE(UInt64Type, Int32Type) \ TEMPLATE(UInt64Type, Int64Type) \ + TEMPLATE(UInt64Type, UInt8Type) \ + TEMPLATE(UInt64Type, UInt16Type) \ + TEMPLATE(UInt64Type, UInt32Type) \ TEMPLATE(UInt64Type, FloatType) \ TEMPLATE(UInt64Type, DoubleType) \ TEMPLATE(UInt64Type, StringType) \ @@ -117,13 +117,13 @@ #define INT32_CASES(TEMPLATE) \ TEMPLATE(Int32Type, BooleanType) \ - TEMPLATE(Int32Type, UInt8Type) \ TEMPLATE(Int32Type, Int8Type) \ - TEMPLATE(Int32Type, UInt16Type) \ TEMPLATE(Int32Type, Int16Type) \ + TEMPLATE(Int32Type, Int64Type) \ + TEMPLATE(Int32Type, UInt8Type) \ + TEMPLATE(Int32Type, UInt16Type) \ TEMPLATE(Int32Type, UInt32Type) \ TEMPLATE(Int32Type, UInt64Type) \ - TEMPLATE(Int32Type, Int64Type) \ TEMPLATE(Int32Type, FloatType) \ TEMPLATE(Int32Type, DoubleType) \ TEMPLATE(Int32Type, StringType) \ @@ -131,12 +131,12 @@ #define INT64_CASES(TEMPLATE) \ TEMPLATE(Int64Type, BooleanType) \ - TEMPLATE(Int64Type, UInt8Type) \ TEMPLATE(Int64Type, Int8Type) \ - TEMPLATE(Int64Type, UInt16Type) \ TEMPLATE(Int64Type, Int16Type) \ - TEMPLATE(Int64Type, UInt32Type) \ TEMPLATE(Int64Type, Int32Type) \ + TEMPLATE(Int64Type, UInt8Type) \ + TEMPLATE(Int64Type, UInt16Type) \ + TEMPLATE(Int64Type, UInt32Type) \ TEMPLATE(Int64Type, UInt64Type) \ TEMPLATE(Int64Type, FloatType) \ TEMPLATE(Int64Type, DoubleType) \ @@ -145,34 +145,37 @@ #define FLOAT_CASES(TEMPLATE) \ TEMPLATE(FloatType, BooleanType) \ - TEMPLATE(FloatType, UInt8Type) \ TEMPLATE(FloatType, Int8Type) \ - TEMPLATE(FloatType, UInt16Type) \ TEMPLATE(FloatType, Int16Type) \ - TEMPLATE(FloatType, UInt32Type) \ TEMPLATE(FloatType, Int32Type) \ - TEMPLATE(FloatType, UInt64Type) \ TEMPLATE(FloatType, Int64Type) \ + TEMPLATE(FloatType, UInt8Type) \ + TEMPLATE(FloatType, UInt16Type) \ + TEMPLATE(FloatType, UInt32Type) \ + TEMPLATE(FloatType, UInt64Type) \ TEMPLATE(FloatType, DoubleType) \ TEMPLATE(FloatType, StringType) \ TEMPLATE(FloatType, LargeStringType) #define DOUBLE_CASES(TEMPLATE) \ TEMPLATE(DoubleType, BooleanType) \ - TEMPLATE(DoubleType, UInt8Type) \ TEMPLATE(DoubleType, Int8Type) \ - TEMPLATE(DoubleType, UInt16Type) \ TEMPLATE(DoubleType, Int16Type) \ - TEMPLATE(DoubleType, UInt32Type) \ TEMPLATE(DoubleType, Int32Type) \ - TEMPLATE(DoubleType, UInt64Type) \ TEMPLATE(DoubleType, Int64Type) \ + TEMPLATE(DoubleType, UInt8Type) \ + TEMPLATE(DoubleType, UInt16Type) \ + TEMPLATE(DoubleType, UInt32Type) \ + TEMPLATE(DoubleType, UInt64Type) \ TEMPLATE(DoubleType, FloatType) \ TEMPLATE(DoubleType, StringType) \ TEMPLATE(DoubleType, LargeStringType) #define DECIMAL128_CASES(TEMPLATE) \ TEMPLATE(Decimal128Type, Decimal128Type) \ + TEMPLATE(Decimal128Type, Int8Type) \ + TEMPLATE(Decimal128Type, Int16Type) \ + TEMPLATE(Decimal128Type, Int32Type) \ TEMPLATE(Decimal128Type, Int64Type) #define DATE32_CASES(TEMPLATE) \ @@ -205,41 +208,41 @@ #define STRING_CASES(TEMPLATE) \ TEMPLATE(StringType, BooleanType) \ - TEMPLATE(StringType, UInt8Type) \ TEMPLATE(StringType, Int8Type) \ - TEMPLATE(StringType, UInt16Type) \ TEMPLATE(StringType, Int16Type) \ - TEMPLATE(StringType, UInt32Type) \ TEMPLATE(StringType, Int32Type) \ - TEMPLATE(StringType, UInt64Type) \ TEMPLATE(StringType, Int64Type) \ + TEMPLATE(StringType, UInt8Type) \ + TEMPLATE(StringType, UInt16Type) \ + TEMPLATE(StringType, UInt32Type) \ + TEMPLATE(StringType, UInt64Type) \ TEMPLATE(StringType, FloatType) \ TEMPLATE(StringType, DoubleType) \ TEMPLATE(StringType, TimestampType) #define LARGESTRING_CASES(TEMPLATE) \ TEMPLATE(LargeStringType, BooleanType) \ - TEMPLATE(LargeStringType, UInt8Type) \ TEMPLATE(LargeStringType, Int8Type) \ - TEMPLATE(LargeStringType, UInt16Type) \ TEMPLATE(LargeStringType, Int16Type) \ - TEMPLATE(LargeStringType, UInt32Type) \ TEMPLATE(LargeStringType, Int32Type) \ - TEMPLATE(LargeStringType, UInt64Type) \ TEMPLATE(LargeStringType, Int64Type) \ + TEMPLATE(LargeStringType, UInt8Type) \ + TEMPLATE(LargeStringType, UInt16Type) \ + TEMPLATE(LargeStringType, UInt32Type) \ + TEMPLATE(LargeStringType, UInt64Type) \ TEMPLATE(LargeStringType, FloatType) \ TEMPLATE(LargeStringType, DoubleType) \ TEMPLATE(LargeStringType, TimestampType) #define DICTIONARY_CASES(TEMPLATE) \ - TEMPLATE(DictionaryType, UInt8Type) \ TEMPLATE(DictionaryType, Int8Type) \ - TEMPLATE(DictionaryType, UInt16Type) \ TEMPLATE(DictionaryType, Int16Type) \ - TEMPLATE(DictionaryType, UInt32Type) \ TEMPLATE(DictionaryType, Int32Type) \ - TEMPLATE(DictionaryType, UInt64Type) \ TEMPLATE(DictionaryType, Int64Type) \ + TEMPLATE(DictionaryType, UInt8Type) \ + TEMPLATE(DictionaryType, UInt16Type) \ + TEMPLATE(DictionaryType, UInt32Type) \ + TEMPLATE(DictionaryType, UInt64Type) \ TEMPLATE(DictionaryType, FloatType) \ TEMPLATE(DictionaryType, DoubleType) \ TEMPLATE(DictionaryType, Date32Type) \ diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index b700d4ebceb..7f58369b38b 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -991,7 +991,6 @@ class ARROW_EXPORT Decimal128Type : public DecimalType { static constexpr int32_t kMinPrecision = 1; static constexpr int32_t kMaxPrecision = 38; - using c_type = arrow::Decimal128; }; struct UnionMode { From c0cb30661d8a93da83dfbe24bcf620f32f7d8d58 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 17 Mar 2020 18:30:42 +0100 Subject: [PATCH 3/5] Fix ARROW_PREDICT_* macros (!) --- cpp/src/arrow/util/macros.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/util/macros.h b/cpp/src/arrow/util/macros.h index 7d04a80e802..624b266337e 100644 --- a/cpp/src/arrow/util/macros.h +++ b/cpp/src/arrow/util/macros.h @@ -51,13 +51,13 @@ #define ARROW_PREFETCH(addr) __builtin_prefetch(addr) #elif defined(_MSC_VER) #define ARROW_NORETURN __declspec(noreturn) -#define ARROW_PREDICT_FALSE(x) x -#define ARROW_PREDICT_TRUE(x) x +#define ARROW_PREDICT_FALSE(x) (x) +#define ARROW_PREDICT_TRUE(x) (x) #define ARROW_PREFETCH(addr) #else #define ARROW_NORETURN -#define ARROW_PREDICT_FALSE(x) x -#define ARROW_PREDICT_TRUE(x) x +#define ARROW_PREDICT_FALSE(x) (x) +#define ARROW_PREDICT_TRUE(x) (x) #define ARROW_PREFETCH(addr) #endif From 38b77ed210d658ed7696104313f1494eb3ba9f69 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 17 Mar 2020 18:40:51 +0100 Subject: [PATCH 4/5] Make sure all data is initialized --- cpp/src/arrow/compute/kernels/cast.cc | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index 829957e82d3..09280496901 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -421,11 +421,13 @@ struct CastFunctor::value>> { constexpr auto min_value = std::numeric_limits::min(); constexpr auto max_value = std::numeric_limits::max(); + constexpr auto zero = out_type{}; if (options.allow_decimal_truncate) { if (in_scale < 0) { // Unsafe upscale auto convert_value = [&](util::optional v) { + *out_data = zero; if (v.has_value()) { auto dec_value = Decimal128(reinterpret_cast(v->data())); auto converted = dec_value.IncreaseScaleBy(-in_scale); @@ -442,6 +444,7 @@ struct CastFunctor::value>> { } else { // Unsafe downscale auto convert_value = [&](util::optional v) { + *out_data = zero; if (v.has_value()) { auto dec_value = Decimal128(reinterpret_cast(v->data())); auto converted = dec_value.ReduceScaleBy(in_scale, false); @@ -459,6 +462,7 @@ struct CastFunctor::value>> { } else { // Safe rescale auto convert_value = [&](util::optional v) { + *out_data = zero; if (v.has_value()) { auto dec_value = Decimal128(reinterpret_cast(v->data())); auto result = dec_value.Rescale(in_scale, 0); @@ -494,6 +498,8 @@ struct CastFunctor { auto out_data = output->GetMutableValues(1); + const auto write_zero = [](uint8_t* out_data) { memset(out_data, 0, 16); }; + if (options.allow_decimal_truncate) { if (in_scale < out_scale) { // Unsafe upscale @@ -501,6 +507,8 @@ struct CastFunctor { if (v.has_value()) { auto dec_value = Decimal128(reinterpret_cast(v->data())); dec_value.IncreaseScaleBy(out_scale - in_scale).ToBytes(out_data); + } else { + write_zero(out_data); } out_data += 16; }; @@ -511,6 +519,8 @@ struct CastFunctor { if (v.has_value()) { auto dec_value = Decimal128(reinterpret_cast(v->data())); dec_value.ReduceScaleBy(in_scale - out_scale, false).ToBytes(out_data); + } else { + write_zero(out_data); } out_data += 16; }; @@ -524,9 +534,12 @@ struct CastFunctor { auto result = dec_value.Rescale(in_scale, out_scale); if (ARROW_PREDICT_FALSE(!result.ok())) { ctx->SetStatus(result.status()); + write_zero(out_data); } else { (*std::move(result)).ToBytes(out_data); } + } else { + write_zero(out_data); } out_data += 16; }; From 2c303abd5ee5c87fbf9764c4b58bd02519904026 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 17 Mar 2020 19:17:05 +0100 Subject: [PATCH 5/5] Also improve gcc / clang version of ARROW_PREDICT_FALSE --- cpp/src/arrow/util/macros.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/util/macros.h b/cpp/src/arrow/util/macros.h index 624b266337e..ed2d6edd2a2 100644 --- a/cpp/src/arrow/util/macros.h +++ b/cpp/src/arrow/util/macros.h @@ -45,7 +45,7 @@ // the absence of better information (ie. -fprofile-arcs). // #if defined(__GNUC__) -#define ARROW_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define ARROW_PREDICT_FALSE(x) (__builtin_expect(!!(x), 0)) #define ARROW_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) #define ARROW_NORETURN __attribute__((noreturn)) #define ARROW_PREFETCH(addr) __builtin_prefetch(addr)