From 8006de63400cc4292a1c91e708fbc068adb56fe5 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 4 Jun 2021 09:54:22 -0500 Subject: [PATCH 1/2] ARROW-12950: [C++] Implement count_substring --- .../arrow/compute/kernels/scalar_string.cc | 70 ++++++++++++++++++- .../compute/kernels/scalar_string_test.cc | 15 ++++ docs/source/cpp/compute.rst | 30 ++++---- docs/source/python/api/compute.rst | 1 + python/pyarrow/compute.py | 19 +++++ python/pyarrow/tests/test_compute.py | 13 ++++ 6 files changed, 135 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 9db16e26ca5..08159ccddaa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -741,8 +741,12 @@ template struct FindSubstringExec { using OffsetType = typename TypeTraits::OffsetType; static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const MatchSubstringOptions& options = MatchSubstringState::Get(ctx); + if (options.ignore_case) { + return Status::NotImplemented("find_substring with ignore_case"); + } applicator::ScalarUnaryNotNullStateful kernel{ - FindSubstring(PlainSubstringMatcher(MatchSubstringState::Get(ctx)))}; + FindSubstring(PlainSubstringMatcher(options))}; return kernel.Exec(ctx, batch, out); } }; @@ -771,6 +775,69 @@ void AddFindSubstring(FunctionRegistry* registry) { DCHECK_OK(registry->AddFunction(std::move(func))); } +// Substring count + +struct CountSubstring { + const PlainSubstringMatcher matcher_; + + explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {} + + template + OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue count = 0; + uint64_t start = 0; + const auto pattern_size = std::max(1, matcher_.options_.pattern.size()); + while (start <= val.size()) { + const int64_t index = matcher_.Find(val.substr(start)); + if (index >= 0) { + count++; + start += index + pattern_size; + } else { + break; + } + } + return count; + } +}; + +template +struct CountSubstringExec { + using OffsetType = typename TypeTraits::OffsetType; + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + const MatchSubstringOptions& options = MatchSubstringState::Get(ctx); + if (options.ignore_case) { + return Status::NotImplemented("count_substring with ignore_case"); + } + applicator::ScalarUnaryNotNullStateful kernel{ + CountSubstring(PlainSubstringMatcher(options))}; + return kernel.Exec(ctx, batch, out); + } +}; + +const FunctionDoc count_substring_doc( + "Count occurrences of substring", + ("For each string in `strings`, emit the number of occurrences of the given " + "pattern.\n" + "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), + {"strings"}, "MatchSubstringOptions"); + +void AddCountSubstring(FunctionRegistry* registry) { + auto func = std::make_shared("count_substring", Arity::Unary(), + &count_substring_doc); + for (const auto& ty : BaseBinaryTypes()) { + std::shared_ptr offset_type; + if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) { + offset_type = int64(); + } else { + offset_type = int32(); + } + DCHECK_OK(func->AddKernel({ty}, offset_type, + GenerateTypeAgnosticVarBinaryBase(ty), + MatchSubstringState::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + // Slicing template @@ -2936,6 +3003,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddUtf8Length(registry); AddMatchSubstring(registry); AddFindSubstring(registry); + AddCountSubstring(registry); MakeUnaryStringBatchKernelWithState( "replace_substring", registry, &replace_substring_doc, MemAllocation::NO_PREALLOCATE); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index bd5c8eec03f..06f9ad252e2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -97,6 +97,21 @@ TYPED_TEST(TestBinaryKernels, FindSubstring) { "[0, 0, null]", &options_empty); } +TYPED_TEST(TestBinaryKernels, CountSubstring) { + MatchSubstringOptions options{"aba"}; + this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary( + "count_substring", + R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])", + this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options); + + MatchSubstringOptions options_empty{""}; + this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(), + "[1, null, 4]", &options_empty); + + // TODO: case-insensitive +} + template class TestStringKernels : public BaseTestStringKernels {}; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 3f30bbcaa06..c7ccc8a822d 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -561,45 +561,51 @@ Containment tests +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ | Function name | Arity | Input types | Output type | Options class | +===========================+============+====================================+====================+========================================+ -| find_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` | +| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ -| match_like | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` | +| find_substring | Unary | String-like | Int32 or Int64 (2) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ -| match_substring | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` | +| match_like | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ -| match_substring_regex | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` | +| match_substring | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ -| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (5) | :struct:`SetLookupOptions` | +| match_substring_regex | Unary | String-like | Boolean (5) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (6) | :struct:`SetLookupOptions` | | | | Binary- and String-like | | | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ -| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (6) | :struct:`SetLookupOptions` | +| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (7) | :struct:`SetLookupOptions` | | | | Binary- and String-like | | | +---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +* \(1) Output is the number of occurrences of + :member:`MatchSubstringOptions::pattern` in the corresponding input + string. Output type is Int32 for Binary/String, Int64 + for LargeBinary/LargeString. -* \(1) Output is the index of the first occurrence of +* \(2) Output is the index of the first occurrence of :member:`MatchSubstringOptions::pattern` in the corresponding input string, otherwise -1. Output type is Int32 for Binary/String, Int64 for LargeBinary/LargeString. -* \(2) Output is true iff the SQL-style LIKE pattern +* \(3) Output is true iff the SQL-style LIKE pattern :member:`MatchSubstringOptions::pattern` fully matches the corresponding input element. That is, ``%`` will match any number of characters, ``_`` will match exactly one character, and any other character matches itself. To match a literal percent sign or underscore, precede the character with a backslash. -* \(3) Output is true iff :member:`MatchSubstringOptions::pattern` +* \(4) Output is true iff :member:`MatchSubstringOptions::pattern` is a substring of the corresponding input element. -* \(4) Output is true iff :member:`MatchSubstringOptions::pattern` +* \(5) Output is true iff :member:`MatchSubstringOptions::pattern` matches the corresponding input element at any position. -* \(5) Output is the index of the corresponding input element in +* \(6) Output is the index of the corresponding input element in :member:`SetLookupOptions::value_set`, if found there. Otherwise, output is null. -* \(6) Output is true iff the corresponding input element is equal to one +* \(7) Output is true iff the corresponding input element is equal to one of the elements in :member:`SetLookupOptions::value_set`. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index ccd530073aa..a586f9011fd 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -178,6 +178,7 @@ Containment tests .. autosummary:: :toctree: ../generated/ + count_substring find_substring index_in is_in diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index eb66f4407c8..8dc7181514c 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -291,6 +291,25 @@ def cast(arr, target_type, safe=True): return call_function("cast", [arr], options) +def count_substring(array, pattern): + """ + Count the occurrences of substring *pattern* in each value of a + string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("count_substring", [array], + MatchSubstringOptions(pattern)) + + def find_substring(array, pattern): """ Find the index of the first occurrence of substring *pattern* in each diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index b3f87127397..48bf537a4e9 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -285,6 +285,19 @@ def test_variance(): assert pc.variance(data, ddof=1).as_py() == 6.0 +def test_count_substring(): + arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None]) + result = pc.count_substring(arr, "ab") + expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32()) + assert expected.equals(result) + + arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], + type=pa.large_string()) + result = pc.count_substring(arr, "ab") + expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64()) + assert expected.equals(result) + + def test_find_substring(): arr = pa.array(["ab", "cab", "ba", None]) result = pc.find_substring(arr, "ab") From 883f27df8ad955035c9bfbe31c63bf7ff0fc3940 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 7 Jun 2021 12:26:36 -0400 Subject: [PATCH 2/2] ARROW-12950: [C++] Add more tests for count_substring --- cpp/src/arrow/compute/kernels/scalar_string_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 06f9ad252e2..8eb891745ee 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -109,6 +109,10 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) { this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(), "[1, null, 4]", &options_empty); + MatchSubstringOptions options_repeated{"aaa"}; + this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaĆ”"])", + this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated); + // TODO: case-insensitive }