Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct ArithmeticOptions : public FunctionOptions {
struct ARROW_EXPORT MatchSubstringOptions : public FunctionOptions {
explicit MatchSubstringOptions(std::string pattern) : pattern(std::move(pattern)) {}

/// The exact substring to look for inside input values.
/// The exact substring (or regex, depending on kernel) to look for inside input values.
std::string pattern;
};

Expand Down
165 changes: 106 additions & 59 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,83 +368,130 @@ void StringBoolTransform(KernelContext* ctx, const ExecBatch& batch,
}
}

template <typename offset_type>
void TransformMatchSubstring(const uint8_t* pattern, int64_t pattern_length,
const offset_type* offsets, const uint8_t* data,
int64_t length, int64_t output_offset, uint8_t* output) {
// This is an implementation of the Knuth-Morris-Pratt algorithm

// Phase 1: Build the prefix table
std::vector<offset_type> prefix_table(pattern_length + 1);
offset_type prefix_length = -1;
prefix_table[0] = -1;
for (offset_type pos = 0; pos < pattern_length; ++pos) {
// The prefix cannot be expanded, reset.
while (prefix_length >= 0 && pattern[pos] != pattern[prefix_length]) {
prefix_length = prefix_table[prefix_length];
}
prefix_length++;
prefix_table[pos + 1] = prefix_length;
}

// Phase 2: Find the prefix in the data
FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
for (int64_t i = 0; i < length; ++i) {
const uint8_t* current_data = data + offsets[i];
int64_t current_length = offsets[i + 1] - offsets[i];

int64_t pattern_pos = 0;
for (int64_t k = 0; k < current_length; k++) {
while ((pattern_pos >= 0) && (pattern[pattern_pos] != current_data[k])) {
pattern_pos = prefix_table[pattern_pos];
}
pattern_pos++;
if (pattern_pos == pattern_length) {
bitmap_writer.Set();
break;
}
}
bitmap_writer.Next();
}
bitmap_writer.Finish();
}

using MatchSubstringState = OptionsWrapper<MatchSubstringOptions>;

template <typename Type>
template <typename Type, typename Matcher>
struct MatchSubstring {
using offset_type = typename Type::offset_type;
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
MatchSubstringOptions arg = MatchSubstringState::Get(ctx);
const uint8_t* pat = reinterpret_cast<const uint8_t*>(arg.pattern.c_str());
const int64_t pat_size = arg.pattern.length();
// TODO Cache matcher across invocations (for regex compilation)
Matcher matcher(ctx, MatchSubstringState::Get(ctx));
if (ctx->HasError()) return;
StringBoolTransform<Type>(
ctx, batch,
[pat, pat_size](const void* offsets, const uint8_t* data, int64_t length,
int64_t output_offset, uint8_t* output) {
TransformMatchSubstring<offset_type>(
pat, pat_size, reinterpret_cast<const offset_type*>(offsets), data, length,
output_offset, output);
[&matcher](const void* raw_offsets, const uint8_t* data, int64_t length,
int64_t output_offset, uint8_t* output) {
const offset_type* offsets = reinterpret_cast<const offset_type*>(raw_offsets);
FirstTimeBitmapWriter bitmap_writer(output, output_offset, length);
for (int64_t i = 0; i < length; ++i) {
const char* current_data = reinterpret_cast<const char*>(data + offsets[i]);
int64_t current_length = offsets[i + 1] - offsets[i];
if (matcher.Match(util::string_view(current_data, current_length))) {
bitmap_writer.Set();
}
bitmap_writer.Next();
}
bitmap_writer.Finish();
},
out);
}
};

// This is an implementation of the Knuth-Morris-Pratt algorithm
struct PlainSubstringMatcher {
const MatchSubstringOptions& options_;
std::vector<int64_t> prefix_table;

PlainSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options)
: options_(options) {
// Phase 1: Build the prefix table
const auto pattern_length = options_.pattern.size();
prefix_table.resize(pattern_length + 1, /*value=*/0);
int64_t prefix_length = -1;
prefix_table[0] = -1;
for (size_t pos = 0; pos < pattern_length; ++pos) {
// The prefix cannot be expanded, reset.
while (prefix_length >= 0 &&
options_.pattern[pos] != options_.pattern[prefix_length]) {
prefix_length = prefix_table[prefix_length];
}
prefix_length++;
prefix_table[pos + 1] = prefix_length;
}
}

bool Match(util::string_view current) {
// Phase 2: Find the prefix in the data
const auto pattern_length = options_.pattern.size();
int64_t pattern_pos = 0;
for (const auto c : current) {
while ((pattern_pos >= 0) && (options_.pattern[pattern_pos] != c)) {
pattern_pos = prefix_table[pattern_pos];
}
pattern_pos++;
if (static_cast<size_t>(pattern_pos) == pattern_length) {
return true;
}
}
return false;
}
};

const FunctionDoc match_substring_doc(
"Match strings against literal pattern",
("For each string in `strings`, emit true iff it contains a given pattern.\n"
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");

#ifdef ARROW_WITH_RE2
struct RegexSubstringMatcher {
const MatchSubstringOptions& options_;
const RE2 regex_match_;

RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options)
: options_(options), regex_match_(options_.pattern) {
if (!regex_match_.ok()) {
ctx->SetStatus(Status::Invalid("Regular expression error"));
}
}

bool Match(util::string_view current) {
auto piece = re2::StringPiece(current.data(), current.length());
return re2::RE2::PartialMatch(piece, regex_match_);
}
};

const FunctionDoc match_substring_regex_doc(
"Match strings against regex pattern",
("For each string in `strings`, emit true iff it matches a given pattern at any "
"position.\n"
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");
#endif

void AddMatchSubstring(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(),
&match_substring_doc);
auto exec_32 = MatchSubstring<StringType>::Exec;
auto exec_64 = MatchSubstring<LargeStringType>::Exec;
DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
DCHECK_OK(
func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
{
auto func = std::make_shared<ScalarFunction>("match_substring", Arity::Unary(),
&match_substring_doc);
auto exec_32 = MatchSubstring<StringType, PlainSubstringMatcher>::Exec;
auto exec_64 = MatchSubstring<LargeStringType, PlainSubstringMatcher>::Exec;
DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
DCHECK_OK(
func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
#ifdef ARROW_WITH_RE2
{
auto func = std::make_shared<ScalarFunction>("match_substring_regex", Arity::Unary(),
&match_substring_regex_doc);
auto exec_32 = MatchSubstring<StringType, RegexSubstringMatcher>::Exec;
auto exec_64 = MatchSubstring<LargeStringType, RegexSubstringMatcher>::Exec;
DCHECK_OK(func->AddKernel({utf8()}, boolean(), exec_32, MatchSubstringState::Init));
DCHECK_OK(
func->AddKernel({large_utf8()}, boolean(), exec_64, MatchSubstringState::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
#endif
}

// IsAlpha/Digit etc
Expand Down Expand Up @@ -1246,7 +1293,7 @@ struct ReplaceSubString {
using State = OptionsWrapper<ReplaceSubstringOptions>;

static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
// TODO Cache replacer accross invocations (for regex compilation)
// TODO Cache replacer across invocations (for regex compilation)
Replacer replacer{ctx, State::Get(ctx)};
if (!ctx->HasError()) {
Replace(ctx, batch, &replacer, out);
Expand Down
21 changes: 21 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,27 @@ TYPED_TEST(TestStringKernels, MatchSubstring) {
&options_double_char_2);
}

#ifdef ARROW_WITH_RE2
TYPED_TEST(TestStringKernels, MatchSubstringRegex) {
MatchSubstringOptions options{"ab"};
this->CheckUnary("match_substring_regex", "[]", boolean(), "[]", &options);
this->CheckUnary("match_substring_regex", R"(["abc", "acb", "cab", null, "bac"])",
boolean(), "[true, false, true, null, false]", &options);
MatchSubstringOptions options_repeated{"(ab){2}"};
this->CheckUnary("match_substring_regex", R"(["abab", "ab", "cababc", null, "bac"])",
boolean(), "[true, false, true, null, false]", &options_repeated);
MatchSubstringOptions options_digit{"\\d"};
this->CheckUnary("match_substring_regex", R"(["aacb", "a2ab", "", "24"])", boolean(),
"[false, true, false, true]", &options_digit);
MatchSubstringOptions options_star{"a*b"};
this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
boolean(), "[true, true, true, true, true, false]", &options_star);
MatchSubstringOptions options_plus{"a+b"};
this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])",
boolean(), "[false, true, true, true, false, false]", &options_plus);
}
#endif

TYPED_TEST(TestStringKernels, SplitBasics) {
SplitPatternOptions options{" "};
// basics
Expand Down
31 changes: 18 additions & 13 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -522,26 +522,31 @@ These functions trim off characters on both sides (trim), or the left (ltrim) or
Containment tests
~~~~~~~~~~~~~~~~~

+--------------------+------------+------------------------------------+---------------+----------------------------------------+
| Function name | Arity | Input types | Output type | Options class |
+====================+============+====================================+===============+========================================+
| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` |
+--------------------+------------+------------------------------------+---------------+----------------------------------------+
| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (2) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+--------------------+------------+------------------------------------+---------------+----------------------------------------+
| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (3) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+--------------------+------------+------------------------------------+---------------+----------------------------------------+
+---------------------------+------------+------------------------------------+---------------+----------------------------------------+
| Function name | Arity | Input types | Output type | Options class |
+===========================+============+====================================+===============+========================================+
| match_substring | Unary | String-like | Boolean (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+---------------+----------------------------------------+
| match_substring_regex | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+---------------+----------------------------------------+
| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (3) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+---------------+----------------------------------------+
| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (4) | :struct:`SetLookupOptions` |
| | | Binary- and String-like | | |
+---------------------------+------------+------------------------------------+---------------+----------------------------------------+

* \(1) Output is true iff :member:`MatchSubstringOptions::pattern`
is a substring of the corresponding input element.

* \(2) Output is the index of the corresponding input element in
* \(2) Output is true iff :member:`MatchSubstringOptions::pattern`
matches the corresponding input element at any position.

* \(3) Output is the index of the corresponding input element in
:member:`SetLookupOptions::value_set`, if found there. Otherwise,
output is null.

* \(3) Output is true iff the corresponding input element is equal to one
* \(4) Output is true iff the corresponding input element is equal to one
of the elements in :member:`SetLookupOptions::value_set`.


Expand Down
1 change: 1 addition & 0 deletions docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ Containment tests
index_in
is_in
match_substring
match_substring_regex

Conversions
-----------
Expand Down
18 changes: 18 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,24 @@ def match_substring(array, pattern):
MatchSubstringOptions(pattern))


def match_substring_regex(array, pattern):
"""
Test if regex *pattern* matches at any position a value of a string array.

Parameters
----------
array : pyarrow.Array or pyarrow.ChunkedArray
pattern : str
regex pattern to search

Returns
-------
result : pyarrow.Array or pyarrow.ChunkedArray
"""
return call_function("match_substring_regex", [array],
MatchSubstringOptions(pattern))


def sum(array):
"""
Sum the values in a numerical (chunked) array.
Expand Down
7 changes: 7 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,13 @@ def test_match_substring():
assert expected.equals(result)


def test_match_substring_regex():
arr = pa.array(["ab", "abc", "ba", "c", None])
result = pc.match_substring_regex(arr, "^a?b")
expected = pa.array([True, True, True, False, None])
assert expected.equals(result)


def test_trim():
# \u3000 is unicode whitespace
arr = pa.array([" foo", None, " \u3000foo bar \t"])
Expand Down