Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions c_glib/gandiva-glib/native-function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,35 @@ ggandiva_native_function_get_signature(GGandivaNativeFunction *native_function)
{
auto gandiva_native_function =
ggandiva_native_function_get_raw(native_function);
auto &gandiva_function_signature = gandiva_native_function->signature();
auto &gandiva_function_signature = gandiva_native_function->signatures().front();
return ggandiva_function_signature_new_raw(&gandiva_function_signature);
}

/**
* ggandiva_native_function_get_all_signatures:
* @native_function: A #GGandivaNativeFunction.
*
* Returns: (transfer full): A List of #GGandivaFunctionSignature supported by
* the native function.
*
* Since: ??
*/
GList *
ggandiva_native_function_get_all_signatures(GGandivaNativeFunction *native_function)
{
auto gandiva_native_function =
ggandiva_native_function_get_raw(native_function);

GList *function_signatures = nullptr;
for (auto& function_sig_raw : gandiva_native_function->signatures()) {
auto function_sig = ggandiva_function_signature_new_raw(&function_sig_raw);
function_signatures = g_list_prepend(function_signatures, function_sig);
}
function_signatures = g_list_reverse(function_signatures);

return function_signatures;
}

/**
* ggandiva_native_function_equal:
* @native_function: A #GGandivaNativeFunction.
Expand Down Expand Up @@ -145,7 +170,7 @@ ggandiva_native_function_to_string(GGandivaNativeFunction *native_function)
{
auto gandiva_native_function =
ggandiva_native_function_get_raw(native_function);
auto gandiva_function_signature = gandiva_native_function->signature();
auto gandiva_function_signature = gandiva_native_function->signatures().front();
return g_strdup(gandiva_function_signature.ToString().c_str());
}

Expand Down
37 changes: 30 additions & 7 deletions cpp/src/gandiva/expression_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,51 @@ ExpressionRegistry::ExpressionRegistry() {

ExpressionRegistry::~ExpressionRegistry() {}

// to be used only to create function_signature_start
ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
native_func_iterator_type nf_it, native_func_iterator_type nf_it_end)
: native_func_it_{nf_it},
native_func_it_end_{nf_it_end},
func_sig_it_{&(nf_it->signatures().front())} {}

// to be used only to create function_signature_end
ExpressionRegistry::FunctionSignatureIterator::FunctionSignatureIterator(
func_sig_iterator_type fs_it)
: native_func_it_{nullptr}, native_func_it_end_{nullptr}, func_sig_it_{fs_it} {}

const ExpressionRegistry::FunctionSignatureIterator
ExpressionRegistry::function_signature_begin() {
return FunctionSignatureIterator(function_registry_->begin());
return FunctionSignatureIterator(function_registry_->begin(),
function_registry_->end());
}

const ExpressionRegistry::FunctionSignatureIterator
ExpressionRegistry::function_signature_end() const {
return FunctionSignatureIterator(function_registry_->end());
return FunctionSignatureIterator(&(*(function_registry_->back()->signatures().end())));
}

bool ExpressionRegistry::FunctionSignatureIterator::operator!=(
const FunctionSignatureIterator& func_sign_it) {
return func_sign_it.it_ != this->it_;
return func_sign_it.func_sig_it_ != this->func_sig_it_;
}

FunctionSignature ExpressionRegistry::FunctionSignatureIterator::operator*() {
return (*it_).signature();
return *func_sig_it_;
}

ExpressionRegistry::iterator ExpressionRegistry::FunctionSignatureIterator::operator++(
int increment) {
return it_++;
ExpressionRegistry::func_sig_iterator_type ExpressionRegistry::FunctionSignatureIterator::
operator++(int increment) {
++func_sig_it_;
// point func_sig_it_ to first signature of next nativefunction if func_sig_it_ is
// pointing to end
if (func_sig_it_ == &(*native_func_it_->signatures().end())) {
++native_func_it_;
if (native_func_it_ == native_func_it_end_) { // last native function
return func_sig_it_;
}
func_sig_it_ = &(native_func_it_->signatures().front());
}
return func_sig_it_;
}

DataTypeVector ExpressionRegistry::supported_types_ =
Expand Down
13 changes: 9 additions & 4 deletions cpp/src/gandiva/expression_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,27 @@ class FunctionRegistry;
/// data types and functions supported by Gandiva.
class GANDIVA_EXPORT ExpressionRegistry {
public:
using iterator = const NativeFunction*;
using native_func_iterator_type = const NativeFunction*;
using func_sig_iterator_type = const FunctionSignature*;
ExpressionRegistry();
~ExpressionRegistry();
static DataTypeVector supported_types() { return supported_types_; }
class GANDIVA_EXPORT FunctionSignatureIterator {
public:
explicit FunctionSignatureIterator(iterator it) : it_(it) {}
explicit FunctionSignatureIterator(native_func_iterator_type nf_it,
native_func_iterator_type nf_it_end_);
explicit FunctionSignatureIterator(func_sig_iterator_type fs_it);

bool operator!=(const FunctionSignatureIterator& func_sign_it);

FunctionSignature operator*();

iterator operator++(int);
func_sig_iterator_type operator++(int);

private:
iterator it_;
native_func_iterator_type native_func_it_;
const native_func_iterator_type native_func_it_end_;
func_sig_iterator_type func_sig_it_;
};
const FunctionSignatureIterator function_signature_begin();
const FunctionSignatureIterator function_signature_end() const;
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/gandiva/expression_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ TEST_F(TestExpressionRegistry, VerifySupportedFunctions) {
functions.push_back((*iter));
}
for (auto& iter : registry_) {
auto function = iter.signature();
auto element = std::find(functions.begin(), functions.end(), function);
EXPECT_NE(element, functions.end())
<< "function " << iter.pc_name() << " missing in supported functions.\n";
for (auto& func_iter : iter.signatures()) {
auto element = std::find(functions.begin(), functions.end(), func_iter);
EXPECT_NE(element, functions.end()) << "function signature " << func_iter.ToString()
<< " missing in supported functions.\n";
}
}
}

Expand Down
8 changes: 7 additions & 1 deletion cpp/src/gandiva/function_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ FunctionRegistry::iterator FunctionRegistry::end() const {
return &(*pc_registry_.end());
}

FunctionRegistry::iterator FunctionRegistry::back() const {
return &(pc_registry_.back());
}

std::vector<NativeFunction> FunctionRegistry::pc_registry_;

SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap();
Expand All @@ -62,7 +66,9 @@ SignatureMap FunctionRegistry::InitPCMap() {
pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end());

for (auto& elem : pc_registry_) {
map.insert(std::make_pair(&(elem.signature()), &elem));
for (auto& func_signature : elem.signatures()) {
map.insert(std::make_pair(&(func_signature), &elem));
}
}

return map;
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class GANDIVA_EXPORT FunctionRegistry {

iterator begin() const;
iterator end() const;
iterator back() const;

private:
static SignatureMap InitPCMap();
Expand Down
80 changes: 41 additions & 39 deletions cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@

namespace gandiva {

#define BINARY_SYMMETRIC_FN(name) NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, name)
#define BINARY_SYMMETRIC_FN(name, ALIASES) \
NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_NULL_IF_NULL, name, ALIASES)

#define BINARY_RELATIONAL_BOOL_FN(name) \
NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name)
#define BINARY_RELATIONAL_BOOL_FN(name, ALIASES) \
NUMERIC_BOOL_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES)

#define BINARY_RELATIONAL_BOOL_DATE_FN(name) \
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name)
#define BINARY_RELATIONAL_BOOL_DATE_FN(name, ALIASES) \
NUMERIC_DATE_TYPES(BINARY_RELATIONAL_SAFE_NULL_IF_NULL, name, ALIASES)

#define UNARY_CAST_TO_FLOAT64(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, name, float64)
#define UNARY_CAST_TO_FLOAT64(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT8, {}, name, float64)

#define UNARY_CAST_TO_FLOAT32(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, name, float32)
#define UNARY_CAST_TO_FLOAT32(name) UNARY_SAFE_NULL_IF_NULL(castFLOAT4, {}, name, float32)

std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
static std::vector<NativeFunction> arithmetic_fn_registry_ = {
UNARY_SAFE_NULL_IF_NULL(not, boolean, boolean),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, int32, int64),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, decimal128, int64),
UNARY_SAFE_NULL_IF_NULL(not, {}, boolean, boolean),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, int32, int64),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64),

// cast to float32
UNARY_CAST_TO_FLOAT32(int32), UNARY_CAST_TO_FLOAT32(int64),
Expand All @@ -46,40 +47,41 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_CAST_TO_FLOAT64(float32), UNARY_CAST_TO_FLOAT64(decimal128),

// cast to decimal
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, int32, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, int64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, float32, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, float64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, decimal128, decimal128),
UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, utf8, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int32, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, int64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float32, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, float64, decimal128),
UNARY_SAFE_NULL_IF_NULL(castDECIMAL, {}, decimal128, decimal128),
UNARY_UNSAFE_NULL_IF_NULL(castDECIMAL, {}, utf8, decimal128),

UNARY_SAFE_NULL_IF_NULL(castDATE, int64, date64),
UNARY_SAFE_NULL_IF_NULL(castDATE, {}, int64, date64),

// add/sub/multiply/divide/mod
BINARY_SYMMETRIC_FN(add), BINARY_SYMMETRIC_FN(subtract),
BINARY_SYMMETRIC_FN(multiply),
NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide),
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, int64, int64, int64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(multiply, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, float64),
BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}),
BINARY_SYMMETRIC_FN(multiply, {}),
NUMERIC_TYPES(BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL, divide, {"div"}),
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int32, int32),
BINARY_GENERIC_SAFE_NULL_IF_NULL(mod, {"modulo"}, int64, int64, int64),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(add, {}, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(subtract, {}, decimal128),
BINARY_SYMMETRIC_SAFE_NULL_IF_NULL(multiply, {}, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, {"div"}, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, float64),

// compare functions
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(equal, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(not_equal, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than_or_equal_to, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than_or_equal_to, decimal128),
BINARY_RELATIONAL_BOOL_FN(equal), BINARY_RELATIONAL_BOOL_FN(not_equal),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to)};
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(equal, {}, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(not_equal, {}, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than, {}, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(less_than_or_equal_to, {}, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than, {}, decimal128),
BINARY_RELATIONAL_SAFE_NULL_IF_NULL(greater_than_or_equal_to, {}, decimal128),
BINARY_RELATIONAL_BOOL_FN(equal, ({"eq", "same"})),
BINARY_RELATIONAL_BOOL_FN(not_equal, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(less_than_or_equal_to, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than, {}),
BINARY_RELATIONAL_BOOL_DATE_FN(greater_than_or_equal_to, {})};

return arithmetic_fn_registry_;
}
Expand Down
Loading