diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 65196b2a491..3c365593c1c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1004,17 +1004,17 @@ struct SplitBaseTransform { return Status::OK(); } - static Status CheckOptions(const Options& options) { return Status::OK(); } - static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { Options options = State::Get(ctx); Derived splitter(options); // we make an instance to reuse the parts vectors + RETURN_NOT_OK(splitter.CheckOptions()); return splitter.Split(ctx, batch, out); } + Status CheckOptions() { return Status::OK(); } + Status Split(KernelContext* ctx, const ExecBatch& batch, Datum* out) { EnsureLookupTablesFilled(); // only needed for unicode - RETURN_NOT_OK(Derived::CheckOptions(options)); if (batch[0].kind() == Datum::ARRAY) { const ArrayData& input = *batch[0].array(); @@ -1080,8 +1080,8 @@ struct SplitPatternTransform : SplitBaseTransform +struct SplitRegexTransform : SplitBaseTransform> { + using Base = SplitBaseTransform>; + using ArrayType = typename TypeTraits::ArrayType; + using string_offset_type = typename Type::offset_type; + using ScalarType = typename TypeTraits::ScalarType; + + const RE2 regex_split; + + explicit SplitRegexTransform(SplitPatternOptions options) + : Base(options), regex_split(MakePattern(options)) {} + + static std::string MakePattern(const SplitPatternOptions& options) { + // RE2 does *not* give you the full match! Must wrap the regex in a capture group + // There is FindAndConsume, but it would give only the end of the separator + std::string pattern = "("; + pattern.reserve(options.pattern.size() + 2); + pattern += options.pattern; + pattern += ')'; + return pattern; + } + + Status CheckOptions() { + if (Base::options.reverse) { + return Status::NotImplemented("Cannot split in reverse with regex"); + } + return RegexStatus(regex_split); + } + + bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin, + const uint8_t** separator_end, const SplitOptions& options) { + re2::StringPiece piece(reinterpret_cast(begin), + std::distance(begin, end)); + // "StringPiece is mutated to point to matched piece" + re2::StringPiece result; + if (!re2::RE2::PartialMatch(piece, regex_split, &result)) { + return false; + } + *separator_begin = reinterpret_cast(result.data()); + *separator_end = reinterpret_cast(result.data() + result.size()); + return true; + } + bool FindReverse(const uint8_t* begin, const uint8_t* end, + const uint8_t** separator_begin, const uint8_t** separator_end, + const SplitOptions& options) { + // Not easily supportable, unfortunately + return false; + } +}; + +const FunctionDoc split_pattern_regex_doc( + "Split string according to regex pattern", + ("Split each string according to the regex `pattern` defined in\n" + "SplitPatternOptions. The output for each string input is a list\n" + "of strings.\n" + "\n" + "The maximum number of splits and direction of splitting\n" + "(forward, reverse) can optionally be defined in SplitPatternOptions."), + {"strings"}, "SplitPatternOptions"); + +void AddSplitRegex(FunctionRegistry* registry) { + auto func = std::make_shared("split_pattern_regex", Arity::Unary(), + &split_pattern_regex_doc); + using t32 = SplitRegexTransform; + using t64 = SplitRegexTransform; + DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init)); + DCHECK_OK( + func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); +} +#endif + void AddSplit(FunctionRegistry* registry) { AddSplitPattern(registry); AddSplitWhitespaceAscii(registry); #ifdef ARROW_WITH_UTF8PROC AddSplitWhitespaceUTF8(registry); #endif +#ifdef ARROW_WITH_RE2 + AddSplitRegex(registry); +#endif } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index a59634b7be8..c50b8091d9d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -472,6 +472,34 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { &options_max); } +#ifdef ARROW_WITH_RE2 +TYPED_TEST(TestStringKernels, SplitRegex) { + SplitPatternOptions options{"a+|b"}; + + this->CheckUnary( + "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])", + list(this->type()), + R"([["", "", ""], ["foo", ""], ["foo ", "", "r"], ["foo"], ["A", "B", "C"], null])", + &options); + + options.max_splits = 1; + this->CheckUnary( + "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])", + list(this->type()), + R"([["", "b"], ["foo", ""], ["foo ", "ar"], ["foo"], ["A", "BaaaC"], null])", + &options); +} + +TYPED_TEST(TestStringKernels, SplitRegexReverse) { + SplitPatternOptions options{"a+|b", /*max_splits=*/1, /*reverse=*/true}; + Datum input = ArrayFromJSON(this->type(), R"(["a"])"); + + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("Cannot split in reverse with regex"), + CallFunction("split_pattern_regex", {input}, &options)); +} +#endif + TYPED_TEST(TestStringKernels, ReplaceSubstring) { ReplaceSubstringOptions options{"foo", "bazz"}; this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 592dc4ec1b0..280d303b16f 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -568,18 +568,23 @@ when a positive ``max_splits`` is given. +==========================+============+=========================+===================+==================================+=========+ | split_pattern | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(1) | +--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ -| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(2) | +| split_pattern_regex | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(2) | +--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ -| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) | +| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) | ++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ +| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(4) | +--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+ * \(1) The string is split when an exact pattern is found (the pattern itself is not included in the output). -* \(2) A non-zero length sequence of Unicode defined whitespace codepoints +* \(2) The string is split when a regex match is found (the matched + substring itself is not included in the output). + +* \(3) A non-zero length sequence of Unicode defined whitespace codepoints is seen as separator. -* \(3) A non-zero length sequence of ASCII defined whitespace bytes +* \(4) A non-zero length sequence of ASCII defined whitespace bytes (``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen as separator. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index da16ccdfa29..c8907424772 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -137,6 +137,17 @@ a byte-by-byte basis. string_is_ascii +String Splitting +---------------- + +.. autosummary:: + :toctree: ../generated/ + + split_pattern + split_pattern_regex + ascii_split_whitespace + utf8_split_whitespace + String Transforms ----------------- diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 8e045fb4f2d..7205c48e0e3 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -349,6 +349,22 @@ def test_split_whitespace_ascii(): assert expected.equals(result) +def test_split_pattern_regex(): + arr = pa.array(["-foo---bar--", "---foo---b"]) + result = pc.split_pattern_regex(arr, pattern="-+") + expected = pa.array([["", "foo", "bar", ""], ["", "foo", "b"]]) + assert expected.equals(result) + + result = pc.split_pattern_regex(arr, pattern="-+", max_splits=1) + expected = pa.array([["", "foo---bar--"], ["", "foo---b"]]) + assert expected.equals(result) + + with pytest.raises(NotImplementedError, + match="Cannot split in reverse with regex"): + result = pc.split_pattern_regex( + arr, pattern="---", max_splits=1, reverse=True) + + def test_min_max(): # An example generated function wrapper with possible options data = [4, 5, 6, None, 1] diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R index bee06a7cb6a..e3ff5cecebd 100644 --- a/r/R/dplyr-functions.R +++ b/r/R/dplyr-functions.R @@ -201,17 +201,7 @@ nse_funcs$strsplit <- function(x, useBytes = FALSE) { assert_that(is.string(split)) - # The Arrow C++ library does not support splitting a string by a regular - # expression pattern (ARROW-12608) but the default behavior of - # base::strsplit() is to interpret the split pattern as a regex - # (fixed = FALSE). R users commonly pass non-regex split patterns to - # strsplit() without bothering to set fixed = TRUE. It would be annoying if - # that didn't work here. So: if fixed = FALSE, let's check the split pattern - # to see if it is a regex (if it contains any regex metacharacters). If not, - # then allow to proceed. - if (!fixed && contains_regex(split)) { - arrow_not_supported("Regular expression matching in strsplit()") - } + arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex") # warn when the user specifies both fixed = TRUE and perl = TRUE, for # consistency with the behavior of base::strsplit() if (fixed && perl) { @@ -221,7 +211,7 @@ nse_funcs$strsplit <- function(x, # regardless of the value of perl, for consistency with the behavior of # base::strsplit() Expression$create( - "split_pattern", + arrow_fun, x, options = list(pattern = split, reverse = FALSE, max_splits = -1L) ) @@ -229,9 +219,7 @@ nse_funcs$strsplit <- function(x, nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) { opts <- get_stringr_pattern_options(enexpr(pattern)) - if (!opts$fixed && contains_regex(opts$pattern)) { - arrow_not_supported("Regular expression matching in str_split()") - } + arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex") if (opts$ignore_case) { arrow_not_supported("Case-insensitive string splitting") } @@ -249,7 +237,7 @@ nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) { # str_split() controls the maximum number of pieces to return. So we must # subtract 1 from n to get max_splits. Expression$create( - "split_pattern", + arrow_fun, string, options = list( pattern = diff --git a/r/src/compute.cpp b/r/src/compute.cpp index 0ffe53578c4..59fbb6bfb74 100644 --- a/r/src/compute.cpp +++ b/r/src/compute.cpp @@ -233,7 +233,7 @@ std::shared_ptr make_compute_options( max_replacements); } - if (func_name == "split_pattern") { + if (func_name == "split_pattern" || func_name == "split_pattern_regex") { using Options = arrow::compute::SplitPatternOptions; int64_t max_splits = -1; if (!Rf_isNull(options["max_splits"])) { diff --git a/r/tests/testthat/test-dplyr-string-functions.R b/r/tests/testthat/test-dplyr-string-functions.R index fb5e6752709..bb4794ef4c5 100644 --- a/r/tests/testthat/test-dplyr-string-functions.R +++ b/r/tests/testthat/test-dplyr-string-functions.R @@ -271,6 +271,12 @@ test_that("strsplit and str_split", { collect(), df ) + expect_dplyr_equal( + input %>% + mutate(x = strsplit(x, " +and +")) %>% + collect(), + df + ) expect_dplyr_equal( input %>% mutate(x = str_split(x, "and")) %>% @@ -295,7 +301,12 @@ test_that("strsplit and str_split", { collect(), df ) - + expect_dplyr_equal( + input %>% + mutate(x = str_split(x, "Foo|bar", n = 2)) %>% + collect(), + df + ) }) test_that("arrow_*_split_whitespace functions", { @@ -352,21 +363,6 @@ test_that("errors and warnings in string splitting", { # so here we can just call the functions directly x <- Expression$field_ref("x") - expect_error( - nse_funcs$strsplit(x, "and.*", fixed = FALSE), - 'Regular expression matching in strsplit() not supported by Arrow', - fixed = TRUE - ) - expect_error( - nse_funcs$str_split(x, "and.?"), - 'Regular expression matching in str_split() not supported by Arrow', - fixed = TRUE - ) - expect_error( - nse_funcs$str_split(x, regex("and.*")), - 'Regular expression matching in str_split() not supported by Arrow', - fixed = TRUE - ) expect_error( nse_funcs$str_split(x, fixed("and", ignore_case = TRUE)), "Case-insensitive string splitting not supported by Arrow"