From 8006de63400cc4292a1c91e708fbc068adb56fe5 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 4 Jun 2021 09:54:22 -0500
Subject: [PATCH 1/2] ARROW-12950: [C++] Implement count_substring
---
.../arrow/compute/kernels/scalar_string.cc | 70 ++++++++++++++++++-
.../compute/kernels/scalar_string_test.cc | 15 ++++
docs/source/cpp/compute.rst | 30 ++++----
docs/source/python/api/compute.rst | 1 +
python/pyarrow/compute.py | 19 +++++
python/pyarrow/tests/test_compute.py | 13 ++++
6 files changed, 135 insertions(+), 13 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 9db16e26ca5..08159ccddaa 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -741,8 +741,12 @@ template
struct FindSubstringExec {
using OffsetType = typename TypeTraits::OffsetType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+ return Status::NotImplemented("find_substring with ignore_case");
+ }
applicator::ScalarUnaryNotNullStateful kernel{
- FindSubstring(PlainSubstringMatcher(MatchSubstringState::Get(ctx)))};
+ FindSubstring(PlainSubstringMatcher(options))};
return kernel.Exec(ctx, batch, out);
}
};
@@ -771,6 +775,69 @@ void AddFindSubstring(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}
+// Substring count
+
+struct CountSubstring {
+ const PlainSubstringMatcher matcher_;
+
+ explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {}
+
+ template
+ OutValue Call(KernelContext*, util::string_view val, Status*) const {
+ OutValue count = 0;
+ uint64_t start = 0;
+ const auto pattern_size = std::max(1, matcher_.options_.pattern.size());
+ while (start <= val.size()) {
+ const int64_t index = matcher_.Find(val.substr(start));
+ if (index >= 0) {
+ count++;
+ start += index + pattern_size;
+ } else {
+ break;
+ }
+ }
+ return count;
+ }
+};
+
+template
+struct CountSubstringExec {
+ using OffsetType = typename TypeTraits::OffsetType;
+ static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
+ if (options.ignore_case) {
+ return Status::NotImplemented("count_substring with ignore_case");
+ }
+ applicator::ScalarUnaryNotNullStateful kernel{
+ CountSubstring(PlainSubstringMatcher(options))};
+ return kernel.Exec(ctx, batch, out);
+ }
+};
+
+const FunctionDoc count_substring_doc(
+ "Count occurrences of substring",
+ ("For each string in `strings`, emit the number of occurrences of the given "
+ "pattern.\n"
+ "Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
+ {"strings"}, "MatchSubstringOptions");
+
+void AddCountSubstring(FunctionRegistry* registry) {
+ auto func = std::make_shared("count_substring", Arity::Unary(),
+ &count_substring_doc);
+ for (const auto& ty : BaseBinaryTypes()) {
+ std::shared_ptr offset_type;
+ if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
+ offset_type = int64();
+ } else {
+ offset_type = int32();
+ }
+ DCHECK_OK(func->AddKernel({ty}, offset_type,
+ GenerateTypeAgnosticVarBinaryBase(ty),
+ MatchSubstringState::Init));
+ }
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+
// Slicing
template
@@ -2936,6 +3003,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddUtf8Length(registry);
AddMatchSubstring(registry);
AddFindSubstring(registry);
+ AddCountSubstring(registry);
MakeUnaryStringBatchKernelWithState(
"replace_substring", registry, &replace_substring_doc,
MemAllocation::NO_PREALLOCATE);
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index bd5c8eec03f..06f9ad252e2 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -97,6 +97,21 @@ TYPED_TEST(TestBinaryKernels, FindSubstring) {
"[0, 0, null]", &options_empty);
}
+TYPED_TEST(TestBinaryKernels, CountSubstring) {
+ MatchSubstringOptions options{"aba"};
+ this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
+ this->CheckUnary(
+ "count_substring",
+ R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
+ this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
+
+ MatchSubstringOptions options_empty{""};
+ this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
+ "[1, null, 4]", &options_empty);
+
+ // TODO: case-insensitive
+}
+
template
class TestStringKernels : public BaseTestStringKernels {};
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 3f30bbcaa06..c7ccc8a822d 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -561,45 +561,51 @@ Containment tests
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| Function name | Arity | Input types | Output type | Options class |
+===========================+============+====================================+====================+========================================+
-| find_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_like | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
+| find_substring | Unary | String-like | Int32 or Int64 (2) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_substring | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
+| match_like | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| match_substring_regex | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
+| match_substring | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (5) | :struct:`SetLookupOptions` |
+| match_substring_regex | Unary | String-like | Boolean (5) | :struct:`MatchSubstringOptions` |
++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
+| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (6) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
-| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (6) | :struct:`SetLookupOptions` |
+| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (7) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
+* \(1) Output is the number of occurrences of
+ :member:`MatchSubstringOptions::pattern` in the corresponding input
+ string. Output type is Int32 for Binary/String, Int64
+ for LargeBinary/LargeString.
-* \(1) Output is the index of the first occurrence of
+* \(2) Output is the index of the first occurrence of
:member:`MatchSubstringOptions::pattern` in the corresponding input
string, otherwise -1. Output type is Int32 for Binary/String, Int64
for LargeBinary/LargeString.
-* \(2) Output is true iff the SQL-style LIKE pattern
+* \(3) Output is true iff the SQL-style LIKE pattern
:member:`MatchSubstringOptions::pattern` fully matches the
corresponding input element. That is, ``%`` will match any number of
characters, ``_`` will match exactly one character, and any other
character matches itself. To match a literal percent sign or
underscore, precede the character with a backslash.
-* \(3) Output is true iff :member:`MatchSubstringOptions::pattern`
+* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
is a substring of the corresponding input element.
-* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
+* \(5) Output is true iff :member:`MatchSubstringOptions::pattern`
matches the corresponding input element at any position.
-* \(5) Output is the index of the corresponding input element in
+* \(6) Output is the index of the corresponding input element in
:member:`SetLookupOptions::value_set`, if found there. Otherwise,
output is null.
-* \(6) Output is true iff the corresponding input element is equal to one
+* \(7) Output is true iff the corresponding input element is equal to one
of the elements in :member:`SetLookupOptions::value_set`.
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index ccd530073aa..a586f9011fd 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -178,6 +178,7 @@ Containment tests
.. autosummary::
:toctree: ../generated/
+ count_substring
find_substring
index_in
is_in
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index eb66f4407c8..8dc7181514c 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -291,6 +291,25 @@ def cast(arr, target_type, safe=True):
return call_function("cast", [arr], options)
+def count_substring(array, pattern):
+ """
+ Count the occurrences of substring *pattern* in each value of a
+ string array.
+
+ Parameters
+ ----------
+ array : pyarrow.Array or pyarrow.ChunkedArray
+ pattern : str
+ pattern to search for exact matches
+
+ Returns
+ -------
+ result : pyarrow.Array or pyarrow.ChunkedArray
+ """
+ return call_function("count_substring", [array],
+ MatchSubstringOptions(pattern))
+
+
def find_substring(array, pattern):
"""
Find the index of the first occurrence of substring *pattern* in each
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index b3f87127397..48bf537a4e9 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -285,6 +285,19 @@ def test_variance():
assert pc.variance(data, ddof=1).as_py() == 6.0
+def test_count_substring():
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None])
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32())
+ assert expected.equals(result)
+
+ arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None],
+ type=pa.large_string())
+ result = pc.count_substring(arr, "ab")
+ expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64())
+ assert expected.equals(result)
+
+
def test_find_substring():
arr = pa.array(["ab", "cab", "ba", None])
result = pc.find_substring(arr, "ab")
From 883f27df8ad955035c9bfbe31c63bf7ff0fc3940 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 7 Jun 2021 12:26:36 -0400
Subject: [PATCH 2/2] ARROW-12950: [C++] Add more tests for count_substring
---
cpp/src/arrow/compute/kernels/scalar_string_test.cc | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index 06f9ad252e2..8eb891745ee 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -109,6 +109,10 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) {
this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
"[1, null, 4]", &options_empty);
+ MatchSubstringOptions options_repeated{"aaa"};
+ this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaĆ”"])",
+ this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
+
// TODO: case-insensitive
}