diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 730836bd118..f59426d8f1b 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -45,7 +45,7 @@ struct ArithmeticOptions : public FunctionOptions { struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions { explicit MatchSubstringOptions(std::string pattern) : pattern(std::move(pattern)) {} - /// The exact substring to look for inside input values. + /// The exact substring (or regex, depending on kernel) to look for inside input values. std::string pattern; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 39869879561..9ec1fe005d4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -368,83 +368,130 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch, } } -template -void TransformMatchSubstring(const uint8_t* pattern, int64_t pattern_length, - const offset_type* offsets, const uint8_t* data, - int64_t length, int64_t output_offset, uint8_t* output) { - // This is an implementation of the Knuth-Morris-Pratt algorithm - - // Phase 1: Build the prefix table - std::vector prefix_table(pattern_length + 1); - offset_type prefix_length = -1; - prefix_table[0] = -1; - for (offset_type pos = 0; pos < pattern_length; ++pos) { - // The prefix cannot be expanded, reset. - while (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) { - prefix_length = prefix_table[prefix_length]; - } - prefix_length++; - prefix_table[pos + 1] = prefix_length; - } - - // Phase 2: Find the prefix in the data - FirstTimeBitmapWriter bitmap_writer(output, output_offset, length); - for (int64_t i = 0; i < length; ++i) { - const uint8_t* current_data = data + offsets[i]; - int64_t current_length = offsets[i + 1] - offsets[i]; - - int64_t pattern_pos = 0; - for (int64_t k = 0; k < current_length; k++) { - while ((pattern_pos >= 0) && (pattern[pattern_pos] != current_data[k])) { - pattern_pos = prefix_table[pattern_pos]; - } - pattern_pos++; - if (pattern_pos == pattern_length) { - bitmap_writer.Set(); - break; - } - } - bitmap_writer.Next(); - } - bitmap_writer.Finish(); -} - using MatchSubstringState = OptionsWrapper; -template +template struct MatchSubstring { using offset_type = typename Type::offset_type; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - MatchSubstringOptions arg = MatchSubstringState::Get(ctx); - const uint8_t* pat = reinterpret_cast(arg.pattern.c_str()); - const int64_t pat_size = arg.pattern.length(); + // TODO Cache matcher across invocations (for regex compilation) + Matcher matcher(ctx, MatchSubstringState::Get(ctx)); + if (ctx->HasError()) return; StringBoolTransform( ctx, batch, - [pat, pat_size](const void* offsets, const uint8_t* data, int64_t length, - int64_t output_offset, uint8_t* output) { - TransformMatchSubstring( - pat, pat_size, reinterpret_cast(offsets), data, length, - output_offset, output); + [&matcher](const void* raw_offsets, const uint8_t* data, int64_t length, + int64_t output_offset, uint8_t* output) { + const offset_type* offsets = reinterpret_cast(raw_offsets); + FirstTimeBitmapWriter bitmap_writer(output, output_offset, length); + for (int64_t i = 0; i < length; ++i) { + const char* current_data = reinterpret_cast(data + offsets[i]); + int64_t current_length = offsets[i + 1] - offsets[i]; + if (matcher.Match(util::string_view(current_data, current_length))) { + bitmap_writer.Set(); + } + bitmap_writer.Next(); + } + bitmap_writer.Finish(); }, out); } }; +// This is an implementation of the Knuth-Morris-Pratt algorithm +struct PlainSubstringMatcher { + const MatchSubstringOptions& options_; + std::vector prefix_table; + + PlainSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options) + : options_(options) { + // Phase 1: Build the prefix table + const auto pattern_length = options_.pattern.size(); + prefix_table.resize(pattern_length + 1, /*value=*/0); + int64_t prefix_length = -1; + prefix_table[0] = -1; + for (size_t pos = 0; pos < pattern_length; ++pos) { + // The prefix cannot be expanded, reset. + while (prefix_length >= 0 && + options_.pattern[pos] != options_.pattern[prefix_length]) { + prefix_length = prefix_table[prefix_length]; + } + prefix_length++; + prefix_table[pos + 1] = prefix_length; + } + } + + bool Match(util::string_view current) { + // Phase 2: Find the prefix in the data + const auto pattern_length = options_.pattern.size(); + int64_t pattern_pos = 0; + for (const auto c : current) { + while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) { + pattern_pos = prefix_table[pattern_pos]; + } + pattern_pos++; + if (static_cast(pattern_pos) == pattern_length) { + return true; + } + } + return false; + } +}; + const FunctionDoc match_substring_doc( "Match strings against literal pattern", ("For each string in `strings`, emit true iff it contains a given pattern.\n" "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), {"strings"}, "MatchSubstringOptions"); +#ifdef ARROW_WITH_RE2 +struct RegexSubstringMatcher { + const MatchSubstringOptions& options_; + const RE2 regex_match_; + + RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options) + : options_(options), regex_match_(options_.pattern) { + if (!regex_match_.ok()) { + ctx->SetStatus(Status::Invalid("Regular expression error")); + } + } + + bool Match(util::string_view current) { + auto piece = re2::StringPiece(current.data(), current.length()); + return re2::RE2::PartialMatch(piece, regex_match_); + } +}; + +const FunctionDoc match_substring_regex_doc( + "Match strings against regex pattern", + ("For each string in `strings`, emit true iff it matches a given pattern at any " + "position.\n" + "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), + {"strings"}, "MatchSubstringOptions"); +#endif + void AddMatchSubstring(FunctionRegistry* registry) { - auto func = std::make_shared("match_substring", Arity::Unary(), - &match_substring_doc); - auto exec_32 = MatchSubstring::Exec; - auto exec_64 = MatchSubstring::Exec; - DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init)); - DCHECK_OK( - func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init)); - DCHECK_OK(registry->AddFunction(std::move(func))); + { + auto func = std::make_shared("match_substring", Arity::Unary(), + &match_substring_doc); + auto exec_32 = MatchSubstring::Exec; + auto exec_64 = MatchSubstring::Exec; + DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); + } +#ifdef ARROW_WITH_RE2 + { + auto func = std::make_shared("match_substring_regex", Arity::Unary(), + &match_substring_regex_doc); + auto exec_32 = MatchSubstring::Exec; + auto exec_64 = MatchSubstring::Exec; + DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); + } +#endif } // IsAlpha/Digit etc @@ -1246,7 +1293,7 @@ struct ReplaceSubString { using State = OptionsWrapper; static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - // TODO Cache replacer accross invocations (for regex compilation) + // TODO Cache replacer across invocations (for regex compilation) Replacer replacer{ctx, State::Get(ctx)}; if (!ctx->HasError()) { Replace(ctx, batch, &replacer, out); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 88622e842d1..2dd0a4d8c74 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -348,6 +348,27 @@ TYPED_TEST(TestStringKernels, MatchSubstring) { &options_double_char_2); } +#ifdef ARROW_WITH_RE2 +TYPED_TEST(TestStringKernels, MatchSubstringRegex) { + MatchSubstringOptions options{"ab"}; + this->CheckUnary("match_substring_regex", "[]", boolean(), "[]", &options); + this->CheckUnary("match_substring_regex", R"(["abc", "acb", "cab", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options); + MatchSubstringOptions options_repeated{"(ab){2}"}; + this->CheckUnary("match_substring_regex", R"(["abab", "ab", "cababc", null, "bac"])", + boolean(), "[true, false, true, null, false]", &options_repeated); + MatchSubstringOptions options_digit{"\\d"}; + this->CheckUnary("match_substring_regex", R"(["aacb", "a2ab", "", "24"])", boolean(), + "[false, true, false, true]", &options_digit); + MatchSubstringOptions options_star{"a*b"}; + this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])", + boolean(), "[true, true, true, true, true, false]", &options_star); + MatchSubstringOptions options_plus{"a+b"}; + this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])", + boolean(), "[false, true, true, true, false, false]", &options_plus); +} +#endif + TYPED_TEST(TestStringKernels, SplitBasics) { SplitPatternOptions options{" "}; // basics diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 065b80736aa..715d5036964 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -522,26 +522,31 @@ These functions trim off characters on both sides (trim), or the left (ltrim) or Containment tests ~~~~~~~~~~~~~~~~~ -+--------------------+------------+------------------------------------+---------------+----------------------------------------+ -| Function name | Arity | Input types | Output type | Options class | -+====================+============+====================================+===============+========================================+ -| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` | -+--------------------+------------+------------------------------------+---------------+----------------------------------------+ -| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (2) | :struct:`SetLookupOptions` | -| | | Binary- and String-like | | | -+--------------------+------------+------------------------------------+---------------+----------------------------------------+ -| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (3) | :struct:`SetLookupOptions` | -| | | Binary- and String-like | | | -+--------------------+------------+------------------------------------+---------------+----------------------------------------+ ++---------------------------+------------+------------------------------------+---------------+----------------------------------------+ +| Function name | Arity | Input types | Output type | Options class | ++===========================+============+====================================+===============+========================================+ +| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+---------------+----------------------------------------+ +| match_substring_regex | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+---------------+----------------------------------------+ +| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (3) | :struct:`SetLookupOptions` | +| | | Binary- and String-like | | | ++---------------------------+------------+------------------------------------+---------------+----------------------------------------+ +| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (4) | :struct:`SetLookupOptions` | +| | | Binary- and String-like | | | ++---------------------------+------------+------------------------------------+---------------+----------------------------------------+ * \(1) Output is true iff :member:`MatchSubstringOptions::pattern` is a substring of the corresponding input element. -* \(2) Output is the index of the corresponding input element in +* \(2) Output is true iff :member:`MatchSubstringOptions::pattern` + matches the corresponding input element at any position. + +* \(3) Output is the index of the corresponding input element in :member:`SetLookupOptions::value_set`, if found there. Otherwise, output is null. -* \(3) Output is true iff the corresponding input element is equal to one +* \(4) Output is true iff the corresponding input element is equal to one of the elements in :member:`SetLookupOptions::value_set`. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 2dafbd23c08..d6efc6a5fea 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -155,6 +155,7 @@ Containment tests index_in is_in match_substring + match_substring_regex Conversions ----------- diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 1b46a08c402..3928b9cb904 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -306,6 +306,24 @@ def match_substring(array, pattern): MatchSubstringOptions(pattern)) +def match_substring_regex(array, pattern): + """ + Test if regex *pattern* matches at any position a value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + regex pattern to search + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("match_substring_regex", [array], + MatchSubstringOptions(pattern)) + + def sum(array): """ Sum the values in a numerical (chunked) array. diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 160375f93bd..94a6189f41c 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -279,6 +279,13 @@ def test_match_substring(): assert expected.equals(result) +def test_match_substring_regex(): + arr = pa.array(["ab", "abc", "ba", "c", None]) + result = pc.match_substring_regex(arr, "^a?b") + expected = pa.array([True, True, True, False, None]) + assert expected.equals(result) + + def test_trim(): # \u3000 is unicode whitespace arr = pa.array([" foo", None, " \u3000foo bar \t"])