From 44bf6222075929617d034091ec5dc249e7730ef8 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 4 Jun 2021 10:45:57 -0500
Subject: [PATCH 1/2] ARROW-12952: [C++] Add count_substring_regex
---
.../arrow/compute/kernels/scalar_string.cc | 113 ++++++++++++--
.../compute/kernels/scalar_string_test.cc | 143 ++++++++++++------
docs/source/cpp/compute.rst | 2 +
docs/source/python/api/compute.rst | 1 +
python/pyarrow/compute.py | 23 ++-
python/pyarrow/tests/test_compute.py | 33 ++--
6 files changed, 248 insertions(+), 67 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index b6c1b8f6261..0ace5d843c2 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -980,13 +980,70 @@ struct CountSubstring {
}
};
+#ifdef ARROW_WITH_RE2
+struct CountSubstringRegex {
+ std::unique_ptr regex_match_;
+
+ explicit CountSubstringRegex(const MatchSubstringOptions& options, bool literal = false)
+ : regex_match_(new RE2(options.pattern,
+ RegexSubstringMatcher::MakeRE2Options(options, literal))) {}
+
+ static Result Make(const MatchSubstringOptions& options,
+ bool literal = false) {
+ CountSubstringRegex counter(options, literal);
+ RETURN_NOT_OK(RegexStatus(*counter.regex_match_));
+ return std::move(counter);
+ }
+
+ template
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ OutValue count = 0;
+ re2::StringPiece input(val.data(), val.size());
+ auto last_size = input.size();
+ while (re2::RE2::FindAndConsume(&input, *regex_match_)) {
+ count++;
+ if (last_size == input.size()) {
+ // 0-length match
+ if (input.size() > 0) {
+ input.remove_prefix(1);
+ } else {
+ break;
+ }
+ }
+ last_size = input.size();
+ }
+ return count;
+ }
+};
+
+template
+struct CountSubstringRegexExec {
+ using OffsetType = typename TypeTraits::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ ARROW_ASSIGN_OR_RAISE(auto counter, CountSubstringRegex::Make(options));
+ applicator::ScalarUnaryNotNullStateful
+ kernel{std::move(counter)};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+#endif
+
template
struct CountSubstringExec {
using OffsetType = typename TypeTraits::OffsetType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
if (options.ignore_case) {
- return Status::NotImplemented("count_substring with ignore_case");
+#ifdef ARROW_WITH_RE2
+ ARROW_ASSIGN_OR_RAISE(auto counter,
+ CountSubstringRegex::Make(options, /*literal=*/true));
+ applicator::ScalarUnaryNotNullStateful
+ kernel{std::move(counter)};
+ return kernel.Exec(ctx, batch, out);
+#else
+ return Status::NotImplemented("ignore_case requires RE2");
+#endif
}
applicator::ScalarUnaryNotNullStateful kernel{
CountSubstring(PlainSubstringMatcher(options))};
@@ -1001,21 +1058,51 @@ const FunctionDoc count_substring_doc(
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");
+#ifdef ARROW_WITH_RE2
+const FunctionDoc count_substring_regex_doc(
+ "Count occurrences of substring",
+ ("For each string in `strings`, emit the number of occurrences of the given "
+ "regex pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+#endif
+
void AddCountSubstring(FunctionRegistry* registry) {
- auto func = std::make_shared("count_substring", Arity::Unary(),
- &count_substring_doc);
- for (const auto& ty : BaseBinaryTypes()) {
- std::shared_ptr offset_type;
- if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
- offset_type = int64();
- } else {
- offset_type = int32();
+ {
+ auto func = std::make_shared("count_substring", Arity::Unary(),
+ &count_substring_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ std::shared_ptr offset_type;
+ if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
+ offset_type = int64();
+ } else {
+ offset_type = int32();
+ }
+ DCHECK_OK(func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase(ty),
+ MatchSubstringState::Init));
}
- DCHECK_OK(func->AddKernel({ty}, offset_type,
- GenerateTypeAgnosticVarBinaryBase(ty),
- MatchSubstringState::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
}
- DCHECK_OK(registry->AddFunction(std::move(func)));
+#ifdef ARROW_WITH_RE2
+ {
+ auto func = std::make_shared("count_substring_regex", Arity::Unary(),
+ &count_substring_regex_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ std::shared_ptr offset_type;
+ if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
+ offset_type = int64();
+ } else {
+ offset_type = int32();
+ }
+ DCHECK_OK(
+ func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+ }
+#endif
}
// Slicing
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 7d52d6aacf2..2053dbaa971 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -80,49 +80,7 @@ TYPED_TEST(TestBinaryKernels, BinaryLength) {
this->offset_type(), "[3, null, 10, 0, 1]");
}
-TYPED_TEST(TestBinaryKernels, FindSubstring) {
- MatchSubstringOptions options{"ab"};
- this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
- this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
- this->offset_type(), "[0, -1, 1, null, -1]", &options);
-
- MatchSubstringOptions options_repeated{"abab"};
- this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
- this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);
-
- MatchSubstringOptions options_double_char{"aab"};
- this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
- this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);
-
- MatchSubstringOptions options_double_char_2{"bbcaa"};
- this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
- "[7]", &options_double_char_2);
-
- MatchSubstringOptions options_empty{""};
- this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(),
- "[0, 0, null]", &options_empty);
-}
-
-TYPED_TEST(TestBinaryKernels, CountSubstring) {
- MatchSubstringOptions options{"aba"};
- this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
- this->CheckUnary(
- "count_substring",
- R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
- this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
-
- MatchSubstringOptions options_empty{""};
- this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
- "[1, null, 4]", &options_empty);
-
- MatchSubstringOptions options_repeated{"aaa"};
- this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
- this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
-
- // TODO: case-insensitive
-}
-
-TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) {
+TYPED_TEST(TestBinaryKernels, BinaryReplaceSlice) {
ReplaceSliceOptions options{0, 1, "XX"};
this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options);
this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])",
@@ -172,6 +130,105 @@ TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) {
&options_neg_flip);
}
+TYPED_TEST(TestBinaryKernels, FindSubstring) {
+ MatchSubstringOptions options{"ab"};
+ this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options);
+
+ MatchSubstringOptions options_repeated{"abab"};
+ this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
+ this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);
+
+ MatchSubstringOptions options_double_char{"aab"};
+ this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
+ this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);
+
+ MatchSubstringOptions options_double_char_2{"bbcaa"};
+ this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
+ "[7]", &options_double_char_2);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(),
+ "[0, 0, null]", &options_empty);
+}
+
+TYPED_TEST(TestBinaryKernels, CountSubstring) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+}
+
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestBinaryKernels, CountSubstringRegex) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring_regex", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring_regex", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ MatchSubstringOptions options_as{"a+"};
+ this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
+ this->offset_type(), "[0, 3, 0, 0]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*"};
+ this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
+ // 7 is because it matches at |b|a|c|aaa|d|aaaa|
+ this->offset_type(), "[1, 7, 2, 4]", &options_empty_match);
+
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+}
+
+TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
+ MatchSubstringOptions options{"aba", /*ignore_case=*/true};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aBa", "bAbA", "aBaBa", "abaAbA", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);
+
+ MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+}
+
+TYPED_TEST(TestBinaryKernels, CountSubstringRegexIgnoreCase) {
+ MatchSubstringOptions options_as{"a+", /*ignore_case=*/true};
+ this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
+ this->offset_type(), "[0, 3, 0, 1]", &options_as);
+
+ MatchSubstringOptions options_empty_match{"a*", /*ignore_case=*/true};
+ this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
+ this->offset_type(), "[1, 7, 2, 2]", &options_empty_match);
+}
+#else
+TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+ MatchSubstringOptions options{"a", /*ignore_case=*/true};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("ignore_case requires RE2"),
+ CallFunction("count_substring", {input}, &options));
+}
+#endif
+
template
class TestStringKernels : public BaseTestStringKernels {};
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index b28e3928a74..91ee6bdf599 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -572,6 +572,8 @@ Containment tests
+===========================+============+====================================+====================+========================================+
| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
+| count_substring_regex | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| ends_with | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| find_substring | Unary | String-like | Int32 or Int64 (3) | :struct:`MatchSubstringOptions` |
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 2e37f9169a7..dd722e44f05 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -185,6 +185,7 @@ Containment tests
:toctree: ../generated/
count_substring
+ count_substring_regex
ends_with
find_substring
index_in
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 44282369f87..b8bd9e65f17 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -292,7 +292,7 @@ def cast(arr, target_type, safe=True):
return call_function("cast", [arr], options)
-def count_substring(array, pattern):
+def count_substring(array, pattern, *, ignore_case=False):
"""
Count the occurrences of substring *pattern* in each value of a
string array.
@@ -308,7 +308,26 @@ def count_substring(array, pattern):
result : pyarrow.Array or pyarrow.ChunkedArray
"""
return call_function("count_substring", [array],
- MatchSubstringOptions(pattern))
+ MatchSubstringOptions(pattern, ignore_case))
+
+
+def count_substring_regex(array, pattern, *, ignore_case=False):
+ """
+ Count the non-overlapping matches of regex *pattern* in each value
+ of a string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("count_substring_regex", [array],
+ MatchSubstringOptions(pattern, ignore_case))
def find_substring(array, pattern):
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 8de24c8c249..1ed582db831 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -291,16 +291,31 @@ def test_variance():
def test_count_substring():
- arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None])
- result = pc.count_substring(arr, "ab")
- expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32())
- assert expected.equals(result)
+ for (ty, offset) in [(pa.string(), pa.int32()),
+ (pa.large_string(), pa.int64())]:
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], type=ty)
- arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None],
- type=pa.large_string())
- result = pc.count_substring(arr, "ab")
- expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64())
- assert expected.equals(result)
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=offset)
+ assert expected.equals(result)
+
+ result = pc.count_substring(arr, "ab", ignore_case=True)
+ expected = pa.array([1, 1, 2, 0, 1, None], type=offset)
+ assert expected.equals(result)
+
+
+def test_count_substring_regex():
+ for (ty, offset) in [(pa.string(), pa.int32()),
+ (pa.large_string(), pa.int64())]:
+ arr = pa.array(["ab", "cab", "baAacaa", "ba", "AB", None], type=ty)
+
+ result = pc.count_substring_regex(arr, "a+")
+ expected = pa.array([1, 1, 3, 1, 0, None], type=offset)
+ assert expected.equals(result)
+
+ result = pc.count_substring_regex(arr, "a+", ignore_case=True)
+ expected = pa.array([1, 1, 2, 1, 1, None], type=offset)
+ assert expected.equals(result)
def test_find_substring():
From f4fa220a71d465f55ab3934e5f27b3273dd32ac4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 10 Jun 2021 12:23:10 -0400
Subject: [PATCH 2/2] ARROW-12952: [C++] Simplify count_substring kernel
registration
---
cpp/src/arrow/compute/kernels/scalar_string.cc | 14 ++------------
1 file changed, 2 insertions(+), 12 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 0ace5d843c2..cd054fcea0e 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -1072,12 +1072,7 @@ void AddCountSubstring(FunctionRegistry* registry) {
auto func = std::make_shared("count_substring", Arity::Unary(),
&count_substring_doc);
for (const auto& ty : BaseBinaryTypes()) {
- std::shared_ptr offset_type;
- if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
- offset_type = int64();
- } else {
- offset_type = int32();
- }
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
DCHECK_OK(func->AddKernel({ty}, offset_type,
GenerateTypeAgnosticVarBinaryBase(ty),
MatchSubstringState::Init));
@@ -1089,12 +1084,7 @@ void AddCountSubstring(FunctionRegistry* registry) {
auto func = std::make_shared("count_substring_regex", Arity::Unary(),
&count_substring_regex_doc);
for (const auto& ty : BaseBinaryTypes()) {
- std::shared_ptr offset_type;
- if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
- offset_type = int64();
- } else {
- offset_type = int32();
- }
+ auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
DCHECK_OK(
func->AddKernel({ty}, offset_type,
GenerateTypeAgnosticVarBinaryBase(ty),