Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
kResultNullIfNull, "gdv_fn_castFLOAT8_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castINT", {}, DataTypeVector{binary()}, int32(), kResultNullIfNull,
"gdv_fn_castINT_varbinary",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castBIGINT", {}, DataTypeVector{binary()}, int64(),
kResultNullIfNull, "gdv_fn_castBIGINT_varbinary",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castFLOAT4", {}, DataTypeVector{binary()}, float32(),
kResultNullIfNull, "gdv_fn_castFLOAT4_varbinary",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castFLOAT8", {}, DataTypeVector{binary()}, float64(),
kResultNullIfNull, "gdv_fn_castFLOAT8_varbinary",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castVARCHAR", {}, DataTypeVector{boolean(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_bool_int64",
NativeFunction::kNeedsContext),
Expand Down
52 changes: 43 additions & 9 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <string>
#include <vector>

#include "arrow/util/formatting.h"
#include "arrow/util/value_parsing.h"
#include "gandiva/engine.h"
#include "gandiva/exported_funcs.h"
Expand Down Expand Up @@ -276,10 +275,10 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
return ret;
}

#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \
#define CAST_NUMERIC_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME, INNER_TYPE) \
GANDIVA_EXPORT \
OUT_TYPE gdv_fn_cast##TYPE_NAME##_utf8(int64_t context, const char* data, \
int32_t len) { \
OUT_TYPE gdv_fn_cast##TYPE_NAME##_##INNER_TYPE(int64_t context, const char* data, \
int32_t len) { \
OUT_TYPE val = 0; \
/* trim leading and trailing spaces */ \
int32_t trimmed_len; \
Expand All @@ -300,12 +299,17 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low,
return val; \
}

CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT)
CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT)
CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4)
CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8)
CAST_NUMERIC_STRING(int32_t, arrow::Int32Type, INT, utf8)
CAST_NUMERIC_STRING(int64_t, arrow::Int64Type, BIGINT, utf8)
CAST_NUMERIC_STRING(float, arrow::FloatType, FLOAT4, utf8)
CAST_NUMERIC_STRING(double, arrow::DoubleType, FLOAT8, utf8)

#undef CAST_NUMERIC_FROM_STRING
CAST_NUMERIC_STRING(int32_t, arrow::Int32Type, INT, varbinary)
CAST_NUMERIC_STRING(int64_t, arrow::Int64Type, BIGINT, varbinary)
CAST_NUMERIC_STRING(float, arrow::FloatType, FLOAT4, varbinary)
CAST_NUMERIC_STRING(double, arrow::DoubleType, FLOAT8, varbinary)

#undef CAST_NUMERIC_STRING

#define GDV_FN_CAST_VARCHAR_INTEGER(IN_TYPE, ARROW_TYPE) \
GANDIVA_EXPORT \
Expand Down Expand Up @@ -590,6 +594,36 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
"gdv_fn_castVARCHAR_float64_int64", types->i8_ptr_type() /*return_type*/, args,
reinterpret_cast<void*>(gdv_fn_castVARCHAR_float64_int64));

args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type()}; // int32_t lenr

engine->AddGlobalMappingForFunc("gdv_fn_castINT_varbinary", types->i32_type(), args,
reinterpret_cast<void*>(gdv_fn_castINT_varbinary));

args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type()}; // int32_t lenr

engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_varbinary", types->i64_type(), args,
reinterpret_cast<void*>(gdv_fn_castBIGINT_varbinary));

args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type()}; // int32_t lenr

engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_varbinary", types->float_type(),
args,
reinterpret_cast<void*>(gdv_fn_castFLOAT4_varbinary));

args = {types->i64_type(), // int64_t context_ptr
types->i8_ptr_type(), // const char* data
types->i32_type()}; // int32_t lenr

engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_varbinary", types->double_type(),
args,
reinterpret_cast<void*>(gdv_fn_castFLOAT8_varbinary));

// gdv_fn_sha1_int8
args = {
types->i64_type(), // context
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,16 @@ const char* gdv_fn_castVARCHAR_float32_int64(int64_t context, float value, int64
GANDIVA_EXPORT
const char* gdv_fn_castVARCHAR_float64_int64(int64_t context, double value, int64_t len,
int32_t* out_len);

GANDIVA_EXPORT
int32_t gdv_fn_castINT_varbinary(int64_t context, const char* data, int32_t data_len);

GANDIVA_EXPORT
int64_t gdv_fn_castBIGINT_varbinary(int64_t context, const char* data, int32_t data_len);

GANDIVA_EXPORT
float gdv_fn_castFLOAT4_varbinary(int64_t context, const char* data, int32_t data_len);

GANDIVA_EXPORT
double gdv_fn_castFLOAT8_varbinary(int64_t context, const char* data, int32_t data_len);
}
136 changes: 136 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,140 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
EXPECT_FALSE(ctx.has_error());
}

TEST(TestGdvFnStubs, TestCastVarbinaryINT) {
gandiva::ExecutionContext ctx;

int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-45", 3), -45);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "0", 1), 0);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "2147483647", 10), 2147483647);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "02147483647", 11), 2147483647);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-2147483648", 11), -2147483648LL);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-02147483648", 12), -2147483648LL);
EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, " 12 ", 4), 12);

gdv_fn_castINT_varbinary(ctx_ptr, "2147483648", 10);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string 2147483648 to int32"));
ctx.Reset();

gdv_fn_castINT_varbinary(ctx_ptr, "-2147483649", 11);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string -2147483649 to int32"));
ctx.Reset();

gdv_fn_castINT_varbinary(ctx_ptr, "12.34", 5);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string 12.34 to int32"));
ctx.Reset();

gdv_fn_castINT_varbinary(ctx_ptr, "abc", 3);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string abc to int32"));
ctx.Reset();

gdv_fn_castINT_varbinary(ctx_ptr, "", 0);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string to int32"));
ctx.Reset();

gdv_fn_castINT_varbinary(ctx_ptr, "-", 1);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string - to int32"));
ctx.Reset();
}

TEST(TestGdvFnStubs, TestCastVarbinaryBIGINT) {
gandiva::ExecutionContext ctx;

int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-45", 3), -45);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "0", 1), 0);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775807", 19),
9223372036854775807LL);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "09223372036854775807", 20),
9223372036854775807LL);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775808", 20),
-9223372036854775807LL - 1);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-009223372036854775808", 22),
-9223372036854775807LL - 1);
EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, " 12 ", 4), 12);

gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775808", 19);
EXPECT_THAT(
ctx.get_error(),
::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64"));
ctx.Reset();

gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775809", 20);
EXPECT_THAT(
ctx.get_error(),
::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64"));
ctx.Reset();

gdv_fn_castBIGINT_varbinary(ctx_ptr, "12.34", 5);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string 12.34 to int64"));
ctx.Reset();

gdv_fn_castBIGINT_varbinary(ctx_ptr, "abc", 3);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string abc to int64"));
ctx.Reset();

gdv_fn_castBIGINT_varbinary(ctx_ptr, "", 0);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string to int64"));
ctx.Reset();

gdv_fn_castBIGINT_varbinary(ctx_ptr, "-", 1);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string - to int64"));
ctx.Reset();
}

TEST(TestGdvFnStubs, TestCastVarbinaryFloat4) {
gandiva::ExecutionContext ctx;

int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-45.34", 6), -45.34f);
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "0", 1), 0.0f);
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "5", 1), 5.0f);
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " 3.4 ", 5), 3.4f);

gdv_fn_castFLOAT4_varbinary(ctx_ptr, "", 0);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string to float"));
ctx.Reset();

gdv_fn_castFLOAT4_varbinary(ctx_ptr, "e", 1);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string e to float"));
ctx.Reset();
}

TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) {
gandiva::ExecutionContext ctx;

int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);

EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "-45.34", 6), -45.34);
EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "0", 1), 0.0);
EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "5", 1), 5.0);
EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, " 3.4 ", 5), 3.4);

gdv_fn_castFLOAT8_varbinary(ctx_ptr, "", 0);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string to double"));
ctx.Reset();

gdv_fn_castFLOAT8_varbinary(ctx_ptr, "e", 1);
EXPECT_THAT(ctx.get_error(),
::testing::HasSubstr("Failed to cast the string e to double"));
ctx.Reset();
}

} // namespace gandiva
53 changes: 53 additions & 0 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,59 @@ TEST_F(TestProjector, TestCastFunction) {
EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3));
}

// Test to ensure behaviour of cast functions when the validity is false for an input. The
// function should not run for that input.
TEST_F(TestProjector, TestCastVarbinaryFunction) {
auto field0 = field("f0", arrow::binary());
auto schema = arrow::schema({field0});

// output fields
auto res_float4 = field("res_float4", arrow::float32());
auto res_float8 = field("res_float8", arrow::float64());
auto res_int4 = field("res_int4", arrow::int32());
auto res_int8 = field("res_int8", arrow::int64());

// Build expression
auto cast_expr_float4 =
TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4);
auto cast_expr_float8 =
TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8);
auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4);
auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8);

std::shared_ptr<Projector> projector;

// {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}
auto status = Projector::Make(
schema, {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8},
TestConfiguration(), &projector);
EXPECT_TRUE(status.ok());

// Create a row-batch with some sample data
int num_records = 4;

// Last validity is false and the cast functions throw error when input is empty. Should
// not be evaluated due to addition of NativeFunction::kCanReturnErrors
auto array0 = MakeArrowArrayBinary({"1", "2", "3", "4"}, {true, true, true, false});
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});

auto out_float4 = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false});
auto out_float8 = MakeArrowArrayFloat64({1, 2, 3, 0}, {true, true, true, false});
auto out_int4 = MakeArrowArrayInt32({1, 2, 3, 0}, {true, true, true, false});
auto out_int8 = MakeArrowArrayInt64({1, 2, 3, 0}, {true, true, true, false});

arrow::ArrayVector outputs;

// Evaluate expression
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok());

EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0));
EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1));
EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(2));
EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3));
}

TEST_F(TestProjector, TestToDate) {
// schema for input fields
auto field0 = field("f0", arrow::utf8());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,59 @@ public void testCastFloat() throws Exception {
releaseValueVectors(output);
}

@Test
public void testCastFloatVarbinary() throws Exception {
Field inField = Field.nullable("input", new ArrowType.Binary());
TreeNode inNode = TreeBuilder.makeField(inField);
TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode),
float64);
Field resultField = Field.nullable("result", float64);
List<ExpressionTree> exprs =
Lists.newArrayList(
TreeBuilder.makeExpression(castFLOAT8Fn, resultField));
Schema schema = new Schema(Lists.newArrayList(inField));
Projector eval = Projector.make(schema, exprs);
int numRows = 5;
byte[] validity = new byte[] {(byte) 255};
String[] values =
new String[] {
"2.3",
"-11.11",
"0",
"111",
"12345.67"
};
double[] expValues =
new double[] {
2.3, -11.11, 0, 111, 12345.67
};
ArrowBuf bufValidity = buf(validity);
List<ArrowBuf> bufData = stringBufs(values);
ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode),
Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1)));
List<ValueVector> output = new ArrayList<>();
for (int i = 0; i < exprs.size(); i++) {
Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator);
float8Vector.allocateNew(numRows);
output.add(float8Vector);
}
eval.evaluate(batch, output);
eval.close();
for (ValueVector valueVector : output) {
Float8Vector float8Vector = (Float8Vector) valueVector;
for (int j = 0; j < numRows; j++) {
assertFalse(float8Vector.isNull(j));
assertTrue(expValues[j] == float8Vector.get(j));
}
}
releaseRecordBatch(batch);
releaseValueVectors(output);
}

@Test(expected = GandivaException.class)
public void testCastFloatInvalidValue() throws Exception {
Field inField = Field.nullable("input", new ArrowType.Utf8());
Expand Down