diff --git a/cpp/src/gandiva/function_registry_string.cc b/cpp/src/gandiva/function_registry_string.cc index 1a08dd6d422..844b221efe9 100644 --- a/cpp/src/gandiva/function_registry_string.cc +++ b/cpp/src/gandiva/function_registry_string.cc @@ -480,7 +480,15 @@ std::vector GetStringFunctionRegistry() { NativeFunction("aes_decrypt", {}, DataTypeVector{utf8(), utf8()}, utf8(), kResultNullIfNull, "gdv_fn_aes_decrypt", - NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)}; + NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors), + + NativeFunction("mask_first_n", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "gdv_mask_first_n_utf8_int32", + NativeFunction::kNeedsContext), + + NativeFunction("mask_last_n", {}, DataTypeVector{utf8(), int32()}, utf8(), + kResultNullIfNull, "gdv_mask_last_n_utf8_int32", + NativeFunction::kNeedsContext)}; return string_fn_registry_; } diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 421650d9ea9..0191ec0daee 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -19,7 +19,6 @@ #include -#include #include #include @@ -45,6 +44,24 @@ extern "C" { +static char mask_array[256] = { + (char)0, (char)1, (char)2, (char)3, (char)4, (char)5, (char)6, (char)7, + (char)8, (char)9, (char)10, (char)11, (char)12, (char)13, (char)14, (char)15, + (char)16, (char)17, (char)18, (char)19, (char)20, (char)21, (char)22, (char)23, + (char)24, (char)25, (char)26, (char)27, (char)28, (char)29, (char)30, (char)31, + (char)32, (char)33, (char)34, (char)35, (char)36, (char)37, (char)38, (char)39, + (char)40, (char)41, (char)42, (char)43, (char)44, (char)45, (char)46, (char)47, + 'n', 'n', 'n', 'n', 'n', 'n', 'n', 'n', + 'n', 'n', (char)58, (char)59, (char)60, (char)61, (char)62, (char)63, + (char)64, 'X', 'X', 'X', 'X', 'X', 'X', 'X', + 'X', 'X', 'X', 'X', 'X', 'X', 'X', 'X', + 'X', 'X', 'X', 'X', 'X', 'X', 'X', 'X', + 'X', 'X', 'X', (char)91, (char)92, (char)93, (char)94, (char)95, + (char)96, 'x', 'x', 'x', 'x', 'x', 'x', 'x', + 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', + 'x', 'x', 'x', 'x', 'x', 'x', 'x', 'x', + 'x', 'x', 'x', (char)123, (char)124, (char)125, (char)126, (char)127}; + bool gdv_fn_like_utf8_utf8(int64_t ptr, const char* data, int data_len, const char* pattern, int pattern_len) { gandiva::LikeHolder* holder = reinterpret_cast(ptr); @@ -892,6 +909,203 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l return ret; } + +GANDIVA_EXPORT +const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, + int32_t data_len, int32_t n_to_mask, + int32_t* out_len) { + if (data_len <= 0) { + *out_len = 0; + return nullptr; + } + + if (n_to_mask > data_len) { + n_to_mask = data_len; + } + + *out_len = data_len; + + if (n_to_mask <= 0) { + return data; + } + + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return nullptr; + } + + int bytes_masked; + for (bytes_masked = 0; bytes_masked < n_to_mask; bytes_masked++) { + unsigned char char_single_byte = data[bytes_masked]; + if (char_single_byte > 127) { + // found a multi-byte utf-8 char + break; + } + out[bytes_masked] = mask_array[char_single_byte]; + } + + int chars_masked = bytes_masked; + int out_idx = bytes_masked; + + // Handle multibyte utf8 characters + utf8proc_int32_t utf8_char; + while ((chars_masked < n_to_mask) && (bytes_masked < data_len)) { + auto char_len = + utf8proc_iterate(reinterpret_cast(data + bytes_masked), + data_len, &utf8_char); + + if (char_len < 0) { + gdv_fn_context_set_error_msg(context, utf8proc_errmsg(char_len)); + *out_len = 0; + return nullptr; + } + + switch (utf8proc_category(utf8_char)) { + case 1: + out[out_idx] = 'X'; + out_idx++; + break; + case 2: + out[out_idx] = 'x'; + out_idx++; + break; + case 9: + out[out_idx] = 'n'; + out_idx++; + break; + case 10: + out[out_idx] = 'n'; + out_idx++; + break; + default: + memcpy(out + out_idx, data + bytes_masked, char_len); + out_idx += static_cast(char_len); + break; + } + bytes_masked += static_cast(char_len); + chars_masked++; + } + + // Correct the out_len after masking multibyte characters with single byte characters + *out_len = *out_len - (bytes_masked - out_idx); + + if (bytes_masked < data_len) { + memcpy(out + out_idx, data + bytes_masked, data_len - bytes_masked); + } + + return out; +} + +GANDIVA_EXPORT +const char* gdv_mask_last_n_utf8_int32(int64_t context, const char* data, + int32_t data_len, int32_t n_to_mask, + int32_t* out_len) { + if (data_len <= 0) { + *out_len = 0; + return nullptr; + } + + if (n_to_mask > data_len) { + n_to_mask = data_len; + } + + *out_len = data_len; + + if (n_to_mask <= 0) { + return data; + } + + char* out = reinterpret_cast(gdv_fn_context_arena_malloc(context, *out_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); + *out_len = 0; + return nullptr; + } + + bool has_multi_byte = false; + for (int i = 0; i < data_len; i++) { + unsigned char char_single_byte = data[i]; + if (char_single_byte > 127) { + // found a multi-byte utf-8 char + has_multi_byte = true; + break; + } + } + + if (!has_multi_byte) { + int start_idx = data_len - n_to_mask; + memcpy(out, data, start_idx); + for (int i = start_idx; i < data_len; ++i) { + unsigned char char_single_byte = data[i]; + out[i] = mask_array[char_single_byte]; + } + *out_len = data_len; + return out; + } + + utf8proc_int32_t utf8_char_buffer; + int num_of_chars = static_cast( + utf8proc_decompose(reinterpret_cast(data), data_len, + &utf8_char_buffer, 4, UTF8PROC_STABLE)); + + if (num_of_chars < 0) { + gdv_fn_context_set_error_msg(context, utf8proc_errmsg(num_of_chars)); + *out_len = 0; + return nullptr; + } + + utf8proc_int32_t utf8_char; + int chars_counter = 0; + int bytes_read = 0; + while ((bytes_read < data_len) && (chars_counter < (num_of_chars - n_to_mask))) { + auto char_len = + utf8proc_iterate(reinterpret_cast(data + bytes_read), + data_len, &utf8_char); + chars_counter++; + bytes_read += static_cast(char_len); + } + + int out_idx = bytes_read; + int offset_idx = bytes_read; + + // Populate the first chars, that are not masked + memcpy(out, data, offset_idx); + + while (bytes_read < data_len) { + auto char_len = + utf8proc_iterate(reinterpret_cast(data + bytes_read), + data_len, &utf8_char); + switch (utf8proc_category(utf8_char)) { + case 1: + out[out_idx] = 'X'; + out_idx++; + break; + case 2: + out[out_idx] = 'x'; + out_idx++; + break; + case 9: + out[out_idx] = 'n'; + out_idx++; + break; + case 10: + out[out_idx] = 'n'; + out_idx++; + break; + default: + memcpy(out + out_idx, data + bytes_read, char_len); + out_idx += static_cast(char_len); + break; + } + bytes_read += static_cast(char_len); + } + + *out_len = out_idx; + + return out; +} } namespace gandiva { @@ -1938,5 +2152,22 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_aes_decrypt", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_aes_decrypt)); + + // gdv_mask_first_n and gdv_mask_last_n + std::vector mask_args = { + types->i64_type(), // context + types->i8_ptr_type(), // data + types->i32_type(), // data_length + types->i32_type(), // n_to_mask + types->i32_ptr_type() // out_length + }; + + engine->AddGlobalMappingForFunc("gdv_mask_first_n_utf8_int32", + types->i8_ptr_type() /*return_type*/, mask_args, + reinterpret_cast(gdv_mask_first_n_utf8_int32)); + + engine->AddGlobalMappingForFunc("gdv_mask_last_n_utf8_int32", + types->i8_ptr_type() /*return_type*/, mask_args, + reinterpret_cast(gdv_mask_last_n_utf8_int32)); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index f045d3c3b1d..ca949259f08 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -187,4 +187,14 @@ GANDIVA_EXPORT const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_len, const char* key_data, int32_t key_data_len, int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_mask_first_n_utf8_int32(int64_t context, const char* data, + int32_t data_len, int32_t n_to_mask, + int32_t* out_len); + +GANDIVA_EXPORT +const char* gdv_mask_last_n_utf8_int32(int64_t context, const char* data, + int32_t data_len, int32_t n_to_mask, + int32_t* out_len); } diff --git a/cpp/src/gandiva/gdv_function_stubs_test.cc b/cpp/src/gandiva/gdv_function_stubs_test.cc index c44954f2020..c24f82c718b 100644 --- a/cpp/src/gandiva/gdv_function_stubs_test.cc +++ b/cpp/src/gandiva/gdv_function_stubs_test.cc @@ -798,4 +798,141 @@ TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) { ::testing::HasSubstr("Failed to cast the string e to double")); ctx.Reset(); } + +TEST(TestGdvFnStubs, TestMaskFirstN) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t out_len = 0; + + std::string data = "a〜Çç&"; + auto data_len = static_cast(data.length()); + std::string expected = "x〜Xx&"; + const char* result = + gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 4, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "世界您"; + data_len = static_cast(data.length()); + expected = "世界您"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 4, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "a6Ççé"; + data_len = static_cast(data.length()); + expected = "xnXxé"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 4, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "0123456789"; + data_len = static_cast(data.length()); + expected = "nnnnnnnnnn"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 10, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + data_len = static_cast(data.length()); + expected = "XXXXXXXXXXXXXXXXXXXXXXXXXX"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 26, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "abcdefghijklmnopqrstuvwxyz"; + expected = "xxxxxxxxxxxxxxxxxxxxxxxxxx"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 26, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "aB-6"; + data_len = static_cast(data.length()); + expected = "xX-6"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 3, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + expected = "xX-n"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 5, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + expected = "aB-6"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, -3, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 0, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "ABcd-123456"; + data_len = static_cast(data.length()); + expected = "XXxx-n23456"; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 6, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = ""; + data_len = 0; + expected = ""; + result = gdv_mask_first_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 6, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); +} + +TEST(TestGdvFnStubs, TestMaskLastN) { + gandiva::ExecutionContext ctx; + int64_t ctx_ptr = reinterpret_cast(&ctx); + int32_t out_len = 0; + + std::string data = "a〜Çç&"; + int32_t data_len = static_cast(data.length()); + std::string expected = "a〜Xx&"; + const char* result = + gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 4, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "abÇçé"; + data_len = static_cast(data.length()); + expected = "axXxx"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 4, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "0123456789"; + data_len = static_cast(data.length()); + expected = "nnnnnnnnnn"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 10, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + data_len = static_cast(data.length()); + expected = "XXXXXXXXXXXXXXXXXXXXXXXXXX"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 26, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "abcdefghijklmnopqrstuvwxyz"; + expected = "xxxxxxxxxxxxxxxxxxxxxxxxxx"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 26, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "aB-6"; + data_len = static_cast(data.length()); + expected = "aX-n"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 3, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + expected = "xX-n"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 5, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + expected = "aB-6"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, -3, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 0, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = "ABcd-123456"; + data_len = static_cast(data.length()); + expected = "ABcd-nnnnnn"; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 6, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); + + data = ""; + data_len = 0; + expected = ""; + result = gdv_mask_last_n_utf8_int32(ctx_ptr, data.c_str(), data_len, 6, &out_len); + EXPECT_EQ(expected, std::string(result, out_len)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index 9b4dee9f7ee..31e54409632 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -2326,5 +2326,4 @@ TEST(TestStringOps, TestSoundex) { auto out2 = soundex_utf8(ctx_ptr, "Smythe", 6, &out_len); EXPECT_EQ(std::string(out, out_len), std::string(out2, out_len)); } - } // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 892b60d04b4..c05caa46f76 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -2200,4 +2200,49 @@ TEST_F(TestProjector, TestAesEncryptDecrypt) { EXPECT_ARROW_ARRAY_EQUALS(array_data, outputs_de.at(0)); } +TEST_F(TestProjector, TestMaskFirstMaskLastN) { + // schema for input fields + auto field0 = field("f0", arrow::utf8()); + auto field1 = field("f1", int32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto res_mask_first_n = field("output", arrow::utf8()); + auto res_mask_last_n = field("output", arrow::utf8()); + + // Build expression + auto expr_mask_first_n = + TreeExprBuilder::MakeExpression("mask_first_n", {field0, field1}, res_mask_first_n); + auto expr_mask_last_n = + TreeExprBuilder::MakeExpression("mask_last_n", {field0, field1}, res_mask_last_n); + + std::shared_ptr projector; + auto status = Projector::Make(schema, {expr_mask_first_n, expr_mask_last_n}, + TestConfiguration(), &projector); + EXPECT_TRUE(status.ok()); + + // Create a row-batch with some sample data + int num_records = 4; + auto array0 = MakeArrowArrayUtf8({"aB-6", "ABcd-123456", "A#-c$%6", "A#-c$%6"}, + {true, true, true, true}); + auto array1 = MakeArrowArrayInt32({3, 6, 7, -2}, {true, true, true, true}); + // expected output + auto exp_mask_first_n = MakeArrowArrayUtf8( + {"xX-6", "XXxx-n23456", "X#-x$%n", "A#-c$%6"}, {true, true, true, true}); + auto exp_mask_last_n = MakeArrowArrayUtf8({"aX-n", "ABcd-nnnnnn", "X#-x$%n", "A#-c$%6"}, + {true, true, true, true}); + + // prepare input record batch + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1}); + + // Evaluate expression + arrow::ArrayVector outputs; + status = projector->Evaluate(*in_batch, pool_, &outputs); + EXPECT_TRUE(status.ok()); + + // Validate results + EXPECT_ARROW_ARRAY_EQUALS(exp_mask_first_n, outputs.at(0)); + EXPECT_ARROW_ARRAY_EQUALS(exp_mask_last_n, outputs.at(1)); +} + } // namespace gandiva