Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions cpp/src/gandiva/function_registry_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ typedef std::unordered_map<const FunctionSignature*, const NativeFunction*, KeyH
#define BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE) \
BINARY_UNSAFE_NULL_IF_NULL(NAME, ALIASES, TYPE, TYPE)

// Binary functions that :
// - have different input types, or output type
// - NULL handling is of type NULL_IF_NULL
//
// The pre-compiled fn name includes the base name & input type names. eg. mod_int64_int32
#define BINARY_GENERIC_UNSAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE1, IN_TYPE2, OUT_TYPE) \
NativeFunction(#NAME, std::vector<std::string> ALIASES, \
DataTypeVector{IN_TYPE1(), IN_TYPE2()}, OUT_TYPE(), kResultNullIfNull, \
ARROW_STRINGIFY(NAME##_##IN_TYPE1##_##IN_TYPE2), \
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)

// Binary functions that :
// - have different input types, or output type
// - NULL handling is of type NULL_IF_NULL
Expand Down Expand Up @@ -127,6 +138,17 @@ typedef std::unordered_map<const FunctionSignature*, const NativeFunction*, KeyH
NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \
OUT_TYPE(), kResultNullIfNull, ARROW_STRINGIFY(NAME##_##IN_TYPE))

// Unary functions in gdv stubs file that :
// - NULL handling is of type NULL_IF_NULL
//
// The pre-compiled fn name includes the base name & input type name.
// eg. gdv_fn_abs_float32
#define STUBS_UNARY_SAFE_NULL_IF_NULL(NAME, ALIASES, IN_TYPE, OUT_TYPE) \
NativeFunction(#NAME, std::vector<std::string> ALIASES, DataTypeVector{IN_TYPE()}, \
OUT_TYPE(), kResultNullIfNull, \
ARROW_STRINGIFY(gdv_fn_##NAME##_##IN_TYPE), \
NativeFunction::kNeedsContext)

// Unary functions that :
// - NULL handling is of type NULL_NEVER
//
Expand Down
46 changes: 42 additions & 4 deletions cpp/src/gandiva/function_registry_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ namespace gandiva {
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float64), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64)

#define MATH_UNARY_OPS_SAME_TYPE_RETURN(name, ALIASES) \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32), \
UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64)

#define STUBS_MATH_UNARY_OPS_SAME_TYPE_RETURN(name, ALIASES) \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int32, int32), \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, int64, int64), \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint32, uint32), \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, uint64, uint64), \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float32, float32), \
STUBS_UNARY_SAFE_NULL_IF_NULL(name, ALIASES, float64, float64)

#define MATH_BINARY_UNSAFE(name, ALIASES) \
BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int32, float64), \
BINARY_UNSAFE_NULL_IF_NULL(name, ALIASES, int64, float64), \
Expand Down Expand Up @@ -61,8 +77,8 @@ namespace gandiva {

std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
static std::vector<NativeFunction> math_fn_registry_ = {
MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {}),
MATH_UNARY_OPS(log10, {}),
MATH_UNARY_OPS(cbrt, {}), MATH_UNARY_OPS(exp, {}), MATH_UNARY_OPS(log, {"ln"}),
MATH_UNARY_OPS(log2, {}), MATH_UNARY_OPS(log10, {}),

MATH_BINARY_UNSAFE(log, {}),

Expand All @@ -86,6 +102,28 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
MATH_UNARY_OPS(cot, {}), MATH_UNARY_OPS(radians, {}), MATH_UNARY_OPS(degrees, {}),
MATH_BINARY_SAFE(atan2, {}),

// extended functions
MATH_UNARY_OPS(sqrt, {}), STUBS_MATH_UNARY_OPS_SAME_TYPE_RETURN(abs, {}),
MATH_UNARY_OPS_SAME_TYPE_RETURN(sign, {}),
UNARY_SAFE_NULL_IF_NULL(ceil, {}, float32, float32),
UNARY_SAFE_NULL_IF_NULL(ceil, {}, float64, float64),
UNARY_SAFE_NULL_IF_NULL(floor, {}, float32, float32),
UNARY_SAFE_NULL_IF_NULL(floor, {}, float64, float64),
UNARY_SAFE_NULL_IF_NULL(lshift, {"shiftleft"}, int32, int32),
UNARY_SAFE_NULL_IF_NULL(lshift, {"shiftleft"}, int64, int64),
UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftright"}, int32, int32),
UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftright"}, int64, int64),
UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftrightunsigned"}, uint32, uint32),
UNARY_SAFE_NULL_IF_NULL(rshift, {"shiftrightunsigned"}, uint64, uint64),
UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32),
UNARY_SAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int64),
UNARY_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, float32),
UNARY_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, float64),
BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, int32, int32, int32),
BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, int64, int32, int64),
BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float32, int32, float32),
BINARY_GENERIC_UNSAFE_NULL_IF_NULL(truncate, {"trunc"}, float64, int32, float64),

// decimal functions
UNARY_SAFE_NULL_IF_NULL(abs, {}, decimal128, decimal128),
UNARY_SAFE_NULL_IF_NULL(ceil, {}, decimal128, decimal128),
Expand All @@ -97,8 +135,8 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
decimal128),
BINARY_SYMMETRIC_SAFE_NULL_NEVER_FN(nvl, {}),

NativeFunction("truncate", {"trunc"}, DataTypeVector{int64(), int32()}, int64(),
kResultNullIfNull, "truncate_int64_int32"),
NativeFunction("pi", {}, {}, float64(), kResultNullIfNull, "pi"),
NativeFunction("e", {}, {}, float64(), kResultNullIfNull, "e"),
NativeFunction("random", {"rand"}, DataTypeVector{}, float64(), kResultNullNever,
"gdv_fn_random", NativeFunction::kNeedsFunctionHolder),
NativeFunction("random", {"rand"}, DataTypeVector{int32()}, float64(),
Expand Down
102 changes: 102 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
#include <utf8proc.h>

#include <boost/crc.hpp>
#include <cmath>
#include <iostream>
#include <limits>
#include <string>
#include <vector>

Expand Down Expand Up @@ -1165,6 +1168,49 @@ int32_t gdv_fn_cast_intervalyear_utf8_int32(int64_t context_ptr, int64_t holder_
auto* holder = reinterpret_cast<gandiva::IntervalYearsHolder*>(holder_ptr);
return (*holder)(context, data, data_len, in1_validity, out_valid);
}

#define ABS_SIGNED_INTEGER_TYPES(INNER) INNER(int32) INNER(int64)
#define ABS_UNSIGNED_INTEGER_TYPES(INNER) INNER(uint32) INNER(uint64)
#define ABS_REAL_TYPES(INNER) INNER(float32) INNER(float64)

// Abs for unsigned types. The function can return an error if occurs an overflow
#define ABS_FOR_SIGNED_INTEGER(IN_TYPE) \
GANDIVA_EXPORT \
gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { \
if (in <= std::numeric_limits<gdv_##IN_TYPE>::min()) { \
std::string error_msg("Overflow in abs execution"); \
gdv_fn_context_set_error_msg(context, error_msg.data()); \
return static_cast<gdv_##IN_TYPE>(0); \
} \
\
gdv_##IN_TYPE zero = 0; \
if (in < zero) { \
gdv_##IN_TYPE minus_one = -1; \
return (minus_one * in); \
} \
\
return in; \
}

#define ABS_FOR_REAL_TYPES(IN_TYPE) \
GANDIVA_EXPORT \
gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { \
return static_cast<gdv_##IN_TYPE>(fabs(in)); \
}

// Optimization in abs function for unsigned types
#define ABS_FOR_UNSIGNED_INTEGER(IN_TYPE) \
GANDIVA_EXPORT \
gdv_##IN_TYPE gdv_fn_abs_##IN_TYPE(int64_t context, gdv_##IN_TYPE in) { return in; }

ABS_SIGNED_INTEGER_TYPES(ABS_FOR_SIGNED_INTEGER)
ABS_UNSIGNED_INTEGER_TYPES(ABS_FOR_UNSIGNED_INTEGER)
ABS_REAL_TYPES(ABS_FOR_REAL_TYPES)

#undef ABS_SIGNED_INTEGER_TYPES
#undef ABS_UNSIGNED_INTEGER_TYPES
#undef ABS_FOR_SIGNED_INTEGER
#undef ABS_FOR_UNSIGNED_INTEGER
}

namespace gandiva {
Expand Down Expand Up @@ -2309,5 +2355,61 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
engine->AddGlobalMappingForFunc(
"gdv_fn_cast_intervalyear_utf8_int32", types->i32_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_cast_intervalyear_utf8_int32));

// gdv_fn_abs_int32
args = {
types->i64_type(), // context
types->i32_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_int32", types->i32_type() /*return_type*/,
args, reinterpret_cast<void*>(gdv_fn_abs_int32));

// gdv_fn_abs_int64
args = {
types->i64_type(), // context
types->i64_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_int64", types->i64_type() /*return_type*/,
args, reinterpret_cast<void*>(gdv_fn_abs_int64));

// gdv_fn_abs_uint32
args = {
types->i64_type(), // context
types->i32_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_uint32", types->i32_type() /*return_type*/,
args, reinterpret_cast<void*>(gdv_fn_abs_uint32));

// gdv_fn_abs_uint64
args = {
types->i64_type(), // context
types->i64_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_uint64", types->i64_type() /*return_type*/,
args, reinterpret_cast<void*>(gdv_fn_abs_uint64));

// gdv_fn_abs_float32
args = {
types->i64_type(), // context
types->float_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_float32",
types->float_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_abs_float32));

// gdv_fn_abs_float64
args = {
types->i64_type(), // context
types->double_type(), // value
};

engine->AddGlobalMappingForFunc("gdv_fn_abs_float64",
types->double_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_abs_float64));
}
} // namespace gandiva
18 changes: 18 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,22 @@ GANDIVA_EXPORT
const char* gdv_mask_last_n_utf8_int32(int64_t context, const char* data,
int32_t data_len, int32_t n_to_mask,
int32_t* out_len);

GANDIVA_EXPORT
int32_t gdv_fn_abs_int32(int64_t context, int32_t in);

GANDIVA_EXPORT
int64_t gdv_fn_abs_int64(int64_t context, int64_t in);

GANDIVA_EXPORT
uint32_t gdv_fn_abs_uint32(int64_t context, uint32_t in);

GANDIVA_EXPORT
uint64_t gdv_fn_abs_uint64(int64_t context, uint64_t in);

GANDIVA_EXPORT
float gdv_fn_abs_float32(int64_t context, float in);

GANDIVA_EXPORT
double gdv_fn_abs_float64(int64_t context, double in);
}
62 changes: 62 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <cfloat>
#include <cmath>

#include "arrow/util/logging.h"
#include "gandiva/execution_context.h"

namespace gandiva {

void VerifyAlmostEquals(double actual, double expected, double max_error = FLT_EPSILON) {
EXPECT_TRUE(fabs(actual - expected) < max_error) << actual << " != " << expected;
}

TEST(TestGdvFnStubs, TestCrc32) {
gandiva::ExecutionContext ctx;
auto ctx_ptr = reinterpret_cast<int64_t>(&ctx);
Expand Down Expand Up @@ -949,4 +956,59 @@ TEST(TestGdvFnStubs, TestMaskLastN) {
EXPECT_EQ(expected, std::string(result, out_len));
}

TEST(TestGdvFnStubs, TestAbs) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

// Abs functions
EXPECT_EQ(gdv_fn_abs_int32(ctx_ptr, 0), 0);
EXPECT_EQ(gdv_fn_abs_uint32(ctx_ptr, 0), 0);
EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, 0), 0L);
EXPECT_EQ(gdv_fn_abs_uint64(ctx_ptr, 0), 0L);
VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, 0.0f), abs(0.0f));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, 0.0), abs(0.0));

EXPECT_EQ(gdv_fn_abs_int32(ctx_ptr, (INT32_MIN + 1)), abs(INT32_MIN + 1));
EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, (INT64_MIN + 1)), abs(INT64_MIN + 1));
VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, static_cast<float>(INT32_MIN + 1)),
abs(static_cast<float>(INT32_MIN + 1)));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, static_cast<double>(INT32_MIN + 1)),
abs(static_cast<double>(INT32_MIN + 1)));

VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, std::numeric_limits<float>::max()),
abs(std::numeric_limits<float>::max()));
VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, std::numeric_limits<float>::min()),
abs(std::numeric_limits<float>::min()));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, std::numeric_limits<double>::max()),
abs(std::numeric_limits<double>::max()));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, std::numeric_limits<double>::min()),
abs(std::numeric_limits<double>::min()));

EXPECT_EQ(gdv_fn_abs_int64(ctx_ptr, (INT64_MIN + 1)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean that we will accept the loss of digits in INT64?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk Thanks for your revision! Please, could you explain how would occur digit loss here? I did not understand what you said.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I thought the previous version of abs is to cast int64 to double long. But, now, gdv_fn_abs_int64 handles as int64. Correct?

Copy link
Contributor Author

@anthonylouisbsb anthonylouisbsb May 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kiszk Yes, this line has the changes I made in the function. I handle the number as int64 and I add a check similar to the compute kernel to avoid overflow when getting the absolute value for the INT64_MIN value.

abs(static_cast<double>(INT64_MIN + 1)));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, static_cast<double>(INT64_MIN + 1)),
abs(static_cast<double>(INT64_MIN + 1)));

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding tests for min and max values of double and float by using std::numeric_limits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests added

VerifyAlmostEquals(gdv_fn_abs_float32(ctx_ptr, -3600.50f), abs(-3600.50f));
VerifyAlmostEquals(gdv_fn_abs_float64(ctx_ptr, -3600.50), abs(-3600.50));
}

TEST(TestGdvFnStubs, TestAbsOverflow) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

int64_t value = gdv_fn_abs_int64(ctx_ptr, INT64_MIN);
EXPECT_TRUE(ctx.has_error());
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Overflow in abs execution"));
EXPECT_EQ(value, 0);

ctx.Reset();

int32_t value_int32 = gdv_fn_abs_int32(ctx_ptr, INT32_MIN);
EXPECT_TRUE(ctx.has_error());
EXPECT_THAT(ctx.get_error(), ::testing::HasSubstr("Overflow in abs execution"));
EXPECT_EQ(value_int32, 0);

ctx.Reset();
}
} // namespace gandiva
27 changes: 17 additions & 10 deletions cpp/src/gandiva/precompiled/decimal_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,27 @@ static std::array<double, DecimalTypeUtil::kMaxPrecision + 1> kDoubleScaleMultip
return values;
})();

BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow) {
BasicDecimal128 FromDouble(double in, int32_t precision, int32_t scale, bool* overflow,
RoundType round_type) {
// Multiply decimal with the scale
auto unscaled = in * kDoubleScaleMultipliers[scale];
DECIMAL_OVERFLOW_IF(std::isnan(unscaled), overflow);

unscaled = std::round(unscaled);
switch (round_type) {
case RoundType::kRoundTypeCeil:
unscaled = std::ceil(unscaled);
break;
case RoundType::kRoundTypeFloor:
unscaled = std::floor(unscaled);
break;
case RoundType::kRoundTypeTrunc:
unscaled = std::trunc(unscaled);
break;
case RoundType::kRoundTypeHalfRoundUp:
default:
unscaled = std::round(unscaled);
break;
}

// convert scaled double to int128
int32_t sign = unscaled < 0 ? -1 : 1;
Expand Down Expand Up @@ -551,14 +566,6 @@ static BasicDecimal128 ModifyScaleAndPrecision(const BasicDecimalScalar128& x,
}
}

enum RoundType {
kRoundTypeCeil, // +1 if +ve and trailing value is > 0, else no rounding.
kRoundTypeFloor, // -1 if -ve and trailing value is < 0, else no rounding.
kRoundTypeTrunc, // no rounding, truncate the trailing digits.
kRoundTypeHalfRoundUp, // if +ve and trailing value is >= half of base, +1.
// else if -ve and trailing value is >= half of base, -1.
};

// Compute the rounding delta for the givven rounding type.
static int32_t ComputeRoundingDelta(const BasicDecimal128& x, int32_t x_scale,
int32_t out_scale, RoundType type) {
Expand Down
Loading