From 458b9bd3999730950792db79b6f40ab52fd8bd5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Pedro?= Date: Tue, 11 May 2021 11:25:43 -0300 Subject: [PATCH 1/2] Add math functions --- cpp/src/gandiva/function_registry_common.h | 11 + cpp/src/gandiva/function_registry_math_ops.cc | 46 +++- cpp/src/gandiva/gdv_function_stubs.cc | 102 +++++++++ cpp/src/gandiva/gdv_function_stubs.h | 18 ++ cpp/src/gandiva/gdv_function_stubs_test.cc | 62 ++++++ cpp/src/gandiva/precompiled/decimal_ops.cc | 27 ++- cpp/src/gandiva/precompiled/decimal_ops.h | 12 +- .../gandiva/precompiled/extended_math_ops.cc | 177 ++++++++++++++-- .../precompiled/extended_math_ops_test.cc | 200 +++++++++++++++++- cpp/src/gandiva/precompiled/types.h | 38 +++- cpp/src/gandiva/tests/projector_test.cc | 52 ++++- 11 files changed, 710 insertions(+), 35 deletions(-) diff --git a/cpp/src/gandiva/function_registry_common.h b/cpp/src/gandiva/function_registry_common.h index 5ce21125abe..d2f08322caf 100644 --- a/cpp/src/gandiva/function_registry_common.h +++ b/cpp/src/gandiva/function_registry_common.h @@ -127,6 +127,17 @@ typedef std::unordered_map ALIASES, DataTypeVector{IN_TYPE()}, \ OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE)) +// Unary functions in gdv stubs file that : +// - NULL handling is of type NULL_IF_NULL +// +// The pre-compiled fn name includes the base name & input type name. +// eg. gdv_fn_abs_float32 +#define STUBS_UNARY_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \ + NativeFunction(#NAME, std::vector ALIASES, DataTypeVector{IN_TYPE()}, \ + OUT_TYPE(), kResultNullIfNull, \ + ARROW_STRINGIFY(gdv_fn_##NAME##_##IN_TYPE), \ + NativeFunction::kNeedsContext) + // Unary functions that : // - NULL handling is of type NULL_NEVER // diff --git a/cpp/src/gandiva/function_registry_math_ops.cc b/cpp/src/gandiva/function_registry_math_ops.cc index 87c7aed2ad9..b7e957e139e 100644 --- a/cpp/src/gandiva/function_registry_math_ops.cc +++ b/cpp/src/gandiva/function_registry_math_ops.cc @@ -28,6 +28,22 @@ namespace gandiva { UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \ UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64) +#define MATH_UNARY_OPS_SAME_TYPE_RETURN(name, ALIASES) \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32), \ + UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64) + +#define STUBS_MATH_UNARY_OPS_SAME_TYPE_RETURN(name, ALIASES) \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32), \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64), \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32), \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64), \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32), \ + STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64) + #define MATH_BINARY_UNSAFE(name, ALIASES) \ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \ BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \ @@ -61,8 +77,8 @@ namespace gandiva { std::vector GetMathOpsFunctionRegistry() { static std::vector math_fn_registry_ = { - MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {}), - MATH_UNARY_OPS(log10, {}), + MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {"ln"}), + MATH_UNARY_OPS(log2, {}), MATH_UNARY_OPS(log10, {}), MATH_BINARY_UNSAFE(log, {}), @@ -86,6 +102,28 @@ std::vector GetMathOpsFunctionRegistry() { MATH_UNARY_OPS(cot, {}), MATH_UNARY_OPS(radians, {}), MATH_UNARY_OPS(degrees, {}), MATH_BINARY_SAFE(atan2, {}), + // extended functions + MATH_UNARY_OPS(sqrt, {}), STUBS_MATH_UNARY_OPS_SAME_TYPE_RETURN(abs, {}), + MATH_UNARY_OPS_SAME_TYPE_RETURN(sign, {}), + UNARY_SAFE_NULL_IF_NULL(ceil, {}, float32, float32), + UNARY_SAFE_NULL_IF_NULL(ceil, {}, float64, float64), + UNARY_SAFE_NULL_IF_NULL(floor, {}, float32, float32), + UNARY_SAFE_NULL_IF_NULL(floor, {}, float64, float64), + UNARY_SAFE_NULL_IF_NULL(lshift, {"shiftleft"}, int32, int32), + UNARY_SAFE_NULL_IF_NULL(lshift, {"shiftleft"}, int64, int64), + UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftright"}, int32, int32), + UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftright"}, int64, int64), + UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftrightunsigned"}, uint32, uint32), + UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftrightunsigned"}, uint64, uint64), + UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32), + UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int64), + UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, float32), + UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, float64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32, int32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int32, int64), + BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, int32, float32), + BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, int32, float64), + // decimal functions UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128), UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128), @@ -97,8 +135,8 @@ std::vector GetMathOpsFunctionRegistry() { decimal128), BINARY_SYMMETRIC_SAFE_NULL_NEVER_FN(nvl, {}), - NativeFunction("truncate", {"trunc"}, DataTypeVector{int64(), int32()}, int64(), - kResultNullIfNull, "truncate_int64_int32"), + NativeFunction("pi", {}, {}, float64(), kResultNullIfNull, "pi"), + NativeFunction("e", {}, {}, float64(), kResultNullIfNull, "e"), NativeFunction("random", {"rand"}, DataTypeVector{}, float64(), kResultNullNever, "gdv_fn_random", NativeFunction::kNeedsFunctionHolder), NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index d31d8cd63d1..cd70fe595fa 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -20,6 +20,9 @@ #include #include +#include +#include +#include #include #include @@ -1165,6 +1168,49 @@ int32_t gdv_fn_cast_intervalyear_utf8_int32(int64_t context_ptr, int64_t holder_ auto* holder = reinterpret_cast(holder_ptr); return (*holder)(context, data, data_len, in1_validity, out_valid); } + +#define ABS_SIGNED_INTEGER_TYPES(INNER) INNER(int32) INNER(int64) +#define ABS_UNSIGNED_INTEGER_TYPES(INNER) INNER(uint32) INNER(uint64) +#define ABS_REAL_TYPES(INNER) INNER(float32) INNER(float64) + +// Abs for unsigned types. The function can return an error if occurs an overflow +#define ABS_FOR_SIGNED_INTEGER(IN_TYPE) \ + GANDIVA_EXPORT \ + gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { \ + if (in <= std::numeric_limits::min()) { \ + std::string error_msg("Overflow in abs execution"); \ + gdv_fn_context_set_error_msg(context, error_msg.data()); \ + return static_cast(0); \ + } \ + \ + gdv_##IN_TYPE zero = 0; \ + if (in < zero) { \ + gdv_##IN_TYPE minus_one = -1; \ + return (minus_one * in); \ + } \ + \ + return in; \ + } + +#define ABS_FOR_REAL_TYPES(IN_TYPE) \ + GANDIVA_EXPORT \ + gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { \ + return static_cast(fabs(in)); \ + } + +// Optimization in abs function for unsigned types +#define ABS_FOR_UNSIGNED_INTEGER(IN_TYPE) \ + GANDIVA_EXPORT \ + gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { return in; } + +ABS_SIGNED_INTEGER_TYPES(ABS_FOR_SIGNED_INTEGER) +ABS_UNSIGNED_INTEGER_TYPES(ABS_FOR_UNSIGNED_INTEGER) +ABS_REAL_TYPES(ABS_FOR_REAL_TYPES) + +#undef ABS_SIGNED_INTEGER_TYPES +#undef ABS_UNSIGNED_INTEGER_TYPES +#undef ABS_FOR_SIGNED_INTEGER +#undef ABS_FOR_UNSIGNED_INTEGER } namespace gandiva { @@ -2309,5 +2355,61 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc( "gdv_fn_cast_intervalyear_utf8_int32", types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_cast_intervalyear_utf8_int32)); + + // gdv_fn_abs_int32 + args = { + types->i64_type(), // context + types->i32_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_int32", types->i32_type() /*return_type*/, + args, reinterpret_cast(gdv_fn_abs_int32)); + + // gdv_fn_abs_int64 + args = { + types->i64_type(), // context + types->i64_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_int64", types->i64_type() /*return_type*/, + args, reinterpret_cast(gdv_fn_abs_int64)); + + // gdv_fn_abs_uint32 + args = { + types->i64_type(), // context + types->i32_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_uint32", types->i32_type() /*return_type*/, + args, reinterpret_cast(gdv_fn_abs_uint32)); + + // gdv_fn_abs_uint64 + args = { + types->i64_type(), // context + types->i64_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_uint64", types->i64_type() /*return_type*/, + args, reinterpret_cast(gdv_fn_abs_uint64)); + + // gdv_fn_abs_float32 + args = { + types->i64_type(), // context + types->float_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_float32", + types->float_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_abs_float32)); + + // gdv_fn_abs_float64 + args = { + types->i64_type(), // context + types->double_type(), // value + }; + + engine->AddGlobalMappingForFunc("gdv_fn_abs_float64", + types->double_type() /*return_type*/, args, + reinterpret_cast(gdv_fn_abs_float64)); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index d39d2940423..b6edd9186ab 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -200,4 +200,22 @@ GANDIVA_EXPORT const char* gdv_mask_last_n_utf8_int32(int64_t context, const char* data, int32_t data_len, int32_t n_to_mask, int32_t* out_len); + +GANDIVA_EXPORT +int32_t gdv_fn_abs_int32(int64_t context, int32_t in); + +GANDIVA_EXPORT +int64_t gdv_fn_abs_int64(int64_t context, int64_t in); + +GANDIVA_EXPORT +uint32_t gdv_fn_abs_uint32(int64_t context, uint32_t in); + +GANDIVA_EXPORT +uint64_t gdv_fn_abs_uint64(int64_t context, uint64_t in); + +GANDIVA_EXPORT +float gdv_fn_abs_float32(int64_t context, float in); + +GANDIVA_EXPORT +double gdv_fn_abs_float64(int64_t context, double in); } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index c0b938c94e1..974bee86d17 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -21,11 +21,18 @@ #include #include +#include +#include + #include "arrow/util/logging.h" #include "gandiva/execution_context.h" namespace gandiva { +void VerifyAlmostEquals(double actual, double expected, double max_error = FLT_EPSILON) { + EXPECT_TRUE(fabs(actual - expected) < max_error) << actual << " != " << expected; +} + TEST(TestGdvFnStubs, TestCrc32) { gandiva::ExecutionContext ctx; auto ctx_ptr = reinterpret_cast(&ctx); @@ -949,4 +956,59 @@ TEST(TestGdvFnStubs, TestMaskLastN) { EXPECT_EQ(expected, std::string(result, out_len)); } +TEST(TestGdvFnStubs, TestAbs) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + + // Abs functions + EXPECT_EQ(gdv_fn_abs_int32(ctx_ptr, 0), 0); + EXPECT_EQ(gdv_fn_abs_uint32(ctx_ptr, 0), 0); + EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, 0), 0L); + EXPECT_EQ(gdv_fn_abs_uint64(ctx_ptr, 0), 0L); + VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, 0.0f), abs(0.0f)); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, 0.0), abs(0.0)); + + EXPECT_EQ(gdv_fn_abs_int32(ctx_ptr, (INT32_MIN + 1)), abs(INT32_MIN + 1)); + EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, (INT64_MIN + 1)), abs(INT64_MIN + 1)); + VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, static_cast(INT32_MIN + 1)), + abs(static_cast(INT32_MIN + 1))); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, static_cast(INT32_MIN + 1)), + abs(static_cast(INT32_MIN + 1))); + + VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, std::numeric_limits::max()), + abs(std::numeric_limits::max())); + VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, std::numeric_limits::min()), + abs(std::numeric_limits::min())); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, std::numeric_limits::max()), + abs(std::numeric_limits::max())); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, std::numeric_limits::min()), + abs(std::numeric_limits::min())); + + EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, (INT64_MIN + 1)), + abs(static_cast(INT64_MIN + 1))); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, static_cast(INT64_MIN + 1)), + abs(static_cast(INT64_MIN + 1))); + + VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, -3600.50f), abs(-3600.50f)); + VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, -3600.50), abs(-3600.50)); +} + +TEST(TestGdvFnStubs, TestAbsOverflow) { + gandiva::ExecutionContext ctx; + uint64_t ctx_ptr = reinterpret_cast(&ctx); + + int64_t value = gdv_fn_abs_int64(ctx_ptr, INT64_MIN); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Overflow in abs execution")); + EXPECT_EQ(value, 0); + + ctx.Reset(); + + int32_t value_int32 = gdv_fn_abs_int32(ctx_ptr, INT32_MIN); + EXPECT_TRUE(ctx.has_error()); + EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Overflow in abs execution")); + EXPECT_EQ(value_int32, 0); + + ctx.Reset(); +} } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/decimal_ops.cc b/cpp/src/gandiva/precompiled/decimal_ops.cc index 61cac60624d..e6ada2780e1 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.cc +++ b/cpp/src/gandiva/precompiled/decimal_ops.cc @@ -488,12 +488,27 @@ static std::array kDoubleScaleMultip return values; })(); -BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow) { +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow, + RoundType round_type) { // Multiply decimal with the scale auto unscaled = in * kDoubleScaleMultipliers[scale]; DECIMAL_OVERFLOW_IF(std::isnan(unscaled), overflow); - unscaled = std::round(unscaled); + switch (round_type) { + case RoundType::kRoundTypeCeil: + unscaled = std::ceil(unscaled); + break; + case RoundType::kRoundTypeFloor: + unscaled = std::floor(unscaled); + break; + case RoundType::kRoundTypeTrunc: + unscaled = std::trunc(unscaled); + break; + case RoundType::kRoundTypeHalfRoundUp: + default: + unscaled = std::round(unscaled); + break; + } // convert scaled double to int128 int32_t sign = unscaled < 0 ? -1 : 1; @@ -551,14 +566,6 @@ static BasicDecimal128 ModifyScaleAndPrecision(const BasicDecimalScalar128& x, } } -enum RoundType { - kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding. - kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding. - kRoundTypeTrunc, // no rounding, truncate the trailing digits. - kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1. - // else if -ve and trailing value is >= half of base, -1. -}; - // Compute the rounding delta for the givven rounding type. static int32_t ComputeRoundingDelta(const BasicDecimal128& x, int32_t x_scale, int32_t out_scale, RoundType type) { diff --git a/cpp/src/gandiva/precompiled/decimal_ops.h b/cpp/src/gandiva/precompiled/decimal_ops.h index 292dce2208c..20643794115 100644 --- a/cpp/src/gandiva/precompiled/decimal_ops.h +++ b/cpp/src/gandiva/precompiled/decimal_ops.h @@ -19,11 +19,20 @@ #include #include + #include "gandiva/basic_decimal_scalar.h" namespace gandiva { namespace decimalops { +enum RoundType { + kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding. + kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding. + kRoundTypeTrunc, // no rounding, truncate the trailing digits. + kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1. + // else if -ve and trailing value is >= half of base, -1. +}; + /// Return the sum of 'x' and 'y'. /// out_precision and out_scale are passed along for efficiency, they must match /// the rules in DecimalTypeSql::GetResultType. @@ -57,7 +66,8 @@ arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x, int32_t Compare(const BasicDecimalScalar128& x, const BasicDecimalScalar128& y); /// Convert to decimal from double. -BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow); +BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow, + RoundType roundType = RoundType::kRoundTypeHalfRoundUp); /// Convert from decimal to double. double ToDouble(const BasicDecimalScalar128& in, bool* overflow); diff --git a/cpp/src/gandiva/precompiled/extended_math_ops.cc b/cpp/src/gandiva/precompiled/extended_math_ops.cc index 96fe7fb9e3e..982328ddeaa 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops.cc @@ -40,6 +40,23 @@ extern "C" { INNER(float32, OUT_TYPE) \ INNER(float64, OUT_TYPE) +#define ESIGNED_INTEGER_TYPES(INNER) \ + INNER(int32) \ + INNER(int64) + +#define EINTEGER_TYPES(INNER) \ + ESIGNED_INTEGER_TYPES(INNER) \ + INNER(uint32) \ + INNER(uint64) + +#define EREAL_TYPES(INNER) \ + INNER(float32) \ + INNER(float64) + +#define ENUMERIC_TYPES(INNER) \ + EINTEGER_TYPES(INNER) \ + EREAL_TYPES(INNER) + // Cubic root #define CBRT(IN_TYPE, OUT_TYPE) \ FORCE_INLINE \ @@ -78,6 +95,15 @@ ENUMERIC_TYPES_UNARY(LOG, float64) ENUMERIC_TYPES_UNARY(LOG10, float64) +// log base 2 +#define LOG2(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE log2_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast(log2l(static_cast(in))); \ + } + +ENUMERIC_TYPES_UNARY(LOG2, float64) + FORCE_INLINE void set_error_for_logbase(int64_t execution_context, double base) { char const* prefix = "divide by zero error with log of base"; @@ -213,6 +239,24 @@ ENUMERIC_TYPES_UNARY(RADIANS, float64) } ENUMERIC_TYPES_UNARY(DEGREES, float64) +// Ceil +#define CEIL(IN_TYPE) \ + FORCE_INLINE \ + gdv_##IN_TYPE ceil_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast(ceil(static_cast(in))); \ + } +CEIL(float32) +CEIL(float64) + +// Floor +#define FLOOR(IN_TYPE) \ + FORCE_INLINE \ + gdv_##IN_TYPE floor_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast(floor(static_cast(in))); \ + } +FLOOR(float32) +FLOOR(float64) + // power #define POWER(IN_TYPE1, IN_TYPE2, OUT_TYPE) \ FORCE_INLINE \ @@ -221,6 +265,71 @@ ENUMERIC_TYPES_UNARY(DEGREES, float64) } POWER(float64, float64, float64) +// Sqrt +#define SQRT(IN_TYPE, OUT_TYPE) \ + FORCE_INLINE \ + gdv_##OUT_TYPE sqrt_##IN_TYPE(gdv_##IN_TYPE in) { \ + return static_cast(sqrt(static_cast(in))); \ + } +ENUMERIC_TYPES_UNARY(SQRT, float64) + +// Sign +#define SIGN(IN_TYPE) \ + FORCE_INLINE \ + gdv_##IN_TYPE sign_##IN_TYPE(gdv_##IN_TYPE in) { \ + if (in == 0.0) return in; \ + return static_cast(copysign(1.0, static_cast(in))); \ + } +ENUMERIC_TYPES(SIGN) + +// Lshift - Left Arithmetical Shift (multiplication by 2) +// In C/C++ left shift of a negative number is undefined (C++11 standard 5.8.2) +// Mimic Java/etc. and treat left shift as based on two's complement representation +// Assumes two's complement machine. +#define LSHIFT(IN_TYPE, MAX_DIGITS) \ + FORCE_INLINE \ + gdv_##IN_TYPE lshift_##IN_TYPE(gdv_##IN_TYPE in1, gdv_##IN_TYPE in2) { \ + if (in2 < 0 || in2 >= (MAX_DIGITS)) { \ + return in1; \ + } \ + \ + auto unsigned_in1 = static_cast(in1); \ + auto unsigned_shift_op = static_cast(in2); \ + \ + return static_cast(unsigned_in1 << unsigned_shift_op); \ + } + +// Logical right shift when Arg0 is unsigned +// Arithmetic otherwise (this is implementation-defined but GCC and MSVC document this +// as arithmetic right shift) +// https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation +// https://docs.microsoft.com/en-us/cpp/cpp/left-shift-and-right-shift-operators-input-and-output?view=msvc-160 +// Clang doesn't document their behavior. +#define RSHIFT(IN_TYPE, MAX_DIGITS) \ + FORCE_INLINE \ + gdv_##IN_TYPE rshift_##IN_TYPE(gdv_##IN_TYPE in1, gdv_##IN_TYPE in2) { \ + if (in2 < 0 || in2 >= (MAX_DIGITS)) { \ + return in1; \ + } \ + return in1 >> in2; \ + } + +LSHIFT(int32, 32) +LSHIFT(int64, 64) +RSHIFT(int32, 32) +RSHIFT(int64, 64) +RSHIFT(uint32, 32) +RSHIFT(uint64, 64) + +#undef RSHIFT +#undef LSHIFT + +FORCE_INLINE +gdv_float64 pi() { return static_cast(M_PI); } + +FORCE_INLINE +gdv_float64 e() { return static_cast(exp(1.0)); } + FORCE_INLINE gdv_int32 round_int32(gdv_int32 num) { return num; } @@ -384,19 +493,61 @@ gdv_int64 get_power_of_10(gdv_int32 exp) { return power_of_10[exp]; } -FORCE_INLINE -gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale) { - bool overflow = false; - arrow::BasicDecimal128 decimal = gandiva::decimalops::FromInt64(in, 38, 0, &overflow); - arrow::BasicDecimal128 decimal_with_outscale = - gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), 38, - out_scale, out_scale, &overflow); - if (out_scale < 0) { - out_scale = 0; - } - return gandiva::decimalops::ToInt64( - gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), &overflow); -} +#define TRUNCATE_INTEGER(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE truncate_##TYPE(gdv_##TYPE value) { return value; } + +#define TRUNCATE_INTEGER_WITH_OUT_SCALE(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE truncate_##TYPE##_int32(gdv_##TYPE in, gdv_int32 out_scale) { \ + bool overflow = false; \ + arrow::BasicDecimal128 decimal = \ + gandiva::decimalops::FromInt64(static_cast(in), 38, 0, &overflow); \ + arrow::BasicDecimal128 decimal_with_outscale = \ + gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), \ + 38, out_scale, out_scale, &overflow); \ + if (out_scale < 0) { \ + out_scale = 0; \ + } \ + return static_cast(gandiva::decimalops::ToInt64( \ + gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), \ + &overflow)); \ + } + +#define TRUNCATE_REAL(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE truncate_##TYPE(gdv_##TYPE in) { return truncate_##TYPE##_int32(in, 0); } + +#define TRUNCATE_REAL_WITH_OUT_SCALE(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE truncate_##TYPE##_int32(gdv_##TYPE in, gdv_int32 out_scale) { \ + bool overflow = false; \ + auto roundType = gandiva::decimalops::RoundType::kRoundTypeTrunc; \ + int32_t conversionScale = out_scale >= 0 ? out_scale : 0; \ + \ + arrow::BasicDecimal128 decimal = gandiva::decimalops::FromDouble( \ + static_cast(in), 38, conversionScale, &overflow, roundType); \ + arrow::BasicDecimal128 decimal_with_outscale = gandiva::decimalops::Truncate( \ + gandiva::BasicDecimalScalar128(decimal, 38, conversionScale), 38, out_scale, \ + out_scale, &overflow); \ + if (out_scale < 0) { \ + out_scale = 0; \ + } \ + \ + return static_cast(gandiva::decimalops::ToDouble( \ + gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), \ + &overflow)); \ + } + +ESIGNED_INTEGER_TYPES(TRUNCATE_INTEGER) +ESIGNED_INTEGER_TYPES(TRUNCATE_INTEGER_WITH_OUT_SCALE) +EREAL_TYPES(TRUNCATE_REAL_WITH_OUT_SCALE) +EREAL_TYPES(TRUNCATE_REAL) + +#undef TRUNCATE_INTEGER +#undef TRUNCATE_INTEGER_WITH_OUT_SCALE +#undef TRUNCATE_REAL +#undef TRUNCATE_REAL_WITH_OUT_SCALE FORCE_INLINE gdv_float64 get_scale_multiplier(gdv_int32 scale) { diff --git a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc index 3e9d8a5d2cd..31eea030b9b 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc @@ -41,7 +41,7 @@ TEST(TestExtendedMathOps, TestCbrt) { VerifyFuzzyEquals(cbrt_float64(27), 3); VerifyFuzzyEquals(cbrt_float64(-27), -3); - VerifyFuzzyEquals(cbrt_float32(15.625), 2.5); + VerifyFuzzyEquals(cbrt_float32(15.625f), 2.5f); VerifyFuzzyEquals(cbrt_float64(15.625), 2.5); } @@ -113,6 +113,13 @@ TEST(TestExtendedMathOps, TestLog10) { VerifyFuzzyEquals(log10_float64(100), 2); } +TEST(TestExtendedMathOps, TestLog2) { + VerifyFuzzyEquals(log2_int32(1024), 10); + VerifyFuzzyEquals(log2_int64(1024), 10); + VerifyFuzzyEquals(log2_float32(1024.0f), 10.0f); + VerifyFuzzyEquals(log2_float64(1024.0), 10.0); +} + TEST(TestExtendedMathOps, TestPower) { VerifyFuzzyEquals(power_float64_float64(2, 5.4), 42.22425314473263); VerifyFuzzyEquals(power_float64_float64(5.4, 2), 29.160000000000004); @@ -201,13 +208,64 @@ TEST(TestExtendedMathOps, TestRound) { EXPECT_EQ(round_int64_int32(345353425343, -12), 0); } -TEST(TestExtendedMathOps, TestTruncate) { +TEST(TestExtendedMathOps, TestTruncateWithScale) { + // Test the truncate function for longs EXPECT_EQ(truncate_int64_int32(1234, 4), 1234); EXPECT_EQ(truncate_int64_int32(-1234, 4), -1234); EXPECT_EQ(truncate_int64_int32(1234, -4), 0); EXPECT_EQ(truncate_int64_int32(-1234, -2), -1200); EXPECT_EQ(truncate_int64_int32(8124674407369523212, 0), 8124674407369523212); EXPECT_EQ(truncate_int64_int32(8124674407369523212, -2), 8124674407369523200); + + // Test truncate function for integers + EXPECT_EQ(truncate_int32_int32(1234, 4), 1234); + EXPECT_EQ(truncate_int32_int32(-1234, 4), -1234); + EXPECT_EQ(truncate_int32_int32(1234, -4), 0); + EXPECT_EQ(truncate_int32_int32(-1234, -2), -1200); + EXPECT_EQ(truncate_int32_int32(8124674, 0), 8124674); + EXPECT_EQ(truncate_int32_int32(8124674, -2), 8124600); +} + +TEST(TestExtendedMathOps, TestTruncateWithoutScale) { + // Test the truncate function for longs + EXPECT_EQ(truncate_int64(1234), 1234); + EXPECT_EQ(truncate_int64(-1234), -1234); + EXPECT_EQ(truncate_int64(8124674407369523212), 8124674407369523212); + + // Test truncate function for integers + EXPECT_EQ(truncate_int32(1234), 1234); + EXPECT_EQ(truncate_int32(-1234), -1234); + EXPECT_EQ(truncate_int32(8124674), 8124674); +} + +TEST(TestExtendedMathOps, TestTruncateFloat) { + VerifyFuzzyEquals(truncate_float32(1234.245f), 1234.0f); + VerifyFuzzyEquals(truncate_float32(-11.7892f), -11.0f); + VerifyFuzzyEquals(truncate_float32(1.4999999f), 1.0f); + EXPECT_EQ(std::signbit(truncate_float32(0)), 0); + VerifyFuzzyEquals(truncate_float32_int32(1234.789f, 2), 1234.78f); + VerifyFuzzyEquals(truncate_float32_int32(1234.12345f, -3), 1000.0f); + VerifyFuzzyEquals(truncate_float32_int32(-1234.4567f, 3), -1234.456f); + VerifyFuzzyEquals(truncate_float32_int32(-1234.4567f, -3), -1000.0f); + VerifyFuzzyEquals(truncate_float32_int32(1234.4567f, 0), 1234); + VerifyFuzzyEquals(truncate_float32_int32(1.5499999523162842f, 1), 1.5f); + EXPECT_EQ(std::signbit(truncate_float32_int32(0, 5)), 0); + VerifyFuzzyEquals(truncate_float32_int32(static_cast(1.55), 1), 1.5f); + VerifyFuzzyEquals(truncate_float32_int32(static_cast(9.134123), 2), 9.13f); + VerifyFuzzyEquals(truncate_float32_int32(static_cast(-1.923), 1), -1.9f); + + VerifyFuzzyEquals(truncate_float64(1234.245), 1234.0); + VerifyFuzzyEquals(truncate_float64(-11.7892), -11.0); + VerifyFuzzyEquals(truncate_float64(1.4999999), 1.0); + EXPECT_EQ(std::signbit(truncate_float64(0)), 0); + VerifyFuzzyEquals(truncate_float64_int32(1234.789, 2), 1234.78); + VerifyFuzzyEquals(truncate_float64_int32(1234.12345, -3), 1000.0); + VerifyFuzzyEquals(truncate_float64_int32(-1234.4567, 3), -1234.456); + VerifyFuzzyEquals(truncate_float64_int32(-1234.4567, -3), -1000.0); + VerifyFuzzyEquals(truncate_float64_int32(1234.4567, 0), 1234.0); + EXPECT_EQ(std::signbit(truncate_float64_int32(0, -2)), 0); + VerifyFuzzyEquals(truncate_float64_int32((double)INT_MAX + 1, 0), (double)INT_MAX + 1); + VerifyFuzzyEquals(truncate_float64_int32((double)INT_MIN - 1, 0), (double)INT_MIN - 1); } TEST(TestExtendedMathOps, TestTrigonometricFunctions) { @@ -407,4 +465,142 @@ TEST(TestExtendedMathOps, TestBinRepresentation) { "1000000000000000000000000000000000000000000000000000000000000000"); EXPECT_FALSE(ctx.has_error()); } + +TEST(TestExtendedMathOps, TestCeil) { + // Ceil functions + VerifyFuzzyEquals(ceil_float32(0), ceil(0)); + VerifyFuzzyEquals(ceil_float64(0), ceil(0)); + + VerifyFuzzyEquals(ceil_float32(-5), ceil(-5)); + VerifyFuzzyEquals(ceil_float64(-5), ceil(-5)); + + VerifyFuzzyEquals(ceil_float32(-2371041), ceil(-2371041)); + VerifyFuzzyEquals(ceil_float64(-2371041), ceil(-2371041)); + + VerifyFuzzyEquals(ceil_float32(5.45f), ceil(5.45f)); + VerifyFuzzyEquals(ceil_float64(5.45), ceil(5.45)); + + VerifyFuzzyEquals(ceil_float32(-3600.50f), ceil(-3600.50f)); + VerifyFuzzyEquals(ceil_float64(-3600.50), ceil(-3600.50)); +} + +TEST(TestExtendedMathOps, TestFloor) { + // Ceil functions + VerifyFuzzyEquals(floor_float32(0), floor(0)); + VerifyFuzzyEquals(floor_float64(0), floor(0)); + + VerifyFuzzyEquals(floor_float32(-5), floor(-5)); + VerifyFuzzyEquals(floor_float64(-5), floor(-5)); + + VerifyFuzzyEquals(floor_float32(-2371041), floor(-2371041)); + VerifyFuzzyEquals(floor_float64(-2371041), floor(-2371041)); + + VerifyFuzzyEquals(floor_float32(5.45f), floor(5.45f)); + VerifyFuzzyEquals(floor_float64(5.45), floor(5.45)); + + VerifyFuzzyEquals(floor_float32(-3600.50f), floor(-3600.50f)); + VerifyFuzzyEquals(floor_float64(-3600.50), floor(-3600.50)); +} + +TEST(TestExtendedMathOps, TestConstants) { + // Constants functions + VerifyFuzzyEquals(pi(), M_PI); + VerifyFuzzyEquals(e(), exp(1.0)); +} + +TEST(TestExtendedMathOps, TestSqrt) { + // Sqrt functions + VerifyFuzzyEquals(sqrt_int32(0), sqrt(0)); + VerifyFuzzyEquals(sqrt_int64(0), sqrt(0)); + VerifyFuzzyEquals(sqrt_float32(0), sqrt(0)); + VerifyFuzzyEquals(sqrt_float64(0), sqrt(0)); + + VerifyFuzzyEquals(sqrt_int32(5), sqrt(5)); + VerifyFuzzyEquals(sqrt_int64(5), sqrt(5)); + VerifyFuzzyEquals(sqrt_float32(5), sqrt(5)); + VerifyFuzzyEquals(sqrt_float64(5), sqrt(5)); + + VerifyFuzzyEquals(sqrt_int32(2371041), sqrt(2371041)); + VerifyFuzzyEquals(sqrt_int64(2371041), sqrt(2371041)); + VerifyFuzzyEquals(sqrt_float32(2371041), sqrt(2371041)); + VerifyFuzzyEquals(sqrt_float64(2371041), sqrt(2371041)); + + VerifyFuzzyEquals(sqrt_float32(3600.50f), sqrt(3600.50f)); + VerifyFuzzyEquals(sqrt_float64(3600.50), sqrt(3600.50)); +} + +TEST(TestExtendedMathOps, TestSign) { + // Sqrt functions + EXPECT_EQ(sign_int32(0), 0); + EXPECT_EQ(sign_int64(0), 0); + EXPECT_EQ(sign_float32(0), 0); + EXPECT_EQ(sign_float64(0), 0); + + EXPECT_EQ(sign_int32(5), 1); + EXPECT_EQ(sign_int64(5), 1); + EXPECT_EQ(sign_float32(5), 1); + EXPECT_EQ(sign_float64(5), 1); + + EXPECT_EQ(sign_int32(2371041), 1); + EXPECT_EQ(sign_int64(2371041), 1); + EXPECT_EQ(sign_float32(2371041), 1); + EXPECT_EQ(sign_float64(2371041), 1); + + EXPECT_EQ(sign_int32(-3600), -1); + EXPECT_EQ(sign_int64(-3600), -1); + EXPECT_EQ(sign_float32(-3600.50f), -1); + EXPECT_EQ(sign_float64(-3600.50), -1); +} + +TEST(TestExtendedMathOps, TestLshiftRshift) { + // Lshift functions + EXPECT_EQ(lshift_int32(1, 32), lshift_int32(1, 0)); + EXPECT_EQ(lshift_int64(1, 64), lshift_int64(1, 0)); + EXPECT_EQ(lshift_int32(-1, 33), lshift_int32(-1, 0)); + EXPECT_EQ(lshift_int32(-1, 100), lshift_int32(-1, 0)); + EXPECT_EQ(lshift_int64(-1, 65), lshift_int64(-1, 0)); + EXPECT_EQ(lshift_int64(-1, 125), lshift_int64(-1, 0)); + EXPECT_EQ(lshift_int32(-1, -1), lshift_int32(-1, 0)); + EXPECT_EQ(lshift_int64(-1, -1), lshift_int64(-1, 0)); + + EXPECT_EQ(lshift_int32(1, 31), lshift_int32(-1, 31)); + + EXPECT_EQ(lshift_int32(2, 31), 0); + EXPECT_EQ(lshift_int32(2, 31), lshift_int32(-2, 31)); + EXPECT_EQ(lshift_int64(2, 63), 0); + EXPECT_EQ(lshift_int64(2, 63), lshift_int64(-2, 63)); + + EXPECT_EQ(lshift_int32(-2, 16), -131072); + EXPECT_EQ(lshift_int32(2, 16), 131072); + EXPECT_EQ(lshift_int64(-2, 33), -17179869184L); + EXPECT_EQ(lshift_int64(2, 33), 17179869184L); + + // Rshift functions + EXPECT_EQ(rshift_int32(1, 32), rshift_int32(1, 0)); + EXPECT_EQ(rshift_int64(1, 64), rshift_int64(1, 0)); + EXPECT_EQ(rshift_int32(-1, 33), rshift_int32(-1, 1)); + EXPECT_EQ(rshift_int64(-1, 65), rshift_int64(-1, 1)); + EXPECT_EQ(rshift_int32(-1, 31), rshift_int32(-1, -1)); + EXPECT_EQ(rshift_int64(-1, 63), rshift_int64(-1, -1)); + EXPECT_EQ(rshift_int32(-1, -62), rshift_int32(-1, 1)); + EXPECT_EQ(rshift_int64(-1, -126), rshift_int64(-1, 1)); + + EXPECT_EQ(rshift_int32(10, 31), 0); + EXPECT_EQ(rshift_int64(10, 63), 0); + EXPECT_EQ(rshift_int32(-10, 31), -1); + EXPECT_EQ(rshift_int64(-10, 63), -1); + + EXPECT_EQ(rshift_int32(-1024, 6), -16); + EXPECT_EQ(rshift_int64(-65536, 10), -64); + + int32_t val = -3; + int32_t num_shift = 16; + EXPECT_EQ(rshift_int32(lshift_int32(val, num_shift), num_shift), val); + + int64_t val_int64 = -3; + int64_t num_shift_int64 = 32; + EXPECT_EQ(rshift_int64(lshift_int64(val_int64, num_shift_int64), num_shift_int64), + val_int64); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index e08dacda416..747c4fffb41 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -235,6 +235,11 @@ gdv_float64 log10_int64(gdv_int64); gdv_float64 log10_float32(gdv_float32); gdv_float64 log10_float64(gdv_float64); +gdv_float64 log2_int32(gdv_int32); +gdv_float64 log2_int64(gdv_int64); +gdv_float64 log2_float32(gdv_float32); +gdv_float64 log2_float64(gdv_float64); + gdv_float64 sin_int32(gdv_int32); gdv_float64 sin_int64(gdv_int64); gdv_float64 sin_float32(gdv_float32); @@ -286,6 +291,30 @@ gdv_float64 degrees_int64(gdv_int64); gdv_float64 degrees_float32(gdv_float32); gdv_float64 degrees_float64(gdv_float64); +gdv_int32 sign_int32(gdv_int32); +gdv_int64 sign_int64(gdv_int64); +gdv_float32 sign_float32(gdv_float32); +gdv_float64 sign_float64(gdv_float64); +gdv_int32 abs_int32(gdv_int32); +gdv_int64 abs_int64(gdv_int64); +gdv_float32 abs_float32(gdv_float32); +gdv_float64 abs_float64(gdv_float64); +gdv_float32 ceil_float32(gdv_float32); +gdv_float64 ceil_float64(gdv_float64); +gdv_float32 floor_float32(gdv_float32); +gdv_float64 floor_float64(gdv_float64); +gdv_float64 sqrt_int32(gdv_int32); +gdv_float64 sqrt_int64(gdv_int64); +gdv_float64 sqrt_float32(gdv_float32); +gdv_float64 sqrt_float64(gdv_float64); +gdv_int32 lshift_int32(gdv_int32, gdv_int32); +gdv_int64 lshift_int64(gdv_int64, gdv_int64); +gdv_int32 rshift_int32(gdv_int32, gdv_int32); +gdv_int64 rshift_int64(gdv_int64, gdv_int64); + +gdv_float64 pi(); +gdv_float64 e(); + gdv_int32 bitwise_and_int32_int32(gdv_int32 in1, gdv_int32 in2); gdv_int64 bitwise_and_int64_int64(gdv_int64 in1, gdv_int64 in2); gdv_int32 bitwise_or_int32_int32(gdv_int32 in1, gdv_int32 in2); @@ -412,7 +441,14 @@ gdv_time32 castTIME_int32(int32_t int_val); const char* castVARCHAR_timestamp_int64(int64_t, gdv_timestamp, gdv_int64, gdv_int32*); gdv_date64 last_day_from_timestamp(gdv_date64 millis); -gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale); +int32_t truncate_int32(int32_t in); +int64_t truncate_int64(int64_t in); +float truncate_float32(float in); +double truncate_float64(double in); +int32_t truncate_int32_int32(int32_t in, int32_t out_scale); +int64_t truncate_int64_int32(int64_t in, int32_t out_scale); +float truncate_float32_int32(float in, int32_t out_scale); +double truncate_float64_int32(double in, int32_t out_scale); const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len, gdv_int32 repeat_times, gdv_int32* out_len); diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index d12d4f8e5cd..e52cce55643 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -413,6 +413,13 @@ TEST_F(TestProjector, TestExtendedMath) { auto field_cot = arrow::field("cot", arrow::float64()); auto field_radians = arrow::field("radians", arrow::float64()); auto field_degrees = arrow::field("degrees", arrow::float64()); + auto field_abs = arrow::field("abs", arrow::float64()); + auto field_ceil = arrow::field("ceil", arrow::float64()); + auto field_floor = arrow::field("floor", arrow::float64()); + auto field_pi = arrow::field("pi", arrow::float64()); + auto field_e = arrow::field("euler", arrow::float64()); + auto field_sqrt = arrow::field("sqrt", arrow::float64()); + auto field_sign = arrow::field("sign", arrow::float64()); // Build expression auto cbrt_expr = TreeExprBuilder::MakeExpression("cbrt", {field0}, field_cbrt); @@ -436,13 +443,21 @@ TEST_F(TestProjector, TestExtendedMath) { auto cot_expr = TreeExprBuilder::MakeExpression("cot", {field0}, field_cot); auto radians_expr = TreeExprBuilder::MakeExpression("radians", {field0}, field_radians); auto degrees_expr = TreeExprBuilder::MakeExpression("degrees", {field0}, field_degrees); + auto abs_expr = TreeExprBuilder::MakeExpression("abs", {field0}, field_abs); + auto ceil_expr = TreeExprBuilder::MakeExpression("ceil", {field0}, field_ceil); + auto floor_expr = TreeExprBuilder::MakeExpression("floor", {field0}, field_ceil); + auto pi_expr = TreeExprBuilder::MakeExpression("pi", {}, field_pi); + auto e_expr = TreeExprBuilder::MakeExpression("e", {}, field_e); + auto sqrt_expr = TreeExprBuilder::MakeExpression("sqrt", {field1}, field_sqrt); + auto sign_expr = TreeExprBuilder::MakeExpression("sign", {field0}, field_sign); std::shared_ptr projector; auto status = Projector::Make( - schema, - {cbrt_expr, exp_expr, log_expr, log10_expr, logb_expr, power_expr, sin_expr, - cos_expr, asin_expr, acos_expr, tan_expr, atan_expr, sinh_expr, cosh_expr, - tanh_expr, atan2_expr, cot_expr, radians_expr, degrees_expr}, + schema, {cbrt_expr, exp_expr, log_expr, log10_expr, logb_expr, power_expr, + sin_expr, cos_expr, asin_expr, acos_expr, tan_expr, atan_expr, + sinh_expr, cosh_expr, tanh_expr, atan2_expr, cot_expr, radians_expr, + degrees_expr, abs_expr, ceil_expr, floor_expr, pi_expr, e_expr, + sqrt_expr, sign_expr}, TestConfiguration(), &projector); EXPECT_TRUE(status.ok()); @@ -475,6 +490,13 @@ TEST_F(TestProjector, TestExtendedMath) { std::vector cot_vals; std::vector radians_vals; std::vector degrees_vals; + std::vector abs_vals; + std::vector ceil_vals; + std::vector floor_vals; + std::vector pi_vals; + std::vector e_vals; + std::vector sqrt_vals; + std::vector sign_vals; for (int i = 0; i < num_records; i++) { cbrt_vals.push_back(static_cast(cbrtl(input0[i]))); exp_vals.push_back(static_cast(expl(input0[i]))); @@ -495,6 +517,14 @@ TEST_F(TestProjector, TestExtendedMath) { cot_vals.push_back(static_cast(tan(M_PI / 2 - input0[i]))); radians_vals.push_back(static_cast(input0[i] * M_PI / 180.0)); degrees_vals.push_back(static_cast(input0[i] * 180.0 / M_PI)); + abs_vals.push_back(static_cast(abs(input0[i]))); + ceil_vals.push_back(static_cast(ceil(input0[i]))); + floor_vals.push_back(static_cast(floor(input0[i]))); + pi_vals.push_back(static_cast(M_PI)); + e_vals.push_back(static_cast(exp(1.0))); + sqrt_vals.push_back(static_cast(sqrt(input1[i]))); + sign_vals.push_back( + static_cast(input0[i] == 0 ? input0[i] : copysign(1.0, input0[i]))); } auto expected_cbrt = MakeArrowArray(cbrt_vals, validity); auto expected_exp = MakeArrowArray(exp_vals, validity); @@ -517,6 +547,13 @@ TEST_F(TestProjector, TestExtendedMath) { MakeArrowArray(radians_vals, validity); auto expected_degrees = MakeArrowArray(degrees_vals, validity); + auto expected_abs = MakeArrowArray(abs_vals, validity); + auto expected_ceil = MakeArrowArray(ceil_vals, validity); + auto expected_floor = MakeArrowArray(floor_vals, validity); + auto expected_pi = MakeArrowArray(pi_vals, validity); + auto expected_e = MakeArrowArray(e_vals, validity); + auto expected_sqrt = MakeArrowArray(sqrt_vals, validity); + auto expected_sign = MakeArrowArray(sign_vals, validity); // prepare input record batch auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); @@ -546,6 +583,13 @@ TEST_F(TestProjector, TestExtendedMath) { EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_cot, outputs.at(16), epsilon); EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_radians, outputs.at(17), epsilon); EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_degrees, outputs.at(18), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_abs, outputs.at(19), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_ceil, outputs.at(20), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_floor, outputs.at(21), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_pi, outputs.at(22), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_e, outputs.at(23), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sqrt, outputs.at(24), epsilon); + EXPECT_ARROW_ARRAY_APPROX_EQUALS(expected_sign, outputs.at(25), epsilon); } TEST_F(TestProjector, TestFloatLessThan) { From 9f450bfa0fdd5d84b8fc23b9803283834477ccf9 Mon Sep 17 00:00:00 2001 From: Anthony Louis Date: Wed, 2 Feb 2022 19:17:23 -0300 Subject: [PATCH 2/2] Add treatement when truncate overflow --- cpp/src/gandiva/function_registry_common.h | 11 +++ cpp/src/gandiva/function_registry_math_ops.cc | 12 +-- .../gandiva/precompiled/extended_math_ops.cc | 46 ++++++++-- .../precompiled/extended_math_ops_test.cc | 86 ++++++++++--------- cpp/src/gandiva/precompiled/types.h | 12 +-- 5 files changed, 109 insertions(+), 58 deletions(-) diff --git a/cpp/src/gandiva/function_registry_common.h b/cpp/src/gandiva/function_registry_common.h index d2f08322caf..b33eabdf6fd 100644 --- a/cpp/src/gandiva/function_registry_common.h +++ b/cpp/src/gandiva/function_registry_common.h @@ -97,6 +97,17 @@ typedef std::unordered_map ALIASES, \ + DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), kResultNullIfNull, \ + ARROW_STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2), \ + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors) + // Binary functions that : // - have different input types, or output type // - NULL handling is of type NULL_IF_NULL diff --git a/cpp/src/gandiva/function_registry_math_ops.cc b/cpp/src/gandiva/function_registry_math_ops.cc index b7e957e139e..93a3a85ce77 100644 --- a/cpp/src/gandiva/function_registry_math_ops.cc +++ b/cpp/src/gandiva/function_registry_math_ops.cc @@ -117,12 +117,12 @@ std::vector GetMathOpsFunctionRegistry() { UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftrightunsigned"}, uint64, uint64), UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32), UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int64), - UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, float32), - UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, float64), - BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32, int32), - BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int32, int64), - BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, int32, float32), - BINARY_GENERIC_SAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, int32, float64), + UNARY_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, float32), + UNARY_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, float64), + BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32, int32), + BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int32, int64), + BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, int32, float32), + BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, int32, float64), // decimal functions UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128), diff --git a/cpp/src/gandiva/precompiled/extended_math_ops.cc b/cpp/src/gandiva/precompiled/extended_math_ops.cc index 982328ddeaa..543cd915fdb 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops.cc @@ -499,28 +499,47 @@ gdv_int64 get_power_of_10(gdv_int32 exp) { #define TRUNCATE_INTEGER_WITH_OUT_SCALE(TYPE) \ FORCE_INLINE \ - gdv_##TYPE truncate_##TYPE##_int32(gdv_##TYPE in, gdv_int32 out_scale) { \ + gdv_##TYPE truncate_##TYPE##_int32(int64_t context, gdv_##TYPE in, \ + gdv_int32 out_scale) { \ bool overflow = false; \ arrow::BasicDecimal128 decimal = \ gandiva::decimalops::FromInt64(static_cast(in), 38, 0, &overflow); \ + \ arrow::BasicDecimal128 decimal_with_outscale = \ gandiva::decimalops::Truncate(gandiva::BasicDecimalScalar128(decimal, 38, 0), \ 38, out_scale, out_scale, &overflow); \ + \ + if (overflow) { \ + gdv_fn_context_set_error_msg(context, "The value overflow during truncate"); \ + return static_cast(0); \ + } \ + \ if (out_scale < 0) { \ out_scale = 0; \ } \ - return static_cast(gandiva::decimalops::ToInt64( \ + \ + gdv_##TYPE return_value = static_cast(gandiva::decimalops::ToInt64( \ gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), \ &overflow)); \ + \ + if (overflow) { \ + gdv_fn_context_set_error_msg(context, "The value overflow during truncate"); \ + return static_cast(0); \ + } \ + \ + return return_value; \ } -#define TRUNCATE_REAL(TYPE) \ - FORCE_INLINE \ - gdv_##TYPE truncate_##TYPE(gdv_##TYPE in) { return truncate_##TYPE##_int32(in, 0); } +#define TRUNCATE_REAL(TYPE) \ + FORCE_INLINE \ + gdv_##TYPE truncate_##TYPE(int64_t context, gdv_##TYPE in) { \ + return truncate_##TYPE##_int32(context, in, 0); \ + } #define TRUNCATE_REAL_WITH_OUT_SCALE(TYPE) \ FORCE_INLINE \ - gdv_##TYPE truncate_##TYPE##_int32(gdv_##TYPE in, gdv_int32 out_scale) { \ + gdv_##TYPE truncate_##TYPE##_int32(int64_t context, gdv_##TYPE in, \ + gdv_int32 out_scale) { \ bool overflow = false; \ auto roundType = gandiva::decimalops::RoundType::kRoundTypeTrunc; \ int32_t conversionScale = out_scale >= 0 ? out_scale : 0; \ @@ -530,13 +549,26 @@ gdv_int64 get_power_of_10(gdv_int32 exp) { arrow::BasicDecimal128 decimal_with_outscale = gandiva::decimalops::Truncate( \ gandiva::BasicDecimalScalar128(decimal, 38, conversionScale), 38, out_scale, \ out_scale, &overflow); \ + \ + if (overflow) { \ + gdv_fn_context_set_error_msg(context, "The value overflow during truncate"); \ + return static_cast(0); \ + } \ + \ if (out_scale < 0) { \ out_scale = 0; \ } \ \ - return static_cast(gandiva::decimalops::ToDouble( \ + gdv_##TYPE return_value = static_cast(gandiva::decimalops::ToDouble( \ gandiva::BasicDecimalScalar128(decimal_with_outscale, 38, out_scale), \ &overflow)); \ + \ + if (overflow) { \ + gdv_fn_context_set_error_msg(context, "The value overflow during truncate"); \ + return static_cast(0); \ + } \ + \ + return return_value; \ } ESIGNED_INTEGER_TYPES(TRUNCATE_INTEGER) diff --git a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc index 31eea030b9b..24a76a81e1e 100644 --- a/cpp/src/gandiva/precompiled/extended_math_ops_test.cc +++ b/cpp/src/gandiva/precompiled/extended_math_ops_test.cc @@ -210,20 +210,23 @@ TEST(TestExtendedMathOps, TestRound) { TEST(TestExtendedMathOps, TestTruncateWithScale) { // Test the truncate function for longs - EXPECT_EQ(truncate_int64_int32(1234, 4), 1234); - EXPECT_EQ(truncate_int64_int32(-1234, 4), -1234); - EXPECT_EQ(truncate_int64_int32(1234, -4), 0); - EXPECT_EQ(truncate_int64_int32(-1234, -2), -1200); - EXPECT_EQ(truncate_int64_int32(8124674407369523212, 0), 8124674407369523212); - EXPECT_EQ(truncate_int64_int32(8124674407369523212, -2), 8124674407369523200); + gandiva::ExecutionContext context; + auto ctx = reinterpret_cast(&context); + + EXPECT_EQ(truncate_int64_int32(ctx, 1234, 4), 1234); + EXPECT_EQ(truncate_int64_int32(ctx, -1234, 4), -1234); + EXPECT_EQ(truncate_int64_int32(ctx, 1234, -4), 0); + EXPECT_EQ(truncate_int64_int32(ctx, -1234, -2), -1200); + EXPECT_EQ(truncate_int64_int32(ctx, 8124674407369523212, 0), 8124674407369523212); + EXPECT_EQ(truncate_int64_int32(ctx, 8124674407369523212, -2), 8124674407369523200); // Test truncate function for integers - EXPECT_EQ(truncate_int32_int32(1234, 4), 1234); - EXPECT_EQ(truncate_int32_int32(-1234, 4), -1234); - EXPECT_EQ(truncate_int32_int32(1234, -4), 0); - EXPECT_EQ(truncate_int32_int32(-1234, -2), -1200); - EXPECT_EQ(truncate_int32_int32(8124674, 0), 8124674); - EXPECT_EQ(truncate_int32_int32(8124674, -2), 8124600); + EXPECT_EQ(truncate_int32_int32(ctx, 1234, 4), 1234); + EXPECT_EQ(truncate_int32_int32(ctx, -1234, 4), -1234); + EXPECT_EQ(truncate_int32_int32(ctx, 1234, -4), 0); + EXPECT_EQ(truncate_int32_int32(ctx, -1234, -2), -1200); + EXPECT_EQ(truncate_int32_int32(ctx, 8124674, 0), 8124674); + EXPECT_EQ(truncate_int32_int32(ctx, 8124674, -2), 8124600); } TEST(TestExtendedMathOps, TestTruncateWithoutScale) { @@ -239,33 +242,38 @@ TEST(TestExtendedMathOps, TestTruncateWithoutScale) { } TEST(TestExtendedMathOps, TestTruncateFloat) { - VerifyFuzzyEquals(truncate_float32(1234.245f), 1234.0f); - VerifyFuzzyEquals(truncate_float32(-11.7892f), -11.0f); - VerifyFuzzyEquals(truncate_float32(1.4999999f), 1.0f); - EXPECT_EQ(std::signbit(truncate_float32(0)), 0); - VerifyFuzzyEquals(truncate_float32_int32(1234.789f, 2), 1234.78f); - VerifyFuzzyEquals(truncate_float32_int32(1234.12345f, -3), 1000.0f); - VerifyFuzzyEquals(truncate_float32_int32(-1234.4567f, 3), -1234.456f); - VerifyFuzzyEquals(truncate_float32_int32(-1234.4567f, -3), -1000.0f); - VerifyFuzzyEquals(truncate_float32_int32(1234.4567f, 0), 1234); - VerifyFuzzyEquals(truncate_float32_int32(1.5499999523162842f, 1), 1.5f); - EXPECT_EQ(std::signbit(truncate_float32_int32(0, 5)), 0); - VerifyFuzzyEquals(truncate_float32_int32(static_cast(1.55), 1), 1.5f); - VerifyFuzzyEquals(truncate_float32_int32(static_cast(9.134123), 2), 9.13f); - VerifyFuzzyEquals(truncate_float32_int32(static_cast(-1.923), 1), -1.9f); - - VerifyFuzzyEquals(truncate_float64(1234.245), 1234.0); - VerifyFuzzyEquals(truncate_float64(-11.7892), -11.0); - VerifyFuzzyEquals(truncate_float64(1.4999999), 1.0); - EXPECT_EQ(std::signbit(truncate_float64(0)), 0); - VerifyFuzzyEquals(truncate_float64_int32(1234.789, 2), 1234.78); - VerifyFuzzyEquals(truncate_float64_int32(1234.12345, -3), 1000.0); - VerifyFuzzyEquals(truncate_float64_int32(-1234.4567, 3), -1234.456); - VerifyFuzzyEquals(truncate_float64_int32(-1234.4567, -3), -1000.0); - VerifyFuzzyEquals(truncate_float64_int32(1234.4567, 0), 1234.0); - EXPECT_EQ(std::signbit(truncate_float64_int32(0, -2)), 0); - VerifyFuzzyEquals(truncate_float64_int32((double)INT_MAX + 1, 0), (double)INT_MAX + 1); - VerifyFuzzyEquals(truncate_float64_int32((double)INT_MIN - 1, 0), (double)INT_MIN - 1); + gandiva::ExecutionContext context; + auto ctx = reinterpret_cast(&context); + + VerifyFuzzyEquals(truncate_float32(ctx, 1234.245f), 1234.0f); + VerifyFuzzyEquals(truncate_float32(ctx, -11.7892f), -11.0f); + VerifyFuzzyEquals(truncate_float32(ctx, 1.4999999f), 1.0f); + EXPECT_EQ(std::signbit(truncate_float32(ctx, 0)), 0); + VerifyFuzzyEquals(truncate_float32_int32(ctx, 1234.789f, 2), 1234.78f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, 1234.12345f, -3), 1000.0f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, -1234.4567f, 3), -1234.456f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, -1234.4567f, -3), -1000.0f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, 1234.4567f, 0), 1234); + VerifyFuzzyEquals(truncate_float32_int32(ctx, 1.5499999523162842f, 1), 1.5f); + EXPECT_EQ(std::signbit(truncate_float32_int32(ctx, 0, 5)), 0); + VerifyFuzzyEquals(truncate_float32_int32(ctx, static_cast(1.55), 1), 1.5f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, static_cast(9.134123), 2), 9.13f); + VerifyFuzzyEquals(truncate_float32_int32(ctx, static_cast(-1.923), 1), -1.9f); + + VerifyFuzzyEquals(truncate_float64(ctx, 1234.245), 1234.0); + VerifyFuzzyEquals(truncate_float64(ctx, -11.7892), -11.0); + VerifyFuzzyEquals(truncate_float64(ctx, 1.4999999), 1.0); + EXPECT_EQ(std::signbit(truncate_float64(ctx, 0)), 0); + VerifyFuzzyEquals(truncate_float64_int32(ctx, 1234.789, 2), 1234.78); + VerifyFuzzyEquals(truncate_float64_int32(ctx, 1234.12345, -3), 1000.0); + VerifyFuzzyEquals(truncate_float64_int32(ctx, -1234.4567, 3), -1234.456); + VerifyFuzzyEquals(truncate_float64_int32(ctx, -1234.4567, -3), -1000.0); + VerifyFuzzyEquals(truncate_float64_int32(ctx, 1234.4567, 0), 1234.0); + EXPECT_EQ(std::signbit(truncate_float64_int32(ctx, 0, -2)), 0); + VerifyFuzzyEquals(truncate_float64_int32(ctx, (double)INT_MAX + 1, 0), + (double)INT_MAX + 1); + VerifyFuzzyEquals(truncate_float64_int32(ctx, (double)INT_MIN - 1, 0), + (double)INT_MIN - 1); } TEST(TestExtendedMathOps, TestTrigonometricFunctions) { diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index 747c4fffb41..54ae818f632 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -443,12 +443,12 @@ gdv_date64 last_day_from_timestamp(gdv_date64 millis); int32_t truncate_int32(int32_t in); int64_t truncate_int64(int64_t in); -float truncate_float32(float in); -double truncate_float64(double in); -int32_t truncate_int32_int32(int32_t in, int32_t out_scale); -int64_t truncate_int64_int32(int64_t in, int32_t out_scale); -float truncate_float32_int32(float in, int32_t out_scale); -double truncate_float64_int32(double in, int32_t out_scale); +float truncate_float32(int64_t context, float in); +double truncate_float64(int64_t context, double in); +int32_t truncate_int32_int32(int64_t context, int32_t in, int32_t out_scale); +int64_t truncate_int64_int32(int64_t context, int64_t in, int32_t out_scale); +float truncate_float32_int32(int64_t context, float in, int32_t out_scale); +double truncate_float64_int32(int64_t context, double in, int32_t out_scale); const char* repeat_utf8_int32(gdv_int64 context, const char* in, gdv_int32 in_len, gdv_int32 repeat_times, gdv_int32* out_len);