diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc index c0b68539e80..09280496901 100644 --- a/cpp/src/arrow/compute/kernels/cast.cc +++ b/cpp/src/arrow/compute/kernels/cast.cc @@ -404,6 +404,150 @@ struct CastFunctor< } }; +// ---------------------------------------------------------------------- +// Decimals + +// Decimal to Integer + +template +struct CastFunctor::value>> { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + using out_type = typename O::c_type; + const auto& in_type_inst = checked_cast(*input.type); + auto in_scale = in_type_inst.scale(); + + auto out_data = output->GetMutableValues(1); + + constexpr auto min_value = std::numeric_limits::min(); + constexpr auto max_value = std::numeric_limits::max(); + constexpr auto zero = out_type{}; + + if (options.allow_decimal_truncate) { + if (in_scale < 0) { + // Unsafe upscale + auto convert_value = [&](util::optional v) { + *out_data = zero; + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto converted = dec_value.IncreaseScaleBy(-in_scale); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast(converted.low_bits()); + } + } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } else { + // Unsafe downscale + auto convert_value = [&](util::optional v) { + *out_data = zero; + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto converted = dec_value.ReduceScaleBy(in_scale, false); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast(converted.low_bits()); + } + } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } + } else { + // Safe rescale + auto convert_value = [&](util::optional v) { + *out_data = zero; + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto result = dec_value.Rescale(in_scale, 0); + if (ARROW_PREDICT_FALSE(!result.ok())) { + ctx->SetStatus(result.status()); + } else { + auto converted = *std::move(result); + if (!options.allow_int_overflow && + ARROW_PREDICT_FALSE(converted < min_value || converted > max_value)) { + ctx->SetStatus(Status::Invalid("Integer value out of bounds")); + } else { + *out_data = static_cast(converted.low_bits()); + } + } + } + ++out_data; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } + } +}; + +// Decimal to Decimal + +template <> +struct CastFunctor { + void operator()(FunctionContext* ctx, const CastOptions& options, + const ArrayData& input, ArrayData* output) { + const auto& in_type_inst = checked_cast(*input.type); + const auto& out_type_inst = checked_cast(*output->type); + auto in_scale = in_type_inst.scale(); + auto out_scale = out_type_inst.scale(); + + auto out_data = output->GetMutableValues(1); + + const auto write_zero = [](uint8_t* out_data) { memset(out_data, 0, 16); }; + + if (options.allow_decimal_truncate) { + if (in_scale < out_scale) { + // Unsafe upscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + dec_value.IncreaseScaleBy(out_scale - in_scale).ToBytes(out_data); + } else { + write_zero(out_data); + } + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } else { + // Unsafe downscale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + dec_value.ReduceScaleBy(in_scale - out_scale, false).ToBytes(out_data); + } else { + write_zero(out_data); + } + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } + } else { + // Safe rescale + auto convert_value = [&](util::optional v) { + if (v.has_value()) { + auto dec_value = Decimal128(reinterpret_cast(v->data())); + auto result = dec_value.Rescale(in_scale, out_scale); + if (ARROW_PREDICT_FALSE(!result.ok())) { + ctx->SetStatus(result.status()); + write_zero(out_data); + } else { + (*std::move(result)).ToBytes(out_data); + } + } else { + write_zero(out_data); + } + out_data += 16; + }; + VisitArrayDataInline(input, std::move(convert_value)); + } + } +}; + // ---------------------------------------------------------------------- // From one timestamp to another @@ -1203,6 +1347,7 @@ GET_CAST_FUNCTION(UINT64_CASES, UInt64Type, CastKernel) GET_CAST_FUNCTION(INT64_CASES, Int64Type, CastKernel) GET_CAST_FUNCTION(FLOAT_CASES, FloatType, CastKernel) GET_CAST_FUNCTION(DOUBLE_CASES, DoubleType, CastKernel) +GET_CAST_FUNCTION(DECIMAL128_CASES, Decimal128Type, CastKernel) GET_CAST_FUNCTION(DATE32_CASES, Date32Type, CastKernel) GET_CAST_FUNCTION(DATE64_CASES, Date64Type, CastKernel) GET_CAST_FUNCTION(TIME32_CASES, Time32Type, CastKernel) @@ -1293,6 +1438,7 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr out_ty CAST_FUNCTION_CASE(Int64Type); CAST_FUNCTION_CASE(FloatType); CAST_FUNCTION_CASE(DoubleType); + CAST_FUNCTION_CASE(Decimal128Type); CAST_FUNCTION_CASE(Date32Type); CAST_FUNCTION_CASE(Date64Type); CAST_FUNCTION_CASE(Time32Type); diff --git a/cpp/src/arrow/compute/kernels/cast.h b/cpp/src/arrow/compute/kernels/cast.h index c5b873195f8..fec63a91026 100644 --- a/cpp/src/arrow/compute/kernels/cast.h +++ b/cpp/src/arrow/compute/kernels/cast.h @@ -38,6 +38,7 @@ struct ARROW_EXPORT CastOptions { : allow_int_overflow(false), allow_time_truncate(false), allow_time_overflow(false), + allow_decimal_truncate(false), allow_float_truncate(false), allow_invalid_utf8(false) {} @@ -45,6 +46,7 @@ struct ARROW_EXPORT CastOptions { : allow_int_overflow(!safe), allow_time_truncate(!safe), allow_time_overflow(!safe), + allow_decimal_truncate(!safe), allow_float_truncate(!safe), allow_invalid_utf8(!safe) {} @@ -55,6 +57,7 @@ struct ARROW_EXPORT CastOptions { bool allow_int_overflow; bool allow_time_truncate; bool allow_time_overflow; + bool allow_decimal_truncate; bool allow_float_truncate; // Indicate if conversions from Binary/FixedSizeBinary to string must // validate the utf8 payload. diff --git a/cpp/src/arrow/compute/kernels/cast_test.cc b/cpp/src/arrow/compute/kernels/cast_test.cc index 7198a10c4d1..82b4d95e6de 100644 --- a/cpp/src/arrow/compute/kernels/cast_test.cc +++ b/cpp/src/arrow/compute/kernels/cast_test.cc @@ -604,6 +604,171 @@ TEST_F(TestCast, IntToFloatingPoint) { } #endif +TEST_F(TestCast, DecimalToInt) { + CastOptions options; + std::vector is_valid2 = {true, true}; + std::vector is_valid3 = {true, true, false}; + + // no overflow no truncation + std::vector v12 = {Decimal128("02.0000000000"), + Decimal128("-11.0000000000")}; + std::vector v13 = {Decimal128("02.0000000000"), + Decimal128("-11.0000000000"), + Decimal128("-12.0000000000")}; + std::vector e12 = {2, -11}; + std::vector e13 = {2, -11, 0}; + + for (bool allow_int_overflow : {false, true}) { + for (bool allow_decimal_truncate : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = allow_decimal_truncate; + CheckCase( + decimal(38, 10), v12, is_valid2, int64(), e12, options); + CheckCase( + decimal(38, 10), v13, is_valid3, int64(), e13, options); + } + } + + // truncation, no overflow + std::vector v22 = {Decimal128("02.1000000000"), + Decimal128("-11.0000004500")}; + std::vector v23 = {Decimal128("02.1000000000"), + Decimal128("-11.0000004500"), + Decimal128("-12.0000004500")}; + std::vector e22 = {2, -11}; + std::vector e23 = {2, -11, 0}; + + for (bool allow_int_overflow : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = true; + CheckCase( + decimal(38, 10), v22, is_valid2, int64(), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, int64(), e23, options); + options.allow_decimal_truncate = false; + CheckFails(decimal(38, 10), v22, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v23, is_valid3, int64(), options); + } + + // overflow, no truncation + std::vector v32 = {Decimal128("12345678901234567890000.0000000000"), + Decimal128("99999999999999999999999.0000000000")}; + std::vector v33 = {Decimal128("12345678901234567890000.0000000000"), + Decimal128("99999999999999999999999.0000000000"), + Decimal128("99999999999999999999999.0000000000")}; + // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64 + std::vector e32 = {4807115922877858896, 200376420520689663}; + std::vector e33 = {4807115922877858896, 200376420520689663, -2}; + + for (bool allow_decimal_truncate : {false, true}) { + options.allow_decimal_truncate = allow_decimal_truncate; + options.allow_int_overflow = true; + CheckCase( + decimal(38, 10), v32, is_valid2, int64(), e32, options); + CheckCase( + decimal(38, 10), v33, is_valid3, int64(), e33, options); + options.allow_int_overflow = false; + CheckFails(decimal(38, 10), v32, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v33, is_valid3, int64(), options); + } + + // overflow, truncation + std::vector v42 = {Decimal128("12345678901234567890000.0045345000"), + Decimal128("99999999999999999999999.0000005430")}; + std::vector v43 = {Decimal128("12345678901234567890000.0005345340"), + Decimal128("99999999999999999999999.0000344300"), + Decimal128("99999999999999999999999.0004354000")}; + // 12345678901234567890000 % 2**64, 99999999999999999999999 % 2**64 + std::vector e42 = {4807115922877858896, 200376420520689663}; + std::vector e43 = {4807115922877858896, 200376420520689663, -2}; + + for (bool allow_int_overflow : {false, true}) { + for (bool allow_decimal_truncate : {false, true}) { + options.allow_int_overflow = allow_int_overflow; + options.allow_decimal_truncate = allow_decimal_truncate; + if (options.allow_int_overflow && options.allow_decimal_truncate) { + CheckCase( + decimal(38, 10), v42, is_valid2, int64(), e42, options); + CheckCase( + decimal(38, 10), v43, is_valid3, int64(), e43, options); + } else { + CheckFails(decimal(38, 10), v42, is_valid2, int64(), options); + CheckFails(decimal(38, 10), v43, is_valid3, int64(), options); + } + } + } + + // negative scale + std::vector v5 = {Decimal128("1234567890000."), Decimal128("-120000.")}; + for (int i = 0; i < 2; i++) v5[i] = v5[i].Rescale(0, -4).ValueOrDie(); + std::vector e5 = {1234567890000, -120000}; + CheckCase( + decimal(38, -4), v5, is_valid2, int64(), e5, options); +} + +TEST_F(TestCast, DecimalToDecimal) { + CastOptions options; + + std::vector is_valid2 = {true, true}; + std::vector is_valid3 = {true, true, false}; + + // simple cases decimal + + std::vector v12 = {Decimal128("02.0000000000"), + Decimal128("30.0000000000")}; + std::vector e12 = {Decimal128("02."), Decimal128("30.")}; + std::vector v13 = {Decimal128("02.0000000000"), Decimal128("30.0000000000"), + Decimal128("30.0000000000")}; + std::vector e13 = {Decimal128("02."), Decimal128("30."), Decimal128("-1.")}; + + for (bool allow_decimal_truncate : {false, true}) { + options.allow_decimal_truncate = allow_decimal_truncate; + CheckCase( + decimal(38, 10), v12, is_valid2, decimal(28, 0), e12, options); + CheckCase( + decimal(38, 10), v13, is_valid3, decimal(28, 0), e13, options); + // and back + CheckCase( + decimal(28, 0), e12, is_valid2, decimal(38, 10), v12, options); + CheckCase( + decimal(28, 0), e13, is_valid3, decimal(38, 10), v13, options); + } + + std::vector v22 = {Decimal128("-02.1234567890"), + Decimal128("30.1234567890")}; + std::vector e22 = {Decimal128("-02."), Decimal128("30.")}; + std::vector f22 = {Decimal128("-02.0000000000"), + Decimal128("30.0000000000")}; + std::vector v23 = {Decimal128("-02.1234567890"), + Decimal128("30.1234567890"), + Decimal128("30.1234567890")}; + std::vector e23 = {Decimal128("-02."), Decimal128("30."), + Decimal128("-70.")}; + std::vector f23 = {Decimal128("-02.0000000000"), + Decimal128("30.0000000000"), + Decimal128("80.0000000000")}; + + options.allow_decimal_truncate = true; + CheckCase( + decimal(38, 10), v22, is_valid2, decimal(28, 0), e22, options); + CheckCase( + decimal(38, 10), v23, is_valid3, decimal(28, 0), e23, options); + // and back + CheckCase( + decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); + CheckCase( + decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); + + options.allow_decimal_truncate = false; + CheckFails(decimal(38, 10), v22, is_valid2, decimal(28, 0), options); + CheckFails(decimal(38, 10), v23, is_valid3, decimal(28, 0), options); + // back case is ok + CheckCase( + decimal(28, 0), e22, is_valid2, decimal(38, 10), f22, options); + CheckCase( + decimal(28, 0), e23, is_valid3, decimal(38, 10), f23, options); +} + TEST_F(TestCast, TimestampToTimestamp) { CastOptions options; diff --git a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h index 2f7cb291d10..bc4ce64d0fb 100644 --- a/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/generated/cast_codegen_internal.h @@ -18,14 +18,14 @@ // THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT // Generated by codegen.py script #define BOOLEAN_CASES(TEMPLATE) \ - TEMPLATE(BooleanType, UInt8Type) \ TEMPLATE(BooleanType, Int8Type) \ - TEMPLATE(BooleanType, UInt16Type) \ TEMPLATE(BooleanType, Int16Type) \ - TEMPLATE(BooleanType, UInt32Type) \ TEMPLATE(BooleanType, Int32Type) \ - TEMPLATE(BooleanType, UInt64Type) \ TEMPLATE(BooleanType, Int64Type) \ + TEMPLATE(BooleanType, UInt8Type) \ + TEMPLATE(BooleanType, UInt16Type) \ + TEMPLATE(BooleanType, UInt32Type) \ + TEMPLATE(BooleanType, UInt64Type) \ TEMPLATE(BooleanType, FloatType) \ TEMPLATE(BooleanType, DoubleType) \ TEMPLATE(BooleanType, StringType) \ @@ -34,12 +34,12 @@ #define UINT8_CASES(TEMPLATE) \ TEMPLATE(UInt8Type, BooleanType) \ TEMPLATE(UInt8Type, Int8Type) \ - TEMPLATE(UInt8Type, UInt16Type) \ TEMPLATE(UInt8Type, Int16Type) \ - TEMPLATE(UInt8Type, UInt32Type) \ TEMPLATE(UInt8Type, Int32Type) \ - TEMPLATE(UInt8Type, UInt64Type) \ TEMPLATE(UInt8Type, Int64Type) \ + TEMPLATE(UInt8Type, UInt16Type) \ + TEMPLATE(UInt8Type, UInt32Type) \ + TEMPLATE(UInt8Type, UInt64Type) \ TEMPLATE(UInt8Type, FloatType) \ TEMPLATE(UInt8Type, DoubleType) \ TEMPLATE(UInt8Type, StringType) \ @@ -47,13 +47,13 @@ #define INT8_CASES(TEMPLATE) \ TEMPLATE(Int8Type, BooleanType) \ + TEMPLATE(Int8Type, Int16Type) \ + TEMPLATE(Int8Type, Int32Type) \ + TEMPLATE(Int8Type, Int64Type) \ TEMPLATE(Int8Type, UInt8Type) \ TEMPLATE(Int8Type, UInt16Type) \ - TEMPLATE(Int8Type, Int16Type) \ TEMPLATE(Int8Type, UInt32Type) \ - TEMPLATE(Int8Type, Int32Type) \ TEMPLATE(Int8Type, UInt64Type) \ - TEMPLATE(Int8Type, Int64Type) \ TEMPLATE(Int8Type, FloatType) \ TEMPLATE(Int8Type, DoubleType) \ TEMPLATE(Int8Type, StringType) \ @@ -61,13 +61,13 @@ #define UINT16_CASES(TEMPLATE) \ TEMPLATE(UInt16Type, BooleanType) \ - TEMPLATE(UInt16Type, UInt8Type) \ TEMPLATE(UInt16Type, Int8Type) \ TEMPLATE(UInt16Type, Int16Type) \ - TEMPLATE(UInt16Type, UInt32Type) \ TEMPLATE(UInt16Type, Int32Type) \ - TEMPLATE(UInt16Type, UInt64Type) \ TEMPLATE(UInt16Type, Int64Type) \ + TEMPLATE(UInt16Type, UInt8Type) \ + TEMPLATE(UInt16Type, UInt32Type) \ + TEMPLATE(UInt16Type, UInt64Type) \ TEMPLATE(UInt16Type, FloatType) \ TEMPLATE(UInt16Type, DoubleType) \ TEMPLATE(UInt16Type, StringType) \ @@ -75,13 +75,13 @@ #define INT16_CASES(TEMPLATE) \ TEMPLATE(Int16Type, BooleanType) \ - TEMPLATE(Int16Type, UInt8Type) \ TEMPLATE(Int16Type, Int8Type) \ + TEMPLATE(Int16Type, Int32Type) \ + TEMPLATE(Int16Type, Int64Type) \ + TEMPLATE(Int16Type, UInt8Type) \ TEMPLATE(Int16Type, UInt16Type) \ TEMPLATE(Int16Type, UInt32Type) \ - TEMPLATE(Int16Type, Int32Type) \ TEMPLATE(Int16Type, UInt64Type) \ - TEMPLATE(Int16Type, Int64Type) \ TEMPLATE(Int16Type, FloatType) \ TEMPLATE(Int16Type, DoubleType) \ TEMPLATE(Int16Type, StringType) \ @@ -89,13 +89,13 @@ #define UINT32_CASES(TEMPLATE) \ TEMPLATE(UInt32Type, BooleanType) \ - TEMPLATE(UInt32Type, UInt8Type) \ TEMPLATE(UInt32Type, Int8Type) \ - TEMPLATE(UInt32Type, UInt16Type) \ TEMPLATE(UInt32Type, Int16Type) \ TEMPLATE(UInt32Type, Int32Type) \ - TEMPLATE(UInt32Type, UInt64Type) \ TEMPLATE(UInt32Type, Int64Type) \ + TEMPLATE(UInt32Type, UInt8Type) \ + TEMPLATE(UInt32Type, UInt16Type) \ + TEMPLATE(UInt32Type, UInt64Type) \ TEMPLATE(UInt32Type, FloatType) \ TEMPLATE(UInt32Type, DoubleType) \ TEMPLATE(UInt32Type, StringType) \ @@ -103,13 +103,13 @@ #define UINT64_CASES(TEMPLATE) \ TEMPLATE(UInt64Type, BooleanType) \ - TEMPLATE(UInt64Type, UInt8Type) \ TEMPLATE(UInt64Type, Int8Type) \ - TEMPLATE(UInt64Type, UInt16Type) \ TEMPLATE(UInt64Type, Int16Type) \ - TEMPLATE(UInt64Type, UInt32Type) \ TEMPLATE(UInt64Type, Int32Type) \ TEMPLATE(UInt64Type, Int64Type) \ + TEMPLATE(UInt64Type, UInt8Type) \ + TEMPLATE(UInt64Type, UInt16Type) \ + TEMPLATE(UInt64Type, UInt32Type) \ TEMPLATE(UInt64Type, FloatType) \ TEMPLATE(UInt64Type, DoubleType) \ TEMPLATE(UInt64Type, StringType) \ @@ -117,13 +117,13 @@ #define INT32_CASES(TEMPLATE) \ TEMPLATE(Int32Type, BooleanType) \ - TEMPLATE(Int32Type, UInt8Type) \ TEMPLATE(Int32Type, Int8Type) \ - TEMPLATE(Int32Type, UInt16Type) \ TEMPLATE(Int32Type, Int16Type) \ + TEMPLATE(Int32Type, Int64Type) \ + TEMPLATE(Int32Type, UInt8Type) \ + TEMPLATE(Int32Type, UInt16Type) \ TEMPLATE(Int32Type, UInt32Type) \ TEMPLATE(Int32Type, UInt64Type) \ - TEMPLATE(Int32Type, Int64Type) \ TEMPLATE(Int32Type, FloatType) \ TEMPLATE(Int32Type, DoubleType) \ TEMPLATE(Int32Type, StringType) \ @@ -131,12 +131,12 @@ #define INT64_CASES(TEMPLATE) \ TEMPLATE(Int64Type, BooleanType) \ - TEMPLATE(Int64Type, UInt8Type) \ TEMPLATE(Int64Type, Int8Type) \ - TEMPLATE(Int64Type, UInt16Type) \ TEMPLATE(Int64Type, Int16Type) \ - TEMPLATE(Int64Type, UInt32Type) \ TEMPLATE(Int64Type, Int32Type) \ + TEMPLATE(Int64Type, UInt8Type) \ + TEMPLATE(Int64Type, UInt16Type) \ + TEMPLATE(Int64Type, UInt32Type) \ TEMPLATE(Int64Type, UInt64Type) \ TEMPLATE(Int64Type, FloatType) \ TEMPLATE(Int64Type, DoubleType) \ @@ -145,32 +145,39 @@ #define FLOAT_CASES(TEMPLATE) \ TEMPLATE(FloatType, BooleanType) \ - TEMPLATE(FloatType, UInt8Type) \ TEMPLATE(FloatType, Int8Type) \ - TEMPLATE(FloatType, UInt16Type) \ TEMPLATE(FloatType, Int16Type) \ - TEMPLATE(FloatType, UInt32Type) \ TEMPLATE(FloatType, Int32Type) \ - TEMPLATE(FloatType, UInt64Type) \ TEMPLATE(FloatType, Int64Type) \ + TEMPLATE(FloatType, UInt8Type) \ + TEMPLATE(FloatType, UInt16Type) \ + TEMPLATE(FloatType, UInt32Type) \ + TEMPLATE(FloatType, UInt64Type) \ TEMPLATE(FloatType, DoubleType) \ TEMPLATE(FloatType, StringType) \ TEMPLATE(FloatType, LargeStringType) #define DOUBLE_CASES(TEMPLATE) \ TEMPLATE(DoubleType, BooleanType) \ - TEMPLATE(DoubleType, UInt8Type) \ TEMPLATE(DoubleType, Int8Type) \ - TEMPLATE(DoubleType, UInt16Type) \ TEMPLATE(DoubleType, Int16Type) \ - TEMPLATE(DoubleType, UInt32Type) \ TEMPLATE(DoubleType, Int32Type) \ - TEMPLATE(DoubleType, UInt64Type) \ TEMPLATE(DoubleType, Int64Type) \ + TEMPLATE(DoubleType, UInt8Type) \ + TEMPLATE(DoubleType, UInt16Type) \ + TEMPLATE(DoubleType, UInt32Type) \ + TEMPLATE(DoubleType, UInt64Type) \ TEMPLATE(DoubleType, FloatType) \ TEMPLATE(DoubleType, StringType) \ TEMPLATE(DoubleType, LargeStringType) +#define DECIMAL128_CASES(TEMPLATE) \ + TEMPLATE(Decimal128Type, Decimal128Type) \ + TEMPLATE(Decimal128Type, Int8Type) \ + TEMPLATE(Decimal128Type, Int16Type) \ + TEMPLATE(Decimal128Type, Int32Type) \ + TEMPLATE(Decimal128Type, Int64Type) + #define DATE32_CASES(TEMPLATE) \ TEMPLATE(Date32Type, Date64Type) @@ -201,41 +208,41 @@ #define STRING_CASES(TEMPLATE) \ TEMPLATE(StringType, BooleanType) \ - TEMPLATE(StringType, UInt8Type) \ TEMPLATE(StringType, Int8Type) \ - TEMPLATE(StringType, UInt16Type) \ TEMPLATE(StringType, Int16Type) \ - TEMPLATE(StringType, UInt32Type) \ TEMPLATE(StringType, Int32Type) \ - TEMPLATE(StringType, UInt64Type) \ TEMPLATE(StringType, Int64Type) \ + TEMPLATE(StringType, UInt8Type) \ + TEMPLATE(StringType, UInt16Type) \ + TEMPLATE(StringType, UInt32Type) \ + TEMPLATE(StringType, UInt64Type) \ TEMPLATE(StringType, FloatType) \ TEMPLATE(StringType, DoubleType) \ TEMPLATE(StringType, TimestampType) #define LARGESTRING_CASES(TEMPLATE) \ TEMPLATE(LargeStringType, BooleanType) \ - TEMPLATE(LargeStringType, UInt8Type) \ TEMPLATE(LargeStringType, Int8Type) \ - TEMPLATE(LargeStringType, UInt16Type) \ TEMPLATE(LargeStringType, Int16Type) \ - TEMPLATE(LargeStringType, UInt32Type) \ TEMPLATE(LargeStringType, Int32Type) \ - TEMPLATE(LargeStringType, UInt64Type) \ TEMPLATE(LargeStringType, Int64Type) \ + TEMPLATE(LargeStringType, UInt8Type) \ + TEMPLATE(LargeStringType, UInt16Type) \ + TEMPLATE(LargeStringType, UInt32Type) \ + TEMPLATE(LargeStringType, UInt64Type) \ TEMPLATE(LargeStringType, FloatType) \ TEMPLATE(LargeStringType, DoubleType) \ TEMPLATE(LargeStringType, TimestampType) #define DICTIONARY_CASES(TEMPLATE) \ - TEMPLATE(DictionaryType, UInt8Type) \ TEMPLATE(DictionaryType, Int8Type) \ - TEMPLATE(DictionaryType, UInt16Type) \ TEMPLATE(DictionaryType, Int16Type) \ - TEMPLATE(DictionaryType, UInt32Type) \ TEMPLATE(DictionaryType, Int32Type) \ - TEMPLATE(DictionaryType, UInt64Type) \ TEMPLATE(DictionaryType, Int64Type) \ + TEMPLATE(DictionaryType, UInt8Type) \ + TEMPLATE(DictionaryType, UInt16Type) \ + TEMPLATE(DictionaryType, UInt32Type) \ + TEMPLATE(DictionaryType, UInt64Type) \ TEMPLATE(DictionaryType, FloatType) \ TEMPLATE(DictionaryType, DoubleType) \ TEMPLATE(DictionaryType, Date32Type) \ diff --git a/cpp/src/arrow/compute/kernels/generated/codegen.py b/cpp/src/arrow/compute/kernels/generated/codegen.py index 76d27ad3344..e3bdd58a0b9 100644 --- a/cpp/src/arrow/compute/kernels/generated/codegen.py +++ b/cpp/src/arrow/compute/kernels/generated/codegen.py @@ -23,9 +23,8 @@ import io import os - -INTEGER_TYPES = ['UInt8', 'Int8', 'UInt16', 'Int16', - 'UInt32', 'Int32', 'UInt64', 'Int64'] +SIGNED_INTEGER_TYPES = ['Int8', 'Int16', 'Int32', 'Int64'] +INTEGER_TYPES = SIGNED_INTEGER_TYPES + ['UInt8', 'UInt16', 'UInt32', 'UInt64'] FLOATING_TYPES = ['Float', 'Double'] NUMERIC_TYPES = ['Boolean'] + INTEGER_TYPES + FLOATING_TYPES STRING_TYPES = ['String', 'LargeString'] @@ -78,6 +77,8 @@ def generate(self): CastCodeGenerator('Int64', NUMERIC_TYPES + STRING_TYPES), CastCodeGenerator('Float', NUMERIC_TYPES + STRING_TYPES), CastCodeGenerator('Double', NUMERIC_TYPES + STRING_TYPES), + CastCodeGenerator('Decimal128', ['Decimal128'] + SIGNED_INTEGER_TYPES, + parametric=True), CastCodeGenerator('Date32', ['Date64']), CastCodeGenerator('Date64', ['Date32']), CastCodeGenerator('Time32', ['Time32', 'Time64'], diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 918b28960d5..7aed75a1a57 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -130,6 +130,7 @@ class StructArray; class StructBuilder; struct StructScalar; +class Decimal128; class Decimal128Type; class Decimal128Array; class Decimal128Builder; diff --git a/cpp/src/arrow/util/macros.h b/cpp/src/arrow/util/macros.h index 7d04a80e802..ed2d6edd2a2 100644 --- a/cpp/src/arrow/util/macros.h +++ b/cpp/src/arrow/util/macros.h @@ -45,19 +45,19 @@ // the absence of better information (ie. -fprofile-arcs). // #if defined(__GNUC__) -#define ARROW_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#define ARROW_PREDICT_FALSE(x) (__builtin_expect(!!(x), 0)) #define ARROW_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) #define ARROW_NORETURN __attribute__((noreturn)) #define ARROW_PREFETCH(addr) __builtin_prefetch(addr) #elif defined(_MSC_VER) #define ARROW_NORETURN __declspec(noreturn) -#define ARROW_PREDICT_FALSE(x) x -#define ARROW_PREDICT_TRUE(x) x +#define ARROW_PREDICT_FALSE(x) (x) +#define ARROW_PREDICT_TRUE(x) (x) #define ARROW_PREFETCH(addr) #else #define ARROW_NORETURN -#define ARROW_PREDICT_FALSE(x) x -#define ARROW_PREDICT_TRUE(x) x +#define ARROW_PREDICT_FALSE(x) (x) +#define ARROW_PREDICT_TRUE(x) (x) #define ARROW_PREFETCH(addr) #endif