diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index d1f97cdb3e8..1acdc910e43 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -84,6 +84,22 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_fn_castFLOAT8_utf8", NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("castINT", {}, DataTypeVector{binary()}, int32(), kResultNullIfNull, + "gdv_fn_castINT_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castBIGINT", {}, DataTypeVector{binary()}, int64(), + kResultNullIfNull, "gdv_fn_castBIGINT_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT4", {}, DataTypeVector{binary()}, float32(), + kResultNullIfNull, "gdv_fn_castFLOAT4_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("castFLOAT8", {}, DataTypeVector{binary()}, float64(), + kResultNullIfNull, "gdv_fn_castFLOAT8_varbinary", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("castVARCHAR", {}, DataTypeVector{boolean(), int64()}, utf8(), kResultNullIfNull, "castVARCHAR_bool_int64", NativeFunction::kNeedsContext), diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 832eebcaa1a..00542d875cd 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -20,7 +20,6 @@ #include #include -#include "arrow/util/formatting.h" #include "arrow/util/value_parsing.h" #include "gandiva/engine.h" #include "gandiva/exported_funcs.h" @@ -276,10 +275,10 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, return ret; } -#define CAST_NUMERIC_FROM_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME) \ +#define CAST_NUMERIC_STRING(OUT_TYPE, ARROW_TYPE, TYPE_NAME, INNER_TYPE) \ GANDIVA_EXPORT \ - OUT_TYPE gdv_fn_cast##TYPE_NAME##_utf8(int64_t context, const char* data, \ - int32_t len) { \ + OUT_TYPE gdv_fn_cast##TYPE_NAME##_##INNER_TYPE(int64_t context, const char* data, \ + int32_t len) { \ OUT_TYPE val = 0; \ /* trim leading and trailing spaces */ \ int32_t trimmed_len; \ @@ -300,12 +299,17 @@ char* gdv_fn_dec_to_string(int64_t context, int64_t x_high, uint64_t x_low, return val; \ } -CAST_NUMERIC_FROM_STRING(int32_t, arrow::Int32Type, INT) -CAST_NUMERIC_FROM_STRING(int64_t, arrow::Int64Type, BIGINT) -CAST_NUMERIC_FROM_STRING(float, arrow::FloatType, FLOAT4) -CAST_NUMERIC_FROM_STRING(double, arrow::DoubleType, FLOAT8) +CAST_NUMERIC_STRING(int32_t, arrow::Int32Type, INT, utf8) +CAST_NUMERIC_STRING(int64_t, arrow::Int64Type, BIGINT, utf8) +CAST_NUMERIC_STRING(float, arrow::FloatType, FLOAT4, utf8) +CAST_NUMERIC_STRING(double, arrow::DoubleType, FLOAT8, utf8) -#undef CAST_NUMERIC_FROM_STRING +CAST_NUMERIC_STRING(int32_t, arrow::Int32Type, INT, varbinary) +CAST_NUMERIC_STRING(int64_t, arrow::Int64Type, BIGINT, varbinary) +CAST_NUMERIC_STRING(float, arrow::FloatType, FLOAT4, varbinary) +CAST_NUMERIC_STRING(double, arrow::DoubleType, FLOAT8, varbinary) + +#undef CAST_NUMERIC_STRING #define GDV_FN_CAST_VARCHAR_INTEGER(IN_TYPE, ARROW_TYPE) \ GANDIVA_EXPORT \ @@ -590,6 +594,36 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { "gdv_fn_castVARCHAR_float64_int64", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_castVARCHAR_float64_int64)); + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castINT_varbinary", types->i32_type(), args, + reinterpret_cast(gdv_fn_castINT_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castBIGINT_varbinary", types->i64_type(), args, + reinterpret_cast(gdv_fn_castBIGINT_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_varbinary", types->float_type(), + args, + reinterpret_cast(gdv_fn_castFLOAT4_varbinary)); + + args = {types->i64_type(), // int64_t context_ptr + types->i8_ptr_type(), // const char* data + types->i32_type()}; // int32_t lenr + + engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_varbinary", types->double_type(), + args, + reinterpret_cast(gdv_fn_castFLOAT8_varbinary)); + // gdv_fn_sha1_int8 args = { types->i64_type(), // context diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 0a6cd70ca7c..e66c275d609 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -108,4 +108,16 @@ const char* gdv_fn_castVARCHAR_float32_int64(int64_t context, float value, int64 GANDIVA_EXPORT const char* gdv_fn_castVARCHAR_float64_int64(int64_t context, double value, int64_t len, int32_t* out_len); + +GANDIVA_EXPORT +int32_t gdv_fn_castINT_varbinary(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +int64_t gdv_fn_castBIGINT_varbinary(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +float gdv_fn_castFLOAT4_varbinary(int64_t context, const char* data, int32_t data_len); + +GANDIVA_EXPORT +double gdv_fn_castFLOAT8_varbinary(int64_t context, const char* data, int32_t data_len); } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index 8f44ce27982..1bce4d10d2f 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -290,4 +290,140 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) { EXPECT_FALSE(ctx.has_error()); } +TEST(TestGdvFnStubs, TestCastVarbinaryINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "2147483647", 10), 2147483647); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "02147483647", 11), 2147483647); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-2147483648", 11), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, "-02147483648", 12), -2147483648LL); + EXPECT_EQ(gdv_fn_castINT_varbinary(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castINT_varbinary(ctx_ptr, "2147483648", 10); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 2147483648 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "-2147483649", 11); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -2147483649 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int32")); + ctx.Reset(); + + gdv_fn_castINT_varbinary(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int32")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryBIGINT) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-45", 3), -45); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "0", 1), 0); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775807", 19), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "09223372036854775807", 20), + 9223372036854775807LL); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775808", 20), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, "-009223372036854775808", 22), + -9223372036854775807LL - 1); + EXPECT_EQ(gdv_fn_castBIGINT_varbinary(ctx_ptr, " 12 ", 4), 12); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "9223372036854775808", 19); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 9223372036854775808 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "-9223372036854775809", 20); + EXPECT_THAT( + ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string -9223372036854775809 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "12.34", 5); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string 12.34 to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "abc", 3); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string abc to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to int64")); + ctx.Reset(); + + gdv_fn_castBIGINT_varbinary(ctx_ptr, "-", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string - to int64")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryFloat4) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-45.34", 6), -45.34f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "0", 1), 0.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "5", 1), 5.0f); + EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, " 3.4 ", 5), 3.4f); + + gdv_fn_castFLOAT4_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to float")); + ctx.Reset(); + + gdv_fn_castFLOAT4_varbinary(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to float")); + ctx.Reset(); +} + +TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) { + gandiva::ExecutionContext ctx; + + int64_t ctx_ptr = reinterpret_cast(&ctx); + + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "-45.34", 6), -45.34); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "0", 1), 0.0); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, "5", 1), 5.0); + EXPECT_EQ(gdv_fn_castFLOAT8_varbinary(ctx_ptr, " 3.4 ", 5), 3.4); + + gdv_fn_castFLOAT8_varbinary(ctx_ptr, "", 0); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string to double")); + ctx.Reset(); + + gdv_fn_castFLOAT8_varbinary(ctx_ptr, "e", 1); + EXPECT_THAT(ctx.get_error(), + ::testing::HasSubstr("Failed to cast the string e to double")); + ctx.Reset(); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index b63af40d359..b0a3a5dc481 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -908,6 +908,59 @@ TEST_F(TestProjector, TestCastFunction) { EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3)); } +// Test to ensure behaviour of cast functions when the validity is false for an input. The +// function should not run for that input. +TEST_F(TestProjector, TestCastVarbinaryFunction) { + auto field0 = field("f0", arrow::binary()); + auto schema = arrow::schema({field0}); + + // output fields + auto res_float4 = field("res_float4", arrow::float32()); + auto res_float8 = field("res_float8", arrow::float64()); + auto res_int4 = field("res_int4", arrow::int32()); + auto res_int8 = field("res_int8", arrow::int64()); + + // Build expression + auto cast_expr_float4 = + TreeExprBuilder::MakeExpression("castFLOAT4", {field0}, res_float4); + auto cast_expr_float8 = + TreeExprBuilder::MakeExpression("castFLOAT8", {field0}, res_float8); + auto cast_expr_int4 = TreeExprBuilder::MakeExpression("castINT", {field0}, res_int4); + auto cast_expr_int8 = TreeExprBuilder::MakeExpression("castBIGINT", {field0}, res_int8); + + std::shared_ptr projector; + + // {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8} + auto status = Projector::Make( + schema, {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + + // Last validity is false and the cast functions throw error when input is empty. Should + // not be evaluated due to addition of NativeFunction::kCanReturnErrors + auto array0 = MakeArrowArrayBinary({"1", "2", "3", "4"}, {true, true, true, false}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0}); + + auto out_float4 = MakeArrowArrayFloat32({1, 2, 3, 0}, {true, true, true, false}); + auto out_float8 = MakeArrowArrayFloat64({1, 2, 3, 0}, {true, true, true, false}); + auto out_int4 = MakeArrowArrayInt32({1, 2, 3, 0}, {true, true, true, false}); + auto out_int8 = MakeArrowArrayInt64({1, 2, 3, 0}, {true, true, true, false}); + + arrow::ArrayVector outputs; + + // Evaluate expression + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + EXPECT_ARROW_ARRAY_EQUALS(out_float4, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(out_float8, outputs.at(1)); + EXPECT_ARROW_ARRAY_EQUALS(out_int4, outputs.at(2)); + EXPECT_ARROW_ARRAY_EQUALS(out_int8, outputs.at(3)); +} + TEST_F(TestProjector, TestToDate) { // schema for input fields auto field0 = field("f0", arrow::utf8()); diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java index 606c1a922e5..295a2fca5a3 100644 --- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java +++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorTest.java @@ -1950,6 +1950,59 @@ public void testCastFloat() throws Exception { releaseValueVectors(output); } + @Test + public void testCastFloatVarbinary() throws Exception { + Field inField = Field.nullable("input", new ArrowType.Binary()); + TreeNode inNode = TreeBuilder.makeField(inField); + TreeNode castFLOAT8Fn = TreeBuilder.makeFunction("castFLOAT8", Lists.newArrayList(inNode), + float64); + Field resultField = Field.nullable("result", float64); + List exprs = + Lists.newArrayList( + TreeBuilder.makeExpression(castFLOAT8Fn, resultField)); + Schema schema = new Schema(Lists.newArrayList(inField)); + Projector eval = Projector.make(schema, exprs); + int numRows = 5; + byte[] validity = new byte[] {(byte) 255}; + String[] values = + new String[] { + "2.3", + "-11.11", + "0", + "111", + "12345.67" + }; + double[] expValues = + new double[] { + 2.3, -11.11, 0, 111, 12345.67 + }; + ArrowBuf bufValidity = buf(validity); + List bufData = stringBufs(values); + ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0); + ArrowRecordBatch batch = + new ArrowRecordBatch( + numRows, + Lists.newArrayList(fieldNode), + Lists.newArrayList(bufValidity, bufData.get(0), bufData.get(1))); + List output = new ArrayList<>(); + for (int i = 0; i < exprs.size(); i++) { + Float8Vector float8Vector = new Float8Vector(EMPTY_SCHEMA_PATH, allocator); + float8Vector.allocateNew(numRows); + output.add(float8Vector); + } + eval.evaluate(batch, output); + eval.close(); + for (ValueVector valueVector : output) { + Float8Vector float8Vector = (Float8Vector) valueVector; + for (int j = 0; j < numRows; j++) { + assertFalse(float8Vector.isNull(j)); + assertTrue(expValues[j] == float8Vector.get(j)); + } + } + releaseRecordBatch(batch); + releaseValueVectors(output); + } + @Test(expected = GandivaException.class) public void testCastFloatInvalidValue() throws Exception { Field inField = Field.nullable("input", new ArrowType.Utf8());