From 44bf6222075929617d034091ec5dc249e7730ef8 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 4 Jun 2021 10:45:57 -0500 Subject: [PATCH 1/2] ARROW-12952: [C++] Add count_substring_regex --- .../arrow/compute/kernels/scalar_string.cc | 113 ++++++++++++-- .../compute/kernels/scalar_string_test.cc | 143 ++++++++++++------ docs/source/cpp/compute.rst | 2 + docs/source/python/api/compute.rst | 1 + python/pyarrow/compute.py | 23 ++- python/pyarrow/tests/test_compute.py | 33 ++-- 6 files changed, 248 insertions(+), 67 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index b6c1b8f6261..0ace5d843c2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -980,13 +980,70 @@ struct CountSubstring { } }; +#ifdef ARROW_WITH_RE2 +struct CountSubstringRegex { + std::unique_ptr regex_match_; + + explicit CountSubstringRegex(const MatchSubstringOptions& options, bool literal = false) + : regex_match_(new RE2(options.pattern, + RegexSubstringMatcher::MakeRE2Options(options, literal))) {} + + static Result Make(const MatchSubstringOptions& options, + bool literal = false) { + CountSubstringRegex counter(options, literal); + RETURN_NOT_OK(RegexStatus(*counter.regex_match_)); + return std::move(counter); + } + + template + OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue count = 0; + re2::StringPiece input(val.data(), val.size()); + auto last_size = input.size(); + while (re2::RE2::FindAndConsume(&input, *regex_match_)) { + count++; + if (last_size == input.size()) { + // 0-length match + if (input.size() > 0) { + input.remove_prefix(1); + } else { + break; + } + } + last_size = input.size(); + } + return count; + } +}; + +template +struct CountSubstringRegexExec { + using OffsetType = typename TypeTraits::OffsetType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const MatchSubstringOptions& options = MatchSubstringState::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto counter, CountSubstringRegex::Make(options)); + applicator::ScalarUnaryNotNullStateful + kernel{std::move(counter)}; + return kernel.Exec(ctx, batch, out); + } +}; +#endif + template struct CountSubstringExec { using OffsetType = typename TypeTraits::OffsetType; static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { const MatchSubstringOptions& options = MatchSubstringState::Get(ctx); if (options.ignore_case) { - return Status::NotImplemented("count_substring with ignore_case"); +#ifdef ARROW_WITH_RE2 + ARROW_ASSIGN_OR_RAISE(auto counter, + CountSubstringRegex::Make(options, /*literal=*/true)); + applicator::ScalarUnaryNotNullStateful + kernel{std::move(counter)}; + return kernel.Exec(ctx, batch, out); +#else + return Status::NotImplemented("ignore_case requires RE2"); +#endif } applicator::ScalarUnaryNotNullStateful kernel{ CountSubstring(PlainSubstringMatcher(options))}; @@ -1001,21 +1058,51 @@ const FunctionDoc count_substring_doc( "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), {"strings"}, "MatchSubstringOptions"); +#ifdef ARROW_WITH_RE2 +const FunctionDoc count_substring_regex_doc( + "Count occurrences of substring", + ("For each string in `strings`, emit the number of occurrences of the given " + "regex pattern.\n" + "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), + {"strings"}, "MatchSubstringOptions"); +#endif + void AddCountSubstring(FunctionRegistry* registry) { - auto func = std::make_shared("count_substring", Arity::Unary(), - &count_substring_doc); - for (const auto& ty : BaseBinaryTypes()) { - std::shared_ptr offset_type; - if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { - offset_type = int64(); - } else { - offset_type = int32(); + { + auto func = std::make_shared("count_substring", Arity::Unary(), + &count_substring_doc); + for (const auto& ty : BaseBinaryTypes()) { + std::shared_ptr offset_type; + if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { + offset_type = int64(); + } else { + offset_type = int32(); + } + DCHECK_OK(func->AddKernel({ty}, offset_type, + GenerateTypeAgnosticVarBinaryBase(ty), + MatchSubstringState::Init)); } - DCHECK_OK(func->AddKernel({ty}, offset_type, - GenerateTypeAgnosticVarBinaryBase(ty), - MatchSubstringState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); } - DCHECK_OK(registry->AddFunction(std::move(func))); +#ifdef ARROW_WITH_RE2 + { + auto func = std::make_shared("count_substring_regex", Arity::Unary(), + &count_substring_regex_doc); + for (const auto& ty : BaseBinaryTypes()) { + std::shared_ptr offset_type; + if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { + offset_type = int64(); + } else { + offset_type = int32(); + } + DCHECK_OK( + func->AddKernel({ty}, offset_type, + GenerateTypeAgnosticVarBinaryBase(ty), + MatchSubstringState::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); + } +#endif } // Slicing diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 7d52d6aacf2..2053dbaa971 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -80,49 +80,7 @@ TYPED_TEST(TestBinaryKernels, BinaryLength) { this->offset_type(), "[3, null, 10, 0, 1]"); } -TYPED_TEST(TestBinaryKernels, FindSubstring) { - MatchSubstringOptions options{"ab"}; - this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options); - this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])", - this->offset_type(), "[0, -1, 1, null, -1]", &options); - - MatchSubstringOptions options_repeated{"abab"}; - this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])", - this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated); - - MatchSubstringOptions options_double_char{"aab"}; - this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])", - this->offset_type(), "[-1, 0, -1, 1]", &options_double_char); - - MatchSubstringOptions options_double_char_2{"bbcaa"}; - this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(), - "[7]", &options_double_char_2); - - MatchSubstringOptions options_empty{""}; - this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(), - "[0, 0, null]", &options_empty); -} - -TYPED_TEST(TestBinaryKernels, CountSubstring) { - MatchSubstringOptions options{"aba"}; - this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options); - this->CheckUnary( - "count_substring", - R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])", - this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options); - - MatchSubstringOptions options_empty{""}; - this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(), - "[1, null, 4]", &options_empty); - - MatchSubstringOptions options_repeated{"aaa"}; - this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])", - this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated); - - // TODO: case-insensitive -} - -TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) { +TYPED_TEST(TestBinaryKernels, BinaryReplaceSlice) { ReplaceSliceOptions options{0, 1, "XX"}; this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options); this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])", @@ -172,6 +130,105 @@ TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) { &options_neg_flip); } +TYPED_TEST(TestBinaryKernels, FindSubstring) { + MatchSubstringOptions options{"ab"}; + this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options); + + MatchSubstringOptions options_repeated{"abab"}; + this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated); + + MatchSubstringOptions options_double_char{"aab"}; + this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])", + this->offset_type(), "[-1, 0, -1, 1]", &options_double_char); + + MatchSubstringOptions options_double_char_2{"bbcaa"}; + this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(), + "[7]", &options_double_char_2); + + MatchSubstringOptions options_empty{""}; + this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(), + "[0, 0, null]", &options_empty); +} + +TYPED_TEST(TestBinaryKernels, CountSubstring) { + MatchSubstringOptions options{"aba"}; + this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary( + "count_substring", + R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])", + this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options); + + MatchSubstringOptions options_empty{""}; + this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(), + "[1, null, 4]", &options_empty); + + MatchSubstringOptions options_repeated{"aaa"}; + this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])", + this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated); +} + +#ifdef ARROW_WITH_RE2 +TYPED_TEST(TestBinaryKernels, CountSubstringRegex) { + MatchSubstringOptions options{"aba"}; + this->CheckUnary("count_substring_regex", "[]", this->offset_type(), "[]", &options); + this->CheckUnary( + "count_substring", + R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])", + this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options); + + MatchSubstringOptions options_empty{""}; + this->CheckUnary("count_substring_regex", R"(["", null, "abc"])", this->offset_type(), + "[1, null, 4]", &options_empty); + + MatchSubstringOptions options_as{"a+"}; + this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])", + this->offset_type(), "[0, 3, 0, 0]", &options_as); + + MatchSubstringOptions options_empty_match{"a*"}; + this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])", + // 7 is because it matches at |b|a|c|aaa|d|aaaa| + this->offset_type(), "[1, 7, 2, 4]", &options_empty_match); + + MatchSubstringOptions options_repeated{"aaa"}; + this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])", + this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated); +} + +TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) { + MatchSubstringOptions options{"aba", /*ignore_case=*/true}; + this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary( + "count_substring", + R"(["", null, "ab", "aBa", "bAbA", "aBaBa", "abaAbA", "babacaba", "ABA"])", + this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options); + + MatchSubstringOptions options_empty{"", /*ignore_case=*/true}; + this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(), + "[1, null, 4]", &options_empty); +} + +TYPED_TEST(TestBinaryKernels, CountSubstringRegexIgnoreCase) { + MatchSubstringOptions options_as{"a+", /*ignore_case=*/true}; + this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])", + this->offset_type(), "[0, 3, 0, 1]", &options_as); + + MatchSubstringOptions options_empty_match{"a*", /*ignore_case=*/true}; + this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])", + this->offset_type(), "[1, 7, 2, 2]", &options_empty_match); +} +#else +TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) { + Datum input = ArrayFromJSON(this->type(), R"(["a"])"); + MatchSubstringOptions options{"a", /*ignore_case=*/true}; + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("ignore_case requires RE2"), + CallFunction("count_substring", {input}, &options)); +} +#endif + template class TestStringKernels : public BaseTestStringKernels {}; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index b28e3928a74..91ee6bdf599 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -572,6 +572,8 @@ Containment tests +===========================+============+====================================+====================+========================================+ | count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| count_substring_regex | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ | ends_with | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ | find_substring | Unary | String-like | Int32 or Int64 (3) | :struct:`MatchSubstringOptions` | diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 2e37f9169a7..dd722e44f05 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -185,6 +185,7 @@ Containment tests :toctree: ../generated/ count_substring + count_substring_regex ends_with find_substring index_in diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 44282369f87..b8bd9e65f17 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -292,7 +292,7 @@ def cast(arr, target_type, safe=True): return call_function("cast", [arr], options) -def count_substring(array, pattern): +def count_substring(array, pattern, *, ignore_case=False): """ Count the occurrences of substring *pattern* in each value of a string array. @@ -308,7 +308,26 @@ def count_substring(array, pattern): result : pyarrow.Array or pyarrow.ChunkedArray """ return call_function("count_substring", [array], - MatchSubstringOptions(pattern)) + MatchSubstringOptions(pattern, ignore_case)) + + +def count_substring_regex(array, pattern, *, ignore_case=False): + """ + Count the non-overlapping matches of regex *pattern* in each value + of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("count_substring_regex", [array], + MatchSubstringOptions(pattern, ignore_case)) def find_substring(array, pattern): diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 8de24c8c249..1ed582db831 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -291,16 +291,31 @@ def test_variance(): def test_count_substring(): - arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None]) - result = pc.count_substring(arr, "ab") - expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32()) - assert expected.equals(result) + for (ty, offset) in [(pa.string(), pa.int32()), + (pa.large_string(), pa.int64())]: + arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], type=ty) - arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], - type=pa.large_string()) - result = pc.count_substring(arr, "ab") - expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64()) - assert expected.equals(result) + result = pc.count_substring(arr, "ab") + expected = pa.array([1, 1, 2, 0, 0, None], type=offset) + assert expected.equals(result) + + result = pc.count_substring(arr, "ab", ignore_case=True) + expected = pa.array([1, 1, 2, 0, 1, None], type=offset) + assert expected.equals(result) + + +def test_count_substring_regex(): + for (ty, offset) in [(pa.string(), pa.int32()), + (pa.large_string(), pa.int64())]: + arr = pa.array(["ab", "cab", "baAacaa", "ba", "AB", None], type=ty) + + result = pc.count_substring_regex(arr, "a+") + expected = pa.array([1, 1, 3, 1, 0, None], type=offset) + assert expected.equals(result) + + result = pc.count_substring_regex(arr, "a+", ignore_case=True) + expected = pa.array([1, 1, 2, 1, 1, None], type=offset) + assert expected.equals(result) def test_find_substring(): From f4fa220a71d465f55ab3934e5f27b3273dd32ac4 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 10 Jun 2021 12:23:10 -0400 Subject: [PATCH 2/2] ARROW-12952: [C++] Simplify count_substring kernel registration --- cpp/src/arrow/compute/kernels/scalar_string.cc | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 0ace5d843c2..cd054fcea0e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1072,12 +1072,7 @@ void AddCountSubstring(FunctionRegistry* registry) { auto func = std::make_shared("count_substring", Arity::Unary(), &count_substring_doc); for (const auto& ty : BaseBinaryTypes()) { - std::shared_ptr offset_type; - if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { - offset_type = int64(); - } else { - offset_type = int32(); - } + auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32(); DCHECK_OK(func->AddKernel({ty}, offset_type, GenerateTypeAgnosticVarBinaryBase(ty), MatchSubstringState::Init)); @@ -1089,12 +1084,7 @@ void AddCountSubstring(FunctionRegistry* registry) { auto func = std::make_shared("count_substring_regex", Arity::Unary(), &count_substring_regex_doc); for (const auto& ty : BaseBinaryTypes()) { - std::shared_ptr offset_type; - if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { - offset_type = int64(); - } else { - offset_type = int32(); - } + auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32(); DCHECK_OK( func->AddKernel({ty}, offset_type, GenerateTypeAgnosticVarBinaryBase(ty),