diff --git a/cpp/src/gandiva/precompiled/string_ops.cc b/cpp/src/gandiva/precompiled/string_ops.cc index 06f4635e2ec..c521a57b004 100644 --- a/cpp/src/gandiva/precompiled/string_ops.cc +++ b/cpp/src/gandiva/precompiled/string_ops.cc @@ -243,10 +243,16 @@ UTF8_LENGTH(char_length, utf8) UTF8_LENGTH(length, utf8) UTF8_LENGTH(lengthUtf8, binary) +// set max/min str length for space_int32, space_int64, lpad_utf8_int32_utf8 +// and rpad_utf8_int32_utf8 to avoid exceptions +static const gdv_int32 max_str_length = 65536; +static const gdv_int32 min_str_length = 0; // Returns a string of 'n' spaces. #define SPACE_STR(IN_TYPE) \ GANDIVA_EXPORT \ const char* space_##IN_TYPE(gdv_int64 ctx, gdv_##IN_TYPE n, int32_t* out_len) { \ + n = std::min(static_cast(max_str_length), n); \ + n = std::max(static_cast(min_str_length), n); \ gdv_int32 n_times = static_cast(n); \ if (n_times <= 0) { \ *out_len = 0; \ @@ -1762,36 +1768,58 @@ const char* replace_utf8_utf8_utf8(gdv_int64 context, const char* text, out_len); } +FORCE_INLINE +gdv_int32 evaluate_return_char_length(gdv_int32 text_len, gdv_int32 actual_text_len, + gdv_int32 return_length, const char* fill_text, + gdv_int32 fill_text_len) { + gdv_int32 fill_actual_text_len = utf8_length_ignore_invalid(fill_text, fill_text_len); + gdv_int32 repeat_times = (return_length - actual_text_len) / fill_actual_text_len; + gdv_int32 return_char_length = repeat_times * fill_text_len + text_len; + gdv_int32 mod = (return_length - actual_text_len) % fill_actual_text_len; + gdv_int32 char_len = 0; + gdv_int32 fill_index = 0; + for (gdv_int32 i = 0; i < mod; i++) { + char_len = utf8_char_length(fill_text[fill_index]); + fill_index += char_len; + return_char_length += char_len; + } + return return_char_length; +} + FORCE_INLINE const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 text_len, gdv_int32 return_length, const char* fill_text, gdv_int32 fill_text_len, gdv_int32* out_len) { // if the text length or the defined return length (number of characters to return) // is <=0, then return an empty string. + return_length = std::min(max_str_length, return_length); + return_length = std::max(min_str_length, return_length); if (text_len == 0 || return_length <= 0) { *out_len = 0; return ""; } // count the number of utf8 characters on text, ignoring invalid bytes - int text_char_count = utf8_length_ignore_invalid(text, text_len); + int actual_text_len = utf8_length_ignore_invalid(text, text_len); - if (return_length == text_char_count || - (return_length > text_char_count && fill_text_len == 0)) { + if (return_length == actual_text_len || + (return_length > actual_text_len && fill_text_len == 0)) { // case where the return length is same as the text's length, or if it need to // fill into text but "fill_text" is empty, then return text directly. *out_len = text_len; return text; - } else if (return_length < text_char_count) { + } else if (return_length < actual_text_len) { // case where it truncates the result on return length. *out_len = utf8_byte_pos(context, text, text_len, return_length); return text; } else { - // case (return_length > text_char_count) + // case (return_length > actual_text_len) // case where it needs to copy "fill_text" on the string left. The total number - // of chars to copy is given by (return_length - text_char_count) - char* ret = - reinterpret_cast(gdv_fn_context_arena_malloc(context, return_length)); + // of chars to copy is given by (return_length - actual_text_len) + gdv_int32 return_char_length = evaluate_return_char_length( + text_len, actual_text_len, return_length, fill_text, fill_text_len); + char* ret = reinterpret_cast( + gdv_fn_context_arena_malloc(context, return_char_length)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); @@ -1801,12 +1829,12 @@ const char* lpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 // try to fulfill the return string with the "fill_text" continuously int32_t copied_chars_count = 0; int32_t copied_chars_position = 0; - while (copied_chars_count < return_length - text_char_count) { + while (copied_chars_count < return_length - actual_text_len) { int32_t char_len; int32_t fill_index; // for each char, evaluate its length to consider it when mem copying for (fill_index = 0; fill_index < fill_text_len; fill_index += char_len) { - if (copied_chars_count >= return_length - text_char_count) { + if (copied_chars_count >= return_length - actual_text_len) { break; } char_len = utf8_char_length(fill_text[fill_index]); @@ -1830,29 +1858,33 @@ const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 gdv_int32 fill_text_len, gdv_int32* out_len) { // if the text length or the defined return length (number of characters to return) // is <=0, then return an empty string. + return_length = std::min(max_str_length, return_length); + return_length = std::max(min_str_length, return_length); if (text_len == 0 || return_length <= 0) { *out_len = 0; return ""; } // count the number of utf8 characters on text, ignoring invalid bytes - int text_char_count = utf8_length_ignore_invalid(text, text_len); + int actual_text_len = utf8_length_ignore_invalid(text, text_len); - if (return_length == text_char_count || - (return_length > text_char_count && fill_text_len == 0)) { + if (return_length == actual_text_len || + (return_length > actual_text_len && fill_text_len == 0)) { // case where the return length is same as the text's length, or if it need to // fill into text but "fill_text" is empty, then return text directly. *out_len = text_len; return text; - } else if (return_length < text_char_count) { + } else if (return_length < actual_text_len) { // case where it truncates the result on return length. *out_len = utf8_byte_pos(context, text, text_len, return_length); return text; } else { - // case (return_length > text_char_count) + // case (return_length > actual_text_len) // case where it needs to copy "fill_text" on the string right - char* ret = - reinterpret_cast(gdv_fn_context_arena_malloc(context, return_length)); + gdv_int32 return_char_length = evaluate_return_char_length( + text_len, actual_text_len, return_length, fill_text, fill_text_len); + char* ret = reinterpret_cast( + gdv_fn_context_arena_malloc(context, return_char_length)); if (ret == nullptr) { gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string"); @@ -1864,12 +1896,12 @@ const char* rpad_utf8_int32_utf8(gdv_int64 context, const char* text, gdv_int32 // try to fulfill the return string with the "fill_text" continuously int32_t copied_chars_count = 0; int32_t copied_chars_position = 0; - while (text_char_count + copied_chars_count < return_length) { + while (actual_text_len + copied_chars_count < return_length) { int32_t char_len; int32_t fill_length; // for each char, evaluate its length to consider it when mem copying for (fill_length = 0; fill_length < fill_text_len; fill_length += char_len) { - if (text_char_count + copied_chars_count >= return_length) { + if (actual_text_len + copied_chars_count >= return_length) { break; } char_len = utf8_char_length(fill_text[fill_length]); diff --git a/cpp/src/gandiva/precompiled/string_ops_test.cc b/cpp/src/gandiva/precompiled/string_ops_test.cc index e4e4a7db27e..c35535a02cd 100644 --- a/cpp/src/gandiva/precompiled/string_ops_test.cc +++ b/cpp/src/gandiva/precompiled/string_ops_test.cc @@ -83,6 +83,10 @@ TEST(TestStringOps, TestSpace) { EXPECT_EQ(std::string(out, out_len), " "); out = space_int32(ctx_ptr, -5, &out_len); EXPECT_EQ(std::string(out, out_len), ""); + out = space_int32(ctx_ptr, 65537, &out_len); + EXPECT_EQ(std::string(out, out_len), std::string(65536, ' ')); + out = space_int32(ctx_ptr, 2147483647, &out_len); + EXPECT_EQ(std::string(out, out_len), std::string(65536, ' ')); out = space_int64(ctx_ptr, 2, &out_len); EXPECT_EQ(std::string(out, out_len), " "); @@ -92,6 +96,12 @@ TEST(TestStringOps, TestSpace) { EXPECT_EQ(std::string(out, out_len), " "); out = space_int64(ctx_ptr, -5, &out_len); EXPECT_EQ(std::string(out, out_len), ""); + out = space_int64(ctx_ptr, 65536, &out_len); + EXPECT_EQ(std::string(out, out_len), std::string(65536, ' ')); + out = space_int64(ctx_ptr, 9223372036854775807, &out_len); + EXPECT_EQ(std::string(out, out_len), std::string(65536, ' ')); + out = space_int64(ctx_ptr, -2639077559LL, &out_len); + EXPECT_EQ(std::string(out, out_len), ""); } TEST(TestStringOps, TestIsSubstr) { @@ -1034,6 +1044,9 @@ TEST(TestStringOps, TestLpadString) { out_str = lpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len); EXPECT_EQ(std::string(out_str, out_len), "дhello"); + out_str = lpad_utf8_int32_utf8(ctx_ptr, "大学路", 9, 65536, "哈", 3, &out_len); + EXPECT_EQ(out_len, 65536 * 3); + // LPAD function tests - with NO pad text out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len); EXPECT_EQ(std::string(out_str, out_len), "Test"); @@ -1058,6 +1071,12 @@ TEST(TestStringOps, TestLpadString) { out_str = lpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len); EXPECT_EQ(std::string(out_str, out_len), " абвгд"); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, 65537, &out_len); + EXPECT_EQ(std::string(out_str, out_len), std::string(65526, ' ') + "TestString"); + + out_str = lpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); } TEST(TestStringOps, TestRpadString) { @@ -1103,6 +1122,9 @@ TEST(TestStringOps, TestRpadString) { out_str = rpad_utf8_int32_utf8(ctx_ptr, "hello", 5, 6, "д", 2, &out_len); EXPECT_EQ(std::string(out_str, out_len), "helloд"); + out_str = rpad_utf8_int32_utf8(ctx_ptr, "大学路", 9, 655360, "哈雷路", 3, &out_len); + EXPECT_EQ(out_len, 65536 * 3); + // RPAD function tests - with NO pad text out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 4, &out_len); EXPECT_EQ(std::string(out_str, out_len), "Test"); @@ -1127,6 +1149,12 @@ TEST(TestStringOps, TestRpadString) { out_str = rpad_utf8_int32(ctx_ptr, "абвгд", 10, 7, &out_len); EXPECT_EQ(std::string(out_str, out_len), "абвгд "); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, 65537, &out_len); + EXPECT_EQ(std::string(out_str, out_len), "TestString" + std::string(65526, ' ')); + + out_str = rpad_utf8_int32(ctx_ptr, "TestString", 10, -1, &out_len); + EXPECT_EQ(std::string(out_str, out_len), ""); } TEST(TestStringOps, TestRtrim) {