diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 2bc6936d77b..bb28bec8c14 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -510,6 +510,10 @@ std::vector GetStringFunctionRegistry() { kResultNullIfNull, "gdv_mask_last_n_utf8_int32", NativeFunction::kNeedsContext), + NativeFunction("find_in_set", {}, DataTypeVector{utf8(), utf8()}, int32(), + kResultNullIfNull, "find_in_set_utf8_utf8", + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + NativeFunction("instr", {}, DataTypeVector{utf8(), utf8()}, int32(), kResultNullIfNull, "instr_utf8"), diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index c255b9a11c0..81bc117f2fe 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -3034,4 +3034,45 @@ int32_t instr_utf8(const char* string, int32_t string_len, const char* substring } return 0; } + +FORCE_INLINE +int32_t find_in_set_utf8_utf8(int64_t context, const char* to_find, int32_t to_find_len, + const char* string_list, int32_t string_list_len) { + // Return 0 if to search entry have commas + if (is_substr_utf8_utf8(to_find, to_find_len, reinterpret_cast(","), 1)) { + return 0; + } + + int32_t cur_pos_in_array = 0; + int32_t cur_length = 0; + bool matching = true; + + for (int i = 0; i < string_list_len; i++) { + if (string_list[i] == ',') { + cur_pos_in_array++; + if (matching && cur_length == to_find_len) { + return cur_pos_in_array; + } else { + matching = true; + cur_length = 0; + } + } else { + if (cur_length + 1 <= string_list_len) { + if (!matching || (memcmp(string_list + i, to_find + cur_length, 1))) { + matching = false; + } + } else { + matching = false; + } + cur_length++; + } + } + + if (matching && cur_length == to_find_len) { + cur_pos_in_array++; + return cur_pos_in_array; + } else { + return 0; + } +} } // extern "C" diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index b84c51b3a6b..4bfa4709638 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -2702,4 +2702,30 @@ TEST(TestStringOps, TestInstr) { result = instr_utf8(s1.c_str(), s1_len, s2.c_str(), s2_len); EXPECT_EQ(result, 8); } + +TEST(TestStringOps, TestFindInSet) { + gandiva::ExecutionContext ctx; + auto ctx_ptr = reinterpret_cast(&ctx); + int32_t result; + result = find_in_set_utf8_utf8(ctx_ptr, "EE", 2, ",A,B,C,D,EE,F", 13); + EXPECT_EQ(result, 6); + result = find_in_set_utf8_utf8(ctx_ptr, "A", 1, "A,B,C,D,EE,F", 12); + EXPECT_EQ(result, 1); + result = find_in_set_utf8_utf8(ctx_ptr, "AAAB", 4, "A,B,C,D,EE,F", 12); + EXPECT_EQ(result, 0); + result = find_in_set_utf8_utf8(ctx_ptr, "E,E", 3, "A,B,C,D,EE,F", 12); + EXPECT_EQ(result, 0); + result = find_in_set_utf8_utf8(ctx_ptr, "C", 1, "A,B,,,,,,,C,,,,,", 16); + EXPECT_EQ(result, 9); + result = find_in_set_utf8_utf8(ctx_ptr, "", 0, "", 0); + EXPECT_EQ(result, 1); + result = find_in_set_utf8_utf8(ctx_ptr, "", 0, " ", 1); + EXPECT_EQ(result, 0); + result = find_in_set_utf8_utf8(ctx_ptr, " ", 1, "", 0); + EXPECT_EQ(result, 0); + result = find_in_set_utf8_utf8(ctx_ptr, "", 0, "a,b,,c,d", 8); + EXPECT_EQ(result, 3); + result = find_in_set_utf8_utf8(ctx_ptr, "", 0, ",", 1); + EXPECT_EQ(result, 1); +} } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/types.h b/cpp/src/gandiva/precompiled/types.h index a0a83f18dd4..67220659f8c 100644 --- a/cpp/src/gandiva/precompiled/types.h +++ b/cpp/src/gandiva/precompiled/types.h @@ -829,4 +829,7 @@ const char* elt_int32_utf8_utf8_utf8_utf8_utf8( int32_t instr_utf8(const char* string, int32_t string_len, const char* substring, int32_t substring_len); +int32_t find_in_set_utf8_utf8(int64_t context, const char* to_find, int32_t to_find_len, + const char* string_list, int32_t string_list_len); + } // extern "C" diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 65597b38f0b..ebd9bab3473 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2824,6 +2824,48 @@ TEST_F(TestProjector, TestInstr) { // Validate results EXPECT_ARROW_ARRAY_EQUALS(exp_sum, outputs.at(0)); } +TEST_F(TestProjector, TestFindInSet) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", arrow::utf8()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto output_find_in_set = field("find_in_set_output", int32()); + + // Build expression + auto find_in_set_expr = TreeExprBuilder::MakeExpression("find_in_set", {field0, field1}, + output_find_in_set); + + std::shared_ptr projector; + auto status = + Projector::Make(schema, {find_in_set_expr}, TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 8; + auto array0 = + MakeArrowArrayUtf8({"ABC", "...", "!C", "MORE", "学路", "b大", "路", "学路"}, + {true, true, true, true, true, true, true, true}); + auto array1 = MakeArrowArrayUtf8( + {"ZXL,KMY,DDD,ABC", "!!!,@@@,###,...,,,", ",A,,,,,,,,!C,,,,,", "MORE", + "学路,学路,学路,123", "大b,,,b大", "大b,,学路,学,b大", "学路"}, + {true, true, true, true, true, true, true, true}); + // expected output + auto exp_res = MakeArrowArrayInt32({4, 4, 10, 1, 1, 4, 0, 1}, + {true, true, true, true, true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_res, outputs.at(0)); +} TEST_F(TestProjector, TestNextDay) { // schema for input fields