From 821e27bdbd25f94c2125694f1edc5fa09e465093 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 18 May 2021 15:44:06 -0400
Subject: [PATCH 1/2] ARROW-12608: [C++][Python][R] Add split_pattern_regex
kernel
---
.../arrow/compute/kernels/scalar_string.cc | 88 +++++++++++++++++--
.../compute/kernels/scalar_string_test.cc | 28 ++++++
docs/source/cpp/compute.rst | 13 ++-
docs/source/python/api/compute.rst | 11 +++
python/pyarrow/tests/test_compute.py | 16 ++++
r/R/dplyr-functions.R | 20 +----
r/src/compute.cpp | 2 +-
.../testthat/test-dplyr-string-functions.R | 28 +++---
8 files changed, 164 insertions(+), 42 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 65196b2a491..94f9681de70 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -1004,17 +1004,17 @@ struct SplitBaseTransform {
return Status::OK();
}
- static Status CheckOptions(const Options& options) { return Status::OK(); }
-
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
Options options = State::Get(ctx);
Derived splitter(options); // we make an instance to reuse the parts vectors
+ RETURN_NOT_OK(splitter.CheckOptions());
return splitter.Split(ctx, batch, out);
}
+ Status CheckOptions() { return Status::OK(); }
+
Status Split(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
EnsureLookupTablesFilled(); // only needed for unicode
- RETURN_NOT_OK(Derived::CheckOptions(options));
if (batch[0].kind() == Datum::ARRAY) {
const ArrayData& input = *batch[0].array();
@@ -1080,8 +1080,8 @@ struct SplitPatternTransform : SplitBaseTransform
+struct SplitRegexTransform : SplitBaseTransform> {
+ using Base = SplitBaseTransform>;
+ using ArrayType = typename TypeTraits::ArrayType;
+ using string_offset_type = typename Type::offset_type;
+ using ScalarType = typename TypeTraits::ScalarType;
+
+ const RE2 regex_split;
+
+ explicit SplitRegexTransform(SplitPatternOptions options)
+ : Base(options), regex_split(MakePattern(options)) {}
+
+ static std::string MakePattern(const SplitPatternOptions& options) {
+ // RE2 does *not* give you the full match! Must wrap the regex in a capture group
+ // There is FindAndConsume, but it would give only the end of the separator
+ std::string pattern = "(";
+ pattern.reserve(options.pattern.size() + 1);
+ pattern += options.pattern;
+ pattern += ')';
+ return pattern;
+ }
+
+ Status CheckOptions() {
+ if (Base::options.reverse) {
+ return Status::NotImplemented("Cannot split in reverse with regex");
+ }
+ return RegexStatus(regex_split);
+ }
+
+ bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin,
+ const uint8_t** separator_end, const SplitOptions& options) {
+ re2::StringPiece piece(reinterpret_cast(begin),
+ std::distance(begin, end));
+ // "StringPiece is mutated to point to matched piece"
+ re2::StringPiece result;
+ if (!re2::RE2::PartialMatch(piece, regex_split, &result)) {
+ return false;
+ }
+ *separator_begin = reinterpret_cast(result.data());
+ *separator_end = reinterpret_cast(result.data() + result.size());
+ return true;
+ }
+ bool FindReverse(const uint8_t* begin, const uint8_t* end,
+ const uint8_t** separator_begin, const uint8_t** separator_end,
+ const SplitOptions& options) {
+ // Not easily supportable, unfortunately
+ return false;
+ }
+};
+
+const FunctionDoc split_pattern_regex_doc(
+ "Split string according to regex pattern",
+ ("Split each string according to the regex `pattern` defined in\n"
+ "SplitPatternOptions. The output for each string input is a list\n"
+ "of strings.\n"
+ "\n"
+ "The maximum number of splits and direction of splitting\n"
+ "(forward, reverse) can optionally be defined in SplitPatternOptions."),
+ {"strings"}, "SplitPatternOptions");
+
+void AddSplitRegex(FunctionRegistry* registry) {
+ auto func = std::make_shared("split_pattern_regex", Arity::Unary(),
+ &split_pattern_regex_doc);
+ using t32 = SplitRegexTransform;
+ using t64 = SplitRegexTransform;
+ DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
+ DCHECK_OK(
+ func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
+ DCHECK_OK(registry->AddFunction(std::move(func)));
+}
+#endif
+
void AddSplit(FunctionRegistry* registry) {
AddSplitPattern(registry);
AddSplitWhitespaceAscii(registry);
#ifdef ARROW_WITH_UTF8PROC
AddSplitWhitespaceUTF8(registry);
#endif
+#ifdef ARROW_WITH_RE2
+ AddSplitRegex(registry);
+#endif
}
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
index a59634b7be8..c50b8091d9d 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc
@@ -472,6 +472,34 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
&options_max);
}
+#ifdef ARROW_WITH_RE2
+TYPED_TEST(TestStringKernels, SplitRegex) {
+ SplitPatternOptions options{"a+|b"};
+
+ this->CheckUnary(
+ "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
+ list(this->type()),
+ R"([["", "", ""], ["foo", ""], ["foo ", "", "r"], ["foo"], ["A", "B", "C"], null])",
+ &options);
+
+ options.max_splits = 1;
+ this->CheckUnary(
+ "split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
+ list(this->type()),
+ R"([["", "b"], ["foo", ""], ["foo ", "ar"], ["foo"], ["A", "BaaaC"], null])",
+ &options);
+}
+
+TYPED_TEST(TestStringKernels, SplitRegexReverse) {
+ SplitPatternOptions options{"a+|b", /*max_splits=*/1, /*reverse=*/true};
+ Datum input = ArrayFromJSON(this->type(), R"(["a"])");
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ NotImplemented, ::testing::HasSubstr("Cannot split in reverse with regex"),
+ CallFunction("split_pattern_regex", {input}, &options));
+}
+#endif
+
TYPED_TEST(TestStringKernels, ReplaceSubstring) {
ReplaceSubstringOptions options{"foo", "bazz"};
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 592dc4ec1b0..8146fed81d4 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -568,18 +568,23 @@ when a positive ``max_splits`` is given.
+==========================+============+=========================+===================+==================================+=========+
| split_pattern | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(1) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
-| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(2) |
+| split_pattern_regex | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(2) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
-| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) |
+| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) |
++--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
+| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(4) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
* \(1) The string is split when an exact pattern is found (the pattern itself
is not included in the output).
-* \(2) A non-zero length sequence of Unicode defined whitespace codepoints
+* \(1) The string is split when a regex match is found (the matched
+ substring itself is not included in the output).
+
+* \(3) A non-zero length sequence of Unicode defined whitespace codepoints
is seen as separator.
-* \(3) A non-zero length sequence of ASCII defined whitespace bytes
+* \(4) A non-zero length sequence of ASCII defined whitespace bytes
(``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen
as separator.
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index da16ccdfa29..c8907424772 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -137,6 +137,17 @@ a byte-by-byte basis.
string_is_ascii
+String Splitting
+----------------
+
+.. autosummary::
+ :toctree: ../generated/
+
+ split_pattern
+ split_pattern_regex
+ ascii_split_whitespace
+ utf8_split_whitespace
+
String Transforms
-----------------
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 8e045fb4f2d..7205c48e0e3 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -349,6 +349,22 @@ def test_split_whitespace_ascii():
assert expected.equals(result)
+def test_split_pattern_regex():
+ arr = pa.array(["-foo---bar--", "---foo---b"])
+ result = pc.split_pattern_regex(arr, pattern="-+")
+ expected = pa.array([["", "foo", "bar", ""], ["", "foo", "b"]])
+ assert expected.equals(result)
+
+ result = pc.split_pattern_regex(arr, pattern="-+", max_splits=1)
+ expected = pa.array([["", "foo---bar--"], ["", "foo---b"]])
+ assert expected.equals(result)
+
+ with pytest.raises(NotImplementedError,
+ match="Cannot split in reverse with regex"):
+ result = pc.split_pattern_regex(
+ arr, pattern="---", max_splits=1, reverse=True)
+
+
def test_min_max():
# An example generated function wrapper with possible options
data = [4, 5, 6, None, 1]
diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R
index bee06a7cb6a..e3ff5cecebd 100644
--- a/r/R/dplyr-functions.R
+++ b/r/R/dplyr-functions.R
@@ -201,17 +201,7 @@ nse_funcs$strsplit <- function(x,
useBytes = FALSE) {
assert_that(is.string(split))
- # The Arrow C++ library does not support splitting a string by a regular
- # expression pattern (ARROW-12608) but the default behavior of
- # base::strsplit() is to interpret the split pattern as a regex
- # (fixed = FALSE). R users commonly pass non-regex split patterns to
- # strsplit() without bothering to set fixed = TRUE. It would be annoying if
- # that didn't work here. So: if fixed = FALSE, let's check the split pattern
- # to see if it is a regex (if it contains any regex metacharacters). If not,
- # then allow to proceed.
- if (!fixed && contains_regex(split)) {
- arrow_not_supported("Regular expression matching in strsplit()")
- }
+ arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex")
# warn when the user specifies both fixed = TRUE and perl = TRUE, for
# consistency with the behavior of base::strsplit()
if (fixed && perl) {
@@ -221,7 +211,7 @@ nse_funcs$strsplit <- function(x,
# regardless of the value of perl, for consistency with the behavior of
# base::strsplit()
Expression$create(
- "split_pattern",
+ arrow_fun,
x,
options = list(pattern = split, reverse = FALSE, max_splits = -1L)
)
@@ -229,9 +219,7 @@ nse_funcs$strsplit <- function(x,
nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
opts <- get_stringr_pattern_options(enexpr(pattern))
- if (!opts$fixed && contains_regex(opts$pattern)) {
- arrow_not_supported("Regular expression matching in str_split()")
- }
+ arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex")
if (opts$ignore_case) {
arrow_not_supported("Case-insensitive string splitting")
}
@@ -249,7 +237,7 @@ nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
# str_split() controls the maximum number of pieces to return. So we must
# subtract 1 from n to get max_splits.
Expression$create(
- "split_pattern",
+ arrow_fun,
string,
options = list(
pattern =
diff --git a/r/src/compute.cpp b/r/src/compute.cpp
index 0ffe53578c4..59fbb6bfb74 100644
--- a/r/src/compute.cpp
+++ b/r/src/compute.cpp
@@ -233,7 +233,7 @@ std::shared_ptr make_compute_options(
max_replacements);
}
- if (func_name == "split_pattern") {
+ if (func_name == "split_pattern" || func_name == "split_pattern_regex") {
using Options = arrow::compute::SplitPatternOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
diff --git a/r/tests/testthat/test-dplyr-string-functions.R b/r/tests/testthat/test-dplyr-string-functions.R
index fb5e6752709..bb4794ef4c5 100644
--- a/r/tests/testthat/test-dplyr-string-functions.R
+++ b/r/tests/testthat/test-dplyr-string-functions.R
@@ -271,6 +271,12 @@ test_that("strsplit and str_split", {
collect(),
df
)
+ expect_dplyr_equal(
+ input %>%
+ mutate(x = strsplit(x, " +and +")) %>%
+ collect(),
+ df
+ )
expect_dplyr_equal(
input %>%
mutate(x = str_split(x, "and")) %>%
@@ -295,7 +301,12 @@ test_that("strsplit and str_split", {
collect(),
df
)
-
+ expect_dplyr_equal(
+ input %>%
+ mutate(x = str_split(x, "Foo|bar", n = 2)) %>%
+ collect(),
+ df
+ )
})
test_that("arrow_*_split_whitespace functions", {
@@ -352,21 +363,6 @@ test_that("errors and warnings in string splitting", {
# so here we can just call the functions directly
x <- Expression$field_ref("x")
- expect_error(
- nse_funcs$strsplit(x, "and.*", fixed = FALSE),
- 'Regular expression matching in strsplit() not supported by Arrow',
- fixed = TRUE
- )
- expect_error(
- nse_funcs$str_split(x, "and.?"),
- 'Regular expression matching in str_split() not supported by Arrow',
- fixed = TRUE
- )
- expect_error(
- nse_funcs$str_split(x, regex("and.*")),
- 'Regular expression matching in str_split() not supported by Arrow',
- fixed = TRUE
- )
expect_error(
nse_funcs$str_split(x, fixed("and", ignore_case = TRUE)),
"Case-insensitive string splitting not supported by Arrow"
From 7bbcecbc224b7eff767213aacff706cd159a5b7b Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 19 May 2021 08:01:55 -0400
Subject: [PATCH 2/2] ARROW-12608: [C++] Address review feedback
---
cpp/src/arrow/compute/kernels/scalar_string.cc | 2 +-
docs/source/cpp/compute.rst | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc
index 94f9681de70..3c365593c1c 100644
--- a/cpp/src/arrow/compute/kernels/scalar_string.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_string.cc
@@ -1319,7 +1319,7 @@ struct SplitRegexTransform : SplitBaseTransform