diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 6032f656c4a..53892ff6b3c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -71,7 +71,9 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, int64_t max_replacements = -1) - : pattern(pattern), replacement(replacement), max_replacements(max_replacements) {} + : pattern(std::move(pattern)), + replacement(std::move(replacement)), + max_replacements(max_replacements) {} /// Pattern to match, literal, or regular expression depending on which kernel is used std::string pattern; @@ -81,6 +83,13 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { int64_t max_replacements; }; +struct ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { + explicit ExtractRegexOptions(std::string pattern) : pattern(std::move(pattern)) {} + + /// Regular expression with named capture fields + std::string pattern; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 9ec1fe005d4..d5473749fe1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -17,6 +17,7 @@ #include #include +#include #include #ifdef ARROW_WITH_UTF8PROC @@ -30,17 +31,40 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" + +#include "arrow/builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/utf8.h" #include "arrow/util/value_parsing.h" namespace arrow { + +using internal::checked_cast; + namespace compute { namespace internal { namespace { +#ifdef ARROW_WITH_RE2 +util::string_view ToStringView(re2::StringPiece piece) { + return {piece.data(), piece.length()}; +} + +re2::StringPiece ToStringPiece(util::string_view view) { + return {view.data(), view.length()}; +} + +Status RegexStatus(const RE2& regex) { + if (!regex.ok()) { + return Status::Invalid("Invalid regular expression: ", regex.error()); + } + return Status::OK(); +} +#endif + // Code units in the range [a-z] can only be an encoding of an ascii // character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of an different // codepoint. This guaranteed by non-overlap design of the unicode standard. (see @@ -449,10 +473,8 @@ struct RegexSubstringMatcher { const RE2 regex_match_; RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options) - : options_(options), regex_match_(options_.pattern) { - if (!regex_match_.ok()) { - ctx->SetStatus(Status::Invalid("Regular expression error")); - } + : options_(options), regex_match_(options_.pattern, RE2::Quiet) { + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_match_)); } bool Match(util::string_view current) { @@ -1390,16 +1412,21 @@ struct RegexSubStringReplacer { // we have 2 regexes, one with () around it, one without. RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) : options_(options), - regex_find_("(" + options_.pattern + ")"), - regex_replacement_(options_.pattern) { - if (!(regex_find_.ok() && regex_replacement_.ok())) { - ctx->SetStatus(Status::Invalid("Regular expression error")); - return; + regex_find_("(" + options_.pattern + ")", RE2::Quiet), + regex_replacement_(options_.pattern, RE2::Quiet) { + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_find_)); + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_replacement_)); + std::string replacement_error; + if (!regex_replacement_.CheckRewriteString(options_.replacement, + &replacement_error)) { + ctx->SetStatus( + Status::Invalid("Invalid replacement string: ", std::move(replacement_error))); } } Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { re2::StringPiece replacement(options_.replacement); + if (options_.max_replacements == -1) { std::string s_copy(s.to_string()); re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); @@ -1472,6 +1499,204 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Extract with regex + +#ifdef ARROW_WITH_RE2 + +// TODO cache this once per ExtractRegexOptions +struct ExtractRegexData { + // Use unique_ptr<> because RE2 is non-movable + std::unique_ptr regex; + std::vector group_names; + + static Result Make(const ExtractRegexOptions& options) { + ExtractRegexData data(options.pattern); + RETURN_NOT_OK(RegexStatus(*data.regex)); + + const int group_count = data.regex->NumberOfCapturingGroups(); + const auto& name_map = data.regex->CapturingGroupNames(); + data.group_names.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + // XXX should we instead just create fields with an empty name? + return Status::Invalid("Regular expression contains unnamed groups"); + } + data.group_names.emplace_back(item->second); + } + return std::move(data); + } + + Result ResolveOutputType(const std::vector& args) const { + const auto& input_type = args[0].type; + if (input_type == nullptr) { + // No input type specified => propagate shape + return args[0]; + } + // Input type is either String or LargeString and is also the type of each + // field in the output struct type. + DCHECK(input_type->id() == Type::STRING || input_type->id() == Type::LARGE_STRING); + FieldVector fields; + fields.reserve(group_names.size()); + std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields), + [&](const std::string& name) { return field(name, input_type); }); + return struct_(std::move(fields)); + } + + private: + explicit ExtractRegexData(const std::string& pattern) + : regex(new RE2(pattern, RE2::Quiet)) {} +}; + +Result ResolveExtractRegexOutput(KernelContext* ctx, + const std::vector& args) { + using State = OptionsWrapper; + ExtractRegexOptions options = State::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options)); + return data.ResolveOutputType(args); +} + +struct ExtractRegexBase { + const ExtractRegexData& data; + const int group_count; + std::vector found_values; + std::vector args; + std::vector args_pointers; + const re2::RE2::Arg** args_pointers_start; + const re2::RE2::Arg* null_arg = nullptr; + + explicit ExtractRegexBase(const ExtractRegexData& data) + : data(data), + group_count(static_cast(data.group_names.size())), + found_values(group_count) { + args.reserve(group_count); + args_pointers.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + args.emplace_back(&found_values[i]); + // Since we reserved capacity, we're guaranteed the pointer remains valid + args_pointers.push_back(&args[i]); + } + // Avoid null pointer if there is no capture group + args_pointers_start = (group_count > 0) ? args_pointers.data() : &null_arg; + } + + bool Match(util::string_view s) { + return re2::RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start, + group_count); + } +}; + +template +struct ExtractRegex : public ExtractRegexBase { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using State = OptionsWrapper; + + using ExtractRegexBase::ExtractRegexBase; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ExtractRegexOptions options = State::Get(ctx); + KERNEL_ASSIGN_OR_RAISE(auto data, ctx, ExtractRegexData::Make(options)); + ExtractRegex{data}.Extract(ctx, batch, out); + } + + void Extract(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, + data.ResolveOutputType(batch.GetDescriptors())); + DCHECK_NE(descr.type, nullptr); + const auto& type = descr.type; + + if (batch[0].kind() == Datum::ARRAY) { + std::unique_ptr array_builder; + KERNEL_RETURN_IF_ERROR(ctx, MakeBuilder(ctx->memory_pool(), type, &array_builder)); + StructBuilder* struct_builder = checked_cast(array_builder.get()); + + std::vector field_builders; + field_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + field_builders.push_back( + checked_cast(struct_builder->field_builder(i))); + } + + auto visit_null = [&]() { + for (int i = 0; i < group_count; i++) { + RETURN_NOT_OK(field_builders[i]->AppendEmptyValue()); + } + return struct_builder->AppendNull(); + }; + auto visit_value = [&](util::string_view s) { + if (Match(s)) { + for (int i = 0; i < group_count; i++) { + RETURN_NOT_OK(field_builders[i]->Append(ToStringView(found_values[i]))); + } + return struct_builder->Append(); + } else { + return visit_null(); + } + }; + const ArrayData& input = *batch[0].array(); + KERNEL_RETURN_IF_ERROR(ctx, + VisitArrayDataInline(input, visit_value, visit_null)); + + std::shared_ptr out_array; + KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&out_array)); + *out = std::move(out_array); + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = std::make_shared(type); + if (input.is_valid && Match(util::string_view(*input.value))) { + result->value.reserve(group_count); + for (int i = 0; i < group_count; i++) { + result->value.push_back( + std::make_shared(found_values[i].as_string())); + } + result->is_valid = true; + } else { + result->is_valid = false; + } + out->value = std::move(result); + } + } +}; + +const FunctionDoc extract_regex_doc( + "Extract substrings captured by a regex pattern", + ("For each string in `strings`, match the regular expression and, if\n" + "successful, emit a struct with field names and values coming from the\n" + "regular expression's named capture groups. If the input is null or the\n" + "regular expression fails matching, a null output value is emitted.\n" + "\n" + "Regular expression matching is done using the Google RE2 library."), + {"strings"}, "ExtractRegexOptions"); + +void AddExtractRegex(FunctionRegistry* registry) { + auto func = std::make_shared("extract_regex", Arity::Unary(), + &extract_regex_doc); + using t32 = ExtractRegex; + using t64 = ExtractRegex; + OutputType out_ty(ResolveExtractRegexOutput); + ScalarKernel kernel; + + // Null values will be computed based on regex match or not + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + kernel.signature.reset(new KernelSignature({utf8()}, out_ty)); + kernel.exec = t32::Exec; + kernel.init = t32::State::Init; + DCHECK_OK(func->AddKernel(kernel)); + kernel.signature.reset(new KernelSignature({large_utf8()}, out_ty)); + kernel.exec = t64::Exec; + kernel.init = t64::State::Init; + DCHECK_OK(func->AddKernel(kernel)); + + DCHECK_OK(registry->AddFunction(std::move(func))); +} +#endif // ARROW_WITH_RE2 + // ---------------------------------------------------------------------- // strptime string parsing @@ -2153,6 +2378,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringBatchKernelWithState( "replace_substring_regex", registry, &replace_substring_regex_doc, MemAllocation::NO_PREALLOCATE); + AddExtractRegex(registry); #endif AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 2dd0a4d8c74..577493913b5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -17,6 +17,7 @@ #include +#include #include #ifdef ARROW_WITH_UTF8PROC @@ -366,6 +367,26 @@ TYPED_TEST(TestStringKernels, MatchSubstringRegex) { MatchSubstringOptions options_plus{"a+b"}; this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])", boolean(), "[false, true, true, true, false, false]", &options_plus); + + // Unicode character semantics + // "\pL" means: unicode category "letter" + // (re2 interprets "\w" as ASCII-only: https://github.com/google/re2/wiki/Syntax) + MatchSubstringOptions options_unicode{"^\\pL+$"}; + this->CheckUnary("match_substring_regex", R"(["été", "ß", "€", ""])", boolean(), + "[true, true, false, false]", &options_unicode); +} + +TYPED_TEST(TestStringKernels, MatchSubstringRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("match_substring_regex", {input})); +} + +TYPED_TEST(TestStringKernels, MatchSubstringRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), "[null]"); + MatchSubstringOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("match_substring_regex", {input}, &options)); } #endif @@ -495,6 +516,63 @@ TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) { Datum input = ArrayFromJSON(this->type(), "[]"); ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input})); } + +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), R"(["foo"])"); + ReplaceSubstringOptions options{"invalid[", ""}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("replace_substring_regex", {input}, &options)); + + // Capture group number out of range + options = ReplaceSubstringOptions{"(.)", "\\9"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid replacement string"), + CallFunction("replace_substring_regex", {input}, &options)); +} + +TYPED_TEST(TestStringKernels, ExtractRegex) { + ExtractRegexOptions options{"(?P[ab])(?P\\d)"}; + auto type = struct_({field("letter", this->type()), field("digit", this->type())}); + this->CheckUnary("extract_regex", R"([])", type, R"([])", &options); + this->CheckUnary( + "extract_regex", R"(["a1", "b2", "c3", null])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}, null, null])", + &options); + this->CheckUnary("extract_regex", R"(["a1", "b2"])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}])", + &options); + this->CheckUnary("extract_regex", R"(["a1", "zb3z"])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", + &options); +} + +TYPED_TEST(TestStringKernels, ExtractRegexNoCapture) { + // XXX Should we accept this or is it a user error? + ExtractRegexOptions options{"foo"}; + auto type = struct_({}); + this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type, + R"([{}, null, null])", &options); +} + +TYPED_TEST(TestStringKernels, ExtractRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input})); +} + +TYPED_TEST(TestStringKernels, ExtractRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ExtractRegexOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("extract_regex", {input}, &options)); + + options = ExtractRegexOptions{"(.)"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), + CallFunction("extract_regex", {input}, &options)); +} + #endif TYPED_TEST(TestStringKernels, Strptime) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index a8a0c8b95f3..11d5e76d342 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -118,7 +118,7 @@ void CheckScalar(std::string func_name, const ArrayVector& inputs, expected->Slice(2 * slice_length), options); } - // should also work with an empty slice + // Should also work with an empty slice CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0), options); diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 92ac8886f87..fb50f8cef65 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -580,6 +580,21 @@ when a positive ``max_splits`` is given. as separator. +String extraction +~~~~~~~~~~~~~~~~~ + ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ +| Function name | Arity | Input types | Output type | Options class | ++====================+============+====================================+===============+========================================+ +| extract_regex | Unary | String-like | Struct (1) | :struct:`ExtractRegexOptions` | ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ + +* \(1) Extract substrings defined by a regular expression using the Google RE2 + library. The output struct field names refer to the named capture groups, + e.g. 'letter' and 'digit' for the regular expression + ``(?P[ab])(?P\\d)``. + + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 1515bdcfd36..3af485343f2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -704,6 +704,23 @@ class ReplaceSubstringOptions(_ReplaceSubstringOptions): self._set_options(pattern, replacement, max_replacements) +cdef class _ExtractRegexOptions(FunctionOptions): + cdef: + unique_ptr[CExtractRegexOptions] extract_regex_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.extract_regex_options.get() + + def _set_options(self, pattern): + self.extract_regex_options.reset( + new CExtractRegexOptions(tobytes(pattern))) + + +class ExtractRegexOptions(_ExtractRegexOptions): + def __init__(self, pattern): + self._set_options(pattern) + + cdef class _FilterOptions(FunctionOptions): cdef: unique_ptr[CFilterOptions] filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 3928b9cb904..ec38710b023 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -33,6 +33,7 @@ CastOptions, CountOptions, DictionaryEncodeOptions, + ExtractRegexOptions, FilterOptions, MatchSubstringOptions, MinMaxOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index ebdcd08334c..45f7c4fee94 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1823,6 +1823,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_string replacement int64_t max_replacements + cdef cppclass CExtractRegexOptions \ + "arrow::compute::ExtractRegexOptions"(CFunctionOptions): + CExtractRegexOptions(c_string pattern) + c_string pattern + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 94a6189f41c..5ad0d2db91b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -598,6 +598,13 @@ def test_replace_regex(): assert ar.tolist() == ['f00', 'm00d', None] +def test_extract_regex(): + ar = pa.array(['a1', 'zb2z']) + struct = pc.extract_regex(ar, pattern=r'(?P[ab])(?P\d)') + assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, { + 'letter': 'b', 'digit': '2'}] + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty)