Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1004,17 +1004,17 @@ struct SplitBaseTransform {
return Status::OK();
}

static Status CheckOptions(const Options& options) { return Status::OK(); }

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
Options options = State::Get(ctx);
Derived splitter(options); // we make an instance to reuse the parts vectors
RETURN_NOT_OK(splitter.CheckOptions());
return splitter.Split(ctx, batch, out);
}

Status CheckOptions() { return Status::OK(); }

Status Split(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
EnsureLookupTablesFilled(); // only needed for unicode
RETURN_NOT_OK(Derived::CheckOptions(options));

if (batch[0].kind() == Datum::ARRAY) {
const ArrayData& input = *batch[0].array();
Expand Down Expand Up @@ -1080,8 +1080,8 @@ struct SplitPatternTransform : SplitBaseTransform<Type, ListType, SplitPatternOp
using string_offset_type = typename Type::offset_type;
using Base::Base;

static Status CheckOptions(const SplitPatternOptions& options) {
if (options.pattern.length() == 0) {
Status CheckOptions() {
if (Base::options.pattern.length() == 0) {
return Status::Invalid("Empty separator");
}
return Status::OK();
Expand Down Expand Up @@ -1300,12 +1300,90 @@ void AddSplitWhitespaceUTF8(FunctionRegistry* registry) {
}
#endif

#ifdef ARROW_WITH_RE2
template <typename Type, typename ListType>
struct SplitRegexTransform : SplitBaseTransform<Type, ListType, SplitPatternOptions,
SplitRegexTransform<Type, ListType>> {
using Base = SplitBaseTransform<Type, ListType, SplitPatternOptions,
SplitRegexTransform<Type, ListType>>;
using ArrayType = typename TypeTraits<Type>::ArrayType;
using string_offset_type = typename Type::offset_type;
using ScalarType = typename TypeTraits<Type>::ScalarType;

const RE2 regex_split;

explicit SplitRegexTransform(SplitPatternOptions options)
: Base(options), regex_split(MakePattern(options)) {}

static std::string MakePattern(const SplitPatternOptions& options) {
// RE2 does *not* give you the full match! Must wrap the regex in a capture group
// There is FindAndConsume, but it would give only the end of the separator
std::string pattern = "(";
pattern.reserve(options.pattern.size() + 2);
pattern += options.pattern;
pattern += ')';
return pattern;
}

Status CheckOptions() {
if (Base::options.reverse) {
return Status::NotImplemented("Cannot split in reverse with regex");
}
return RegexStatus(regex_split);
}

bool Find(const uint8_t* begin, const uint8_t* end, const uint8_t** separator_begin,
const uint8_t** separator_end, const SplitOptions& options) {
re2::StringPiece piece(reinterpret_cast<const char*>(begin),
std::distance(begin, end));
// "StringPiece is mutated to point to matched piece"
re2::StringPiece result;
if (!re2::RE2::PartialMatch(piece, regex_split, &result)) {
return false;
}
*separator_begin = reinterpret_cast<const uint8_t*>(result.data());
*separator_end = reinterpret_cast<const uint8_t*>(result.data() + result.size());
return true;
}
bool FindReverse(const uint8_t* begin, const uint8_t* end,
const uint8_t** separator_begin, const uint8_t** separator_end,
const SplitOptions& options) {
// Not easily supportable, unfortunately
return false;
}
};

const FunctionDoc split_pattern_regex_doc(
"Split string according to regex pattern",
("Split each string according to the regex `pattern` defined in\n"
"SplitPatternOptions. The output for each string input is a list\n"
"of strings.\n"
"\n"
"The maximum number of splits and direction of splitting\n"
"(forward, reverse) can optionally be defined in SplitPatternOptions."),
{"strings"}, "SplitPatternOptions");

void AddSplitRegex(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("split_pattern_regex", Arity::Unary(),
&split_pattern_regex_doc);
using t32 = SplitRegexTransform<StringType, ListType>;
using t64 = SplitRegexTransform<LargeStringType, ListType>;
DCHECK_OK(func->AddKernel({utf8()}, {list(utf8())}, t32::Exec, t32::State::Init));
DCHECK_OK(
func->AddKernel({large_utf8()}, {list(large_utf8())}, t64::Exec, t64::State::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
#endif

void AddSplit(FunctionRegistry* registry) {
AddSplitPattern(registry);
AddSplitWhitespaceAscii(registry);
#ifdef ARROW_WITH_UTF8PROC
AddSplitWhitespaceUTF8(registry);
#endif
#ifdef ARROW_WITH_RE2
AddSplitRegex(registry);
#endif
}

// ----------------------------------------------------------------------
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,34 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) {
&options_max);
}

#ifdef ARROW_WITH_RE2
TYPED_TEST(TestStringKernels, SplitRegex) {
SplitPatternOptions options{"a+|b"};

this->CheckUnary(
"split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
list(this->type()),
R"([["", "", ""], ["foo", ""], ["foo ", "", "r"], ["foo"], ["A", "B", "C"], null])",
&options);

options.max_splits = 1;
this->CheckUnary(
"split_pattern_regex", R"(["aaaab", "foob", "foo bar", "foo", "AaaaBaaaC", null])",
list(this->type()),
R"([["", "b"], ["foo", ""], ["foo ", "ar"], ["foo"], ["A", "BaaaC"], null])",
&options);
}

TYPED_TEST(TestStringKernels, SplitRegexReverse) {
SplitPatternOptions options{"a+|b", /*max_splits=*/1, /*reverse=*/true};
Datum input = ArrayFromJSON(this->type(), R"(["a"])");

EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("Cannot split in reverse with regex"),
CallFunction("split_pattern_regex", {input}, &options));
}
#endif

TYPED_TEST(TestStringKernels, ReplaceSubstring) {
ReplaceSubstringOptions options{"foo", "bazz"};
this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])",
Expand Down
13 changes: 9 additions & 4 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -568,18 +568,23 @@ when a positive ``max_splits`` is given.
+==========================+============+=========================+===================+==================================+=========+
| split_pattern | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(1) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(2) |
| split_pattern_regex | Unary | String-like | List-like | :struct:`SplitPatternOptions` | \(2) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) |
| utf8_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(3) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+
| ascii_split_whitespace | Unary | String-like | List-like | :struct:`SplitOptions` | \(4) |
+--------------------------+------------+-------------------------+-------------------+----------------------------------+---------+

* \(1) The string is split when an exact pattern is found (the pattern itself
is not included in the output).

* \(2) A non-zero length sequence of Unicode defined whitespace codepoints
* \(2) The string is split when a regex match is found (the matched
substring itself is not included in the output).

* \(3) A non-zero length sequence of Unicode defined whitespace codepoints
is seen as separator.

* \(3) A non-zero length sequence of ASCII defined whitespace bytes
* \(4) A non-zero length sequence of ASCII defined whitespace bytes
(``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen
as separator.

Expand Down
11 changes: 11 additions & 0 deletions docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,17 @@ a byte-by-byte basis.

string_is_ascii

String Splitting
----------------

.. autosummary::
:toctree: ../generated/

split_pattern
split_pattern_regex
ascii_split_whitespace
utf8_split_whitespace

String Transforms
-----------------

Expand Down
16 changes: 16 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,22 @@ def test_split_whitespace_ascii():
assert expected.equals(result)


def test_split_pattern_regex():
arr = pa.array(["-foo---bar--", "---foo---b"])
result = pc.split_pattern_regex(arr, pattern="-+")
expected = pa.array([["", "foo", "bar", ""], ["", "foo", "b"]])
assert expected.equals(result)

result = pc.split_pattern_regex(arr, pattern="-+", max_splits=1)
expected = pa.array([["", "foo---bar--"], ["", "foo---b"]])
assert expected.equals(result)

with pytest.raises(NotImplementedError,
match="Cannot split in reverse with regex"):
result = pc.split_pattern_regex(
arr, pattern="---", max_splits=1, reverse=True)


def test_min_max():
# An example generated function wrapper with possible options
data = [4, 5, 6, None, 1]
Expand Down
20 changes: 4 additions & 16 deletions r/R/dplyr-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,7 @@ nse_funcs$strsplit <- function(x,
useBytes = FALSE) {
assert_that(is.string(split))

# The Arrow C++ library does not support splitting a string by a regular
# expression pattern (ARROW-12608) but the default behavior of
# base::strsplit() is to interpret the split pattern as a regex
# (fixed = FALSE). R users commonly pass non-regex split patterns to
# strsplit() without bothering to set fixed = TRUE. It would be annoying if
# that didn't work here. So: if fixed = FALSE, let's check the split pattern
# to see if it is a regex (if it contains any regex metacharacters). If not,
# then allow to proceed.
if (!fixed && contains_regex(split)) {
arrow_not_supported("Regular expression matching in strsplit()")
}
arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex")
# warn when the user specifies both fixed = TRUE and perl = TRUE, for
# consistency with the behavior of base::strsplit()
if (fixed && perl) {
Expand All @@ -221,17 +211,15 @@ nse_funcs$strsplit <- function(x,
# regardless of the value of perl, for consistency with the behavior of
# base::strsplit()
Expression$create(
"split_pattern",
arrow_fun,
x,
options = list(pattern = split, reverse = FALSE, max_splits = -1L)
)
}

nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
opts <- get_stringr_pattern_options(enexpr(pattern))
if (!opts$fixed && contains_regex(opts$pattern)) {
arrow_not_supported("Regular expression matching in str_split()")
}
arrow_fun <- ifelse(opts$fixed, "split_pattern", "split_pattern_regex")
if (opts$ignore_case) {
arrow_not_supported("Case-insensitive string splitting")
}
Expand All @@ -249,7 +237,7 @@ nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) {
# str_split() controls the maximum number of pieces to return. So we must
# subtract 1 from n to get max_splits.
Expression$create(
"split_pattern",
arrow_fun,
string,
options = list(
pattern =
Expand Down
2 changes: 1 addition & 1 deletion r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
max_replacements);
}

if (func_name == "split_pattern") {
if (func_name == "split_pattern" || func_name == "split_pattern_regex") {
using Options = arrow::compute::SplitPatternOptions;
int64_t max_splits = -1;
if (!Rf_isNull(options["max_splits"])) {
Expand Down
28 changes: 12 additions & 16 deletions r/tests/testthat/test-dplyr-string-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ test_that("strsplit and str_split", {
collect(),
df
)
expect_dplyr_equal(
input %>%
mutate(x = strsplit(x, " +and +")) %>%
collect(),
df
)
expect_dplyr_equal(
input %>%
mutate(x = str_split(x, "and")) %>%
Expand All @@ -295,7 +301,12 @@ test_that("strsplit and str_split", {
collect(),
df
)

expect_dplyr_equal(
input %>%
mutate(x = str_split(x, "Foo|bar", n = 2)) %>%
collect(),
df
)
})

test_that("arrow_*_split_whitespace functions", {
Expand Down Expand Up @@ -352,21 +363,6 @@ test_that("errors and warnings in string splitting", {
# so here we can just call the functions directly

x <- Expression$field_ref("x")
expect_error(
nse_funcs$strsplit(x, "and.*", fixed = FALSE),
'Regular expression matching in strsplit() not supported by Arrow',
fixed = TRUE
)
expect_error(
nse_funcs$str_split(x, "and.?"),
'Regular expression matching in str_split() not supported by Arrow',
fixed = TRUE
)
expect_error(
nse_funcs$str_split(x, regex("and.*")),
'Regular expression matching in str_split() not supported by Arrow',
fixed = TRUE
)
expect_error(
nse_funcs$str_split(x, fixed("and", ignore_case = TRUE)),
"Case-insensitive string splitting not supported by Arrow"
Expand Down