diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 23a94d9eb92..4d83e1ec24e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -488,21 +488,25 @@ struct PlainSubstringMatcher { } } - bool Match(util::string_view current) const { + int64_t Find(util::string_view current) const { // Phase 2: Find the prefix in the data const auto pattern_length = options_.pattern.size(); int64_t pattern_pos = 0; + int64_t pos = 0; for (const auto c : current) { while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) { pattern_pos = prefix_table[pattern_pos]; } pattern_pos++; if (static_cast(pattern_pos) == pattern_length) { - return true; + return pos + 1 - pattern_length; } + pos++; } - return false; + return -1; } + + bool Match(util::string_view current) const { return Find(current) >= 0; } }; const FunctionDoc match_substring_doc( @@ -664,6 +668,48 @@ void AddMatchSubstring(FunctionRegistry* registry) { #endif } +// Substring find - lfind/index/etc. + +struct FindSubstring { + const PlainSubstringMatcher matcher_; + + explicit FindSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {} + + template + OutValue Call(KernelContext*, util::string_view val, Status*) const { + return static_cast(matcher_.Find(val)); + } +}; + +template +Status FindSubstringExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + using offset_type = typename TypeTraits::OffsetType; + applicator::ScalarUnaryNotNullStateful kernel{ + FindSubstring(PlainSubstringMatcher(MatchSubstringState::Get(ctx)))}; + return kernel.Exec(ctx, batch, out); +} + +const FunctionDoc find_substring_doc( + "Find first occurrence of substring", + ("For each string in `strings`, emit the index of the first occurrence of the given " + "pattern, or -1 if not found.\n" + "Null inputs emit null. The pattern must be given in MatchSubstringOptions."), + {"strings"}, "MatchSubstringOptions"); + +void AddFindSubstring(FunctionRegistry* registry) { + auto func = std::make_shared("find_substring", Arity::Unary(), + &find_substring_doc); + DCHECK_OK(func->AddKernel({binary()}, int32(), FindSubstringExec, + MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({utf8()}, int32(), FindSubstringExec, + MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({large_binary()}, int64(), FindSubstringExec, + MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, int64(), FindSubstringExec, + MatchSubstringState::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + // IsAlpha/Digit etc #ifdef ARROW_WITH_UTF8PROC @@ -2626,6 +2672,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); + AddFindSubstring(registry); MakeUnaryStringBatchKernelWithState( "replace_substring", registry, &replace_substring_doc, MemAllocation::NO_PREALLOCATE); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 5ec7f579fff..7f2126828ce 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -74,6 +74,25 @@ TYPED_TEST(TestBinaryKernels, BinaryLength) { this->offset_type(), "[3, null, 10, 0, 1]"); } +TYPED_TEST(TestBinaryKernels, FindSubstring) { + MatchSubstringOptions options{"ab"}; + this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options); + + MatchSubstringOptions options_repeated{"abab"}; + this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated); + + MatchSubstringOptions options_double_char{"aab"}; + this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])", + this->offset_type(), "[-1, 0, -1, 1]", &options_double_char); + + MatchSubstringOptions options_double_char_2{"bbcaa"}; + this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(), + "[7]", &options_double_char_2); +} + template class TestStringKernels : public BaseTestStringKernels {}; @@ -470,6 +489,25 @@ TYPED_TEST(TestStringKernels, MatchLikeEscaping) { } #endif +TYPED_TEST(TestStringKernels, FindSubstring) { + MatchSubstringOptions options{"ab"}; + this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options); + this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options); + + MatchSubstringOptions options_repeated{"abab"}; + this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])", + this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated); + + MatchSubstringOptions options_double_char{"aab"}; + this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])", + this->offset_type(), "[-1, 0, -1, 1]", &options_double_char); + + MatchSubstringOptions options_double_char_2{"bbcaa"}; + this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(), + "[7]", &options_double_char_2); +} + TYPED_TEST(TestStringKernels, SplitBasics) { SplitPatternOptions options{" "}; // basics diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 79140257a9b..ca68a31cc21 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -462,7 +462,7 @@ String transforms * \(1) Each ASCII character in the input is converted to lowercase or uppercase. Non-ASCII characters are left untouched. -* \(2) ASCII input is reversed to the output. If non-ASCII characters +* \(2) ASCII input is reversed to the output. If non-ASCII characters are present, ``Invalid`` :class:`Status` will be returned. * \(3) Output is the physical length in bytes of each input element. Output @@ -482,7 +482,7 @@ String transforms pattern contains groups, backreferencing can be used. * \(6) Output is the number of characters (not bytes) of each input element. - Output type is Int32 for String, Int64 for LargeString. + Output type is Int32 for String, Int64 for LargeString. * \(7) Each UTF8-encoded character in the input is converted to lowercase or uppercase. @@ -541,40 +541,48 @@ These functions trim off characters on both sides (trim), or the left (ltrim) or Containment tests ~~~~~~~~~~~~~~~~~ -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ -| Function name | Arity | Input types | Output type | Options class | -+===========================+============+====================================+===============+========================================+ -| match_like | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` | -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ -| match_substring | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` | -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ -| match_substring_regex | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` | -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ -| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (4) | :struct:`SetLookupOptions` | -| | | Binary- and String-like | | | -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ -| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (5) | :struct:`SetLookupOptions` | -| | | Binary- and String-like | | | -+---------------------------+------------+------------------------------------+---------------+----------------------------------------+ - -* \(1) Output is true iff the SQL-style LIKE pattern ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| Function name | Arity | Input types | Output type | Options class | ++===========================+============+====================================+====================+========================================+ +| find_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| match_like | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| match_substring | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| match_substring_regex | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (5) | :struct:`SetLookupOptions` | +| | | Binary- and String-like | | | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ +| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (6) | :struct:`SetLookupOptions` | +| | | Binary- and String-like | | | ++---------------------------+------------+------------------------------------+--------------------+----------------------------------------+ + + +* \(1) Output is the index of the first occurrence of + :member:`MatchSubstringOptions::pattern` in the corresponding input + string, otherwise -1. Output type is Int32 for Binary/String, Int64 + for LargeBinary/LargeString. + +* \(2) Output is true iff the SQL-style LIKE pattern :member:`MatchSubstringOptions::pattern` fully matches the corresponding input element. That is, ``%`` will match any number of characters, ``_`` will match exactly one character, and any other character matches itself. To match a literal percent sign or underscore, precede the character with a backslash. -* \(2) Output is true iff :member:`MatchSubstringOptions::pattern` +* \(3) Output is true iff :member:`MatchSubstringOptions::pattern` is a substring of the corresponding input element. -* \(3) Output is true iff :member:`MatchSubstringOptions::pattern` +* \(4) Output is true iff :member:`MatchSubstringOptions::pattern` matches the corresponding input element at any position. -* \(4) Output is the index of the corresponding input element in +* \(5) Output is the index of the corresponding input element in :member:`SetLookupOptions::value_set`, if found there. Otherwise, output is null. -* \(5) Output is true iff the corresponding input element is equal to one +* \(6) Output is true iff the corresponding input element is equal to one of the elements in :member:`SetLookupOptions::value_set`. @@ -878,4 +886,3 @@ Structural transforms * \(2) For each value in the list child array, the index at which it is found in the list array is appended to the output. Nulls in the parent list array are discarded. - diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index d206cbc9e50..61482f49f19 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -169,6 +169,7 @@ Containment tests .. autosummary:: :toctree: ../generated/ + find_substring index_in is_in match_like diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index cb6ba475b5f..6bb0efb5963 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -288,6 +288,25 @@ def cast(arr, target_type, safe=True): return call_function("cast", [arr], options) +def find_substring(array, pattern): + """ + Find the index of the first occurrence of substring *pattern* in each + value of a string array. + + Parameters + ---------- + array : pyarrow.Array or pyarrow.ChunkedArray + pattern : str + pattern to search for exact matches + + Returns + ------- + result : pyarrow.Array or pyarrow.ChunkedArray + """ + return call_function("find_substring", [array], + MatchSubstringOptions(pattern)) + + def match_like(array, pattern): """ Test if the SQL-style LIKE pattern *pattern* matches a value of a diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index b014dcc0c8a..c62ff72acd5 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -280,6 +280,28 @@ def test_variance(): assert pc.variance(data, ddof=1).as_py() == 6.0 +def test_find_substring(): + arr = pa.array(["ab", "cab", "ba", None]) + result = pc.find_substring(arr, "ab") + expected = pa.array([0, 1, -1, None], type=pa.int32()) + assert expected.equals(result) + + arr = pa.array(["ab", "cab", "ba", None], type=pa.large_string()) + result = pc.find_substring(arr, "ab") + expected = pa.array([0, 1, -1, None], type=pa.int64()) + assert expected.equals(result) + + arr = pa.array([b"ab", b"cab", b"ba", None]) + result = pc.find_substring(arr, b"ab") + expected = pa.array([0, 1, -1, None], type=pa.int32()) + assert expected.equals(result) + + arr = pa.array([b"ab", b"cab", b"ba", None], type=pa.large_binary()) + result = pc.find_substring(arr, b"ab") + expected = pa.array([0, 1, -1, None], type=pa.int64()) + assert expected.equals(result) + + def test_match_like(): arr = pa.array(["ab", "ba%", "ba", "ca%d", None]) result = pc.match_like(arr, r"_a\%%")