Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 90 additions & 13 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,13 +980,70 @@ struct CountSubstring {
}
};

#ifdef ARROW_WITH_RE2
struct CountSubstringRegex {
std::unique_ptr<RE2> regex_match_;

explicit CountSubstringRegex(const MatchSubstringOptions& options, bool literal = false)
: regex_match_(new RE2(options.pattern,
RegexSubstringMatcher::MakeRE2Options(options, literal))) {}

static Result<CountSubstringRegex> Make(const MatchSubstringOptions& options,
bool literal = false) {
CountSubstringRegex counter(options, literal);
RETURN_NOT_OK(RegexStatus(*counter.regex_match_));
return std::move(counter);
}

template <typename OutValue, typename... Ignored>
OutValue Call(KernelContext*, util::string_view val, Status*) const {
OutValue count = 0;
re2::StringPiece input(val.data(), val.size());
auto last_size = input.size();
while (re2::RE2::FindAndConsume(&input, *regex_match_)) {
count++;
if (last_size == input.size()) {
// 0-length match
if (input.size() > 0) {
input.remove_prefix(1);
} else {
break;
}
}
last_size = input.size();
}
return count;
}
};

template <typename InputType>
struct CountSubstringRegexExec {
using OffsetType = typename TypeTraits<InputType>::OffsetType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto counter, CountSubstringRegex::Make(options));
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstringRegex>
kernel{std::move(counter)};
return kernel.Exec(ctx, batch, out);
}
};
#endif

template <typename InputType>
struct CountSubstringExec {
using OffsetType = typename TypeTraits<InputType>::OffsetType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
if (options.ignore_case) {
return Status::NotImplemented("count_substring with ignore_case");
#ifdef ARROW_WITH_RE2
ARROW_ASSIGN_OR_RAISE(auto counter,
CountSubstringRegex::Make(options, /*literal=*/true));
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstringRegex>
kernel{std::move(counter)};
return kernel.Exec(ctx, batch, out);
#else
return Status::NotImplemented("ignore_case requires RE2");
#endif
}
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstring> kernel{
CountSubstring(PlainSubstringMatcher(options))};
Expand All @@ -1001,21 +1058,41 @@ const FunctionDoc count_substring_doc(
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");

#ifdef ARROW_WITH_RE2
const FunctionDoc count_substring_regex_doc(
"Count occurrences of substring",
("For each string in `strings`, emit the number of occurrences of the given "
"regex pattern.\n"
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
{"strings"}, "MatchSubstringOptions");
#endif

void AddCountSubstring(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("count_substring", Arity::Unary(),
&count_substring_doc);
for (const auto& ty : BaseBinaryTypes()) {
std::shared_ptr<DataType> offset_type;
if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
offset_type = int64();
} else {
offset_type = int32();
{
auto func = std::make_shared<ScalarFunction>("count_substring", Arity::Unary(),
&count_substring_doc);
for (const auto& ty : BaseBinaryTypes()) {
auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
DCHECK_OK(func->AddKernel({ty}, offset_type,
GenerateTypeAgnosticVarBinaryBase<CountSubstringExec>(ty),
MatchSubstringState::Init));
}
DCHECK_OK(func->AddKernel({ty}, offset_type,
GenerateTypeAgnosticVarBinaryBase<CountSubstringExec>(ty),
MatchSubstringState::Init));
DCHECK_OK(registry->AddFunction(std::move(func)));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
#ifdef ARROW_WITH_RE2
{
auto func = std::make_shared<ScalarFunction>("count_substring_regex", Arity::Unary(),
&count_substring_regex_doc);
for (const auto& ty : BaseBinaryTypes()) {
auto offset_type = offset_bit_width(ty->id()) == 64 ? int64() : int32();
DCHECK_OK(
func->AddKernel({ty}, offset_type,
GenerateTypeAgnosticVarBinaryBase<CountSubstringRegexExec>(ty),
MatchSubstringState::Init));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
#endif
}

// Slicing
Expand Down
143 changes: 100 additions & 43 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,49 +80,7 @@ TYPED_TEST(TestBinaryKernels, BinaryLength) {
this->offset_type(), "[3, null, 10, 0, 1]");
}

TYPED_TEST(TestBinaryKernels, FindSubstring) {
MatchSubstringOptions options{"ab"};
this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
this->offset_type(), "[0, -1, 1, null, -1]", &options);

MatchSubstringOptions options_repeated{"abab"};
this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);

MatchSubstringOptions options_double_char{"aab"};
this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);

MatchSubstringOptions options_double_char_2{"bbcaa"};
this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
"[7]", &options_double_char_2);

MatchSubstringOptions options_empty{""};
this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(),
"[0, 0, null]", &options_empty);
}

TYPED_TEST(TestBinaryKernels, CountSubstring) {
MatchSubstringOptions options{"aba"};
this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
this->CheckUnary(
"count_substring",
R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);

MatchSubstringOptions options_empty{""};
this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
"[1, null, 4]", &options_empty);

MatchSubstringOptions options_repeated{"aaa"};
this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);

// TODO: case-insensitive
}

TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) {
TYPED_TEST(TestBinaryKernels, BinaryReplaceSlice) {
ReplaceSliceOptions options{0, 1, "XX"};
this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options);
this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])",
Expand Down Expand Up @@ -172,6 +130,105 @@ TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) {
&options_neg_flip);
}

TYPED_TEST(TestBinaryKernels, FindSubstring) {
MatchSubstringOptions options{"ab"};
this->CheckUnary("find_substring", "[]", this->offset_type(), "[]", &options);
this->CheckUnary("find_substring", R"(["abc", "acb", "cab", null, "bac"])",
this->offset_type(), "[0, -1, 1, null, -1]", &options);

MatchSubstringOptions options_repeated{"abab"};
this->CheckUnary("find_substring", R"(["abab", "ab", "cababc", null, "bac"])",
this->offset_type(), "[0, -1, 1, null, -1]", &options_repeated);

MatchSubstringOptions options_double_char{"aab"};
this->CheckUnary("find_substring", R"(["aacb", "aab", "ab", "aaab"])",
this->offset_type(), "[-1, 0, -1, 1]", &options_double_char);

MatchSubstringOptions options_double_char_2{"bbcaa"};
this->CheckUnary("find_substring", R"(["abcbaabbbcaabccabaab"])", this->offset_type(),
"[7]", &options_double_char_2);

MatchSubstringOptions options_empty{""};
this->CheckUnary("find_substring", R"(["", "a", null])", this->offset_type(),
"[0, 0, null]", &options_empty);
}

TYPED_TEST(TestBinaryKernels, CountSubstring) {
MatchSubstringOptions options{"aba"};
this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
this->CheckUnary(
"count_substring",
R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);

MatchSubstringOptions options_empty{""};
this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
"[1, null, 4]", &options_empty);

MatchSubstringOptions options_repeated{"aaa"};
this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
}

#ifdef ARROW_WITH_RE2
TYPED_TEST(TestBinaryKernels, CountSubstringRegex) {
MatchSubstringOptions options{"aba"};
this->CheckUnary("count_substring_regex", "[]", this->offset_type(), "[]", &options);
this->CheckUnary(
"count_substring",
R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);

MatchSubstringOptions options_empty{""};
this->CheckUnary("count_substring_regex", R"(["", null, "abc"])", this->offset_type(),
"[1, null, 4]", &options_empty);

MatchSubstringOptions options_as{"a+"};
this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
this->offset_type(), "[0, 3, 0, 0]", &options_as);

MatchSubstringOptions options_empty_match{"a*"};
this->CheckUnary("count_substring_regex", R"(["", "bacaaadaaaa", "c", "AAA"])",
// 7 is because it matches at |b|a|c|aaa|d|aaaa|
this->offset_type(), "[1, 7, 2, 4]", &options_empty_match);

MatchSubstringOptions options_repeated{"aaa"};
this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
}

TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
MatchSubstringOptions options{"aba", /*ignore_case=*/true};
this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
this->CheckUnary(
"count_substring",
R"(["", null, "ab", "aBa", "bAbA", "aBaBa", "abaAbA", "babacaba", "ABA"])",
this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 1]", &options);

MatchSubstringOptions options_empty{"", /*ignore_case=*/true};
this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
"[1, null, 4]", &options_empty);
}

TYPED_TEST(TestBinaryKernels, CountSubstringRegexIgnoreCase) {
MatchSubstringOptions options_as{"a+", /*ignore_case=*/true};
this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
this->offset_type(), "[0, 3, 0, 1]", &options_as);

MatchSubstringOptions options_empty_match{"a*", /*ignore_case=*/true};
this->CheckUnary("count_substring_regex", R"(["", "bacAaAdaAaA", "c", "AAA"])",
this->offset_type(), "[1, 7, 2, 2]", &options_empty_match);
}
#else
TYPED_TEST(TestBinaryKernels, CountSubstringIgnoreCase) {
Datum input = ArrayFromJSON(this->type(), R"(["a"])");
MatchSubstringOptions options{"a", /*ignore_case=*/true};
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
::testing::HasSubstr("ignore_case requires RE2"),
CallFunction("count_substring", {input}, &options));
}
#endif

template <typename TestType>
class TestStringKernels : public BaseTestStringKernels<TestType> {};

Expand Down
2 changes: 2 additions & 0 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,8 @@ Containment tests
+===========================+============+====================================+====================+========================================+
| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| count_substring_regex | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| ends_with | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
| find_substring | Unary | String-like | Int32 or Int64 (3) | :struct:`MatchSubstringOptions` |
Expand Down
1 change: 1 addition & 0 deletions docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Containment tests
:toctree: ../generated/

count_substring
count_substring_regex
ends_with
find_substring
index_in
Expand Down
23 changes: 21 additions & 2 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def cast(arr, target_type, safe=True):
return call_function("cast", [arr], options)


def count_substring(array, pattern):
def count_substring(array, pattern, *, ignore_case=False):
"""
Count the occurrences of substring *pattern* in each value of a
string array.
Expand All @@ -308,7 +308,26 @@ def count_substring(array, pattern):
result : pyarrow.Array or pyarrow.ChunkedArray
"""
return call_function("count_substring", [array],
MatchSubstringOptions(pattern))
MatchSubstringOptions(pattern, ignore_case))


def count_substring_regex(array, pattern, *, ignore_case=False):
"""
Count the non-overlapping matches of regex *pattern* in each value
of a string array.

Parameters
----------
array : pyarrow.Array or pyarrow.ChunkedArray
pattern : str
pattern to search for exact matches

Returns
-------
result : pyarrow.Array or pyarrow.ChunkedArray
"""
return call_function("count_substring_regex", [array],
MatchSubstringOptions(pattern, ignore_case))


def find_substring(array, pattern):
Expand Down
33 changes: 24 additions & 9 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,31 @@ def test_variance():


def test_count_substring():
arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None])
result = pc.count_substring(arr, "ab")
expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32())
assert expected.equals(result)
for (ty, offset) in [(pa.string(), pa.int32()),
(pa.large_string(), pa.int64())]:
arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None], type=ty)

arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None],
type=pa.large_string())
result = pc.count_substring(arr, "ab")
expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64())
assert expected.equals(result)
result = pc.count_substring(arr, "ab")
expected = pa.array([1, 1, 2, 0, 0, None], type=offset)
assert expected.equals(result)

result = pc.count_substring(arr, "ab", ignore_case=True)
expected = pa.array([1, 1, 2, 0, 1, None], type=offset)
assert expected.equals(result)


def test_count_substring_regex():
for (ty, offset) in [(pa.string(), pa.int32()),
(pa.large_string(), pa.int64())]:
arr = pa.array(["ab", "cab", "baAacaa", "ba", "AB", None], type=ty)

result = pc.count_substring_regex(arr, "a+")
expected = pa.array([1, 1, 3, 1, 0, None], type=offset)
assert expected.equals(result)

result = pc.count_substring_regex(arr, "a+", ignore_case=True)
expected = pa.array([1, 1, 2, 1, 1, None], type=offset)
assert expected.equals(result)


def test_find_substring():
Expand Down