From 6648df07fe3accfabd079be1ceae6a4413b38156 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Wed, 14 Oct 2020 14:06:51 +0200 Subject: [PATCH 1/6] ARROW-10195: [C++] Add string struct extract kernel using re2 --- cpp/src/arrow/compute/api_scalar.h | 7 + .../arrow/compute/kernels/scalar_string.cc | 130 ++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 12 ++ docs/source/cpp/compute.rst | 14 ++ python/pyarrow/_compute.pyx | 17 +++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 5 + python/pyarrow/tests/test_compute.py | 7 + 8 files changed, 193 insertions(+) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 6032f656c4a..4c3e5b39614 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -81,6 +81,13 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { int64_t max_replacements; }; +struct ARROW_EXPORT RE2Options : public FunctionOptions { + explicit RE2Options(std::string regex) : regex(regex) {} + + /// Regular expression + std::string regex; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 9ec1fe005d4..e51da4c5e12 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -30,6 +30,8 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" + +#include "arrow/builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/utf8.h" @@ -1472,6 +1474,131 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// re2 regex + +#ifdef ARROW_WITH_RE2 +template +struct ExtractRE2 { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using State = OptionsWrapper; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + RE2Options options = State::Get(ctx); + RE2 regex(options.regex); + + if (!regex.ok()) { + ctx->SetStatus(Status::Invalid("Regular expression error")); + return; + } + std::vector> fields; + int group_count = regex.NumberOfCapturingGroups(); + fields.reserve(group_count); + const std::map name_map = regex.CapturingGroupNames(); + + // We need to pass RE2 a Args* array, which all point to a std::string + std::vector found_values(group_count); + std::vector args; + std::vector args_pointers; + args.reserve(group_count); + args_pointers.reserve(group_count); + + for (int i = 0; i < group_count; i++) { + auto item = name_map.find(i + 1); // re2 starts counting from 1 + if (item == name_map.end()) { + ctx->SetStatus(Status::Invalid("Regular expression contains unnamed groups")); + return; + } + fields.emplace_back(new Field(item->second, batch[0].type())); + args.emplace_back(&found_values[i]); + // since we reserved capacity, we're guaranteed std::vector does not reallocate + // (which would cause the pointer to be invalid) + args_pointers.push_back(&args[i]); + } + auto type = struct_(fields); + + if (batch[0].kind() == Datum::ARRAY) { + std::unique_ptr array_builder_tmp; + MakeBuilder(ctx->memory_pool(), type, &array_builder_tmp); + std::shared_ptr struct_builder; + struct_builder.reset(checked_cast(array_builder_tmp.release())); + + const ArrayData& input = *batch[0].array(); + KERNEL_RETURN_IF_ERROR( + ctx, + VisitArrayDataInline( + input, + [&](util::string_view s) { + re2::StringPiece piece(s.data(), s.length()); + if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { + for (int i = 0; i < group_count; i++) { + BuilderType* builder = + static_cast(struct_builder->field_builder(i)); + RETURN_NOT_OK(builder->Append(found_values[i])); + } + RETURN_NOT_OK(struct_builder->Append()); + } else { + RETURN_NOT_OK(struct_builder->AppendNull()); + } + return Status::OK(); + }, + [&]() { + RETURN_NOT_OK(struct_builder->AppendNull()); + return Status::OK(); + })); + std::shared_ptr struct_array = + std::make_shared(out->array()); + KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&struct_array)); + *out = struct_array; + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = std::make_shared(type); + if (input.is_valid) { + util::string_view s = static_cast(*input.value); + re2::StringPiece piece(s.data(), s.length()); + if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { + for (int i = 0; i < group_count; i++) { + result->value.push_back(std::make_shared(found_values[i])); + } + result->is_valid = true; + } else { + result->is_valid = false; + } + } else { + result->is_valid = false; + } + out->value = result; + } + } +}; + +const FunctionDoc utf8_extract_re2_doc("Extract", ("Long.."), {"strings"}, "RE2Options"); + +void AddExtractRE2(FunctionRegistry* registry) { + auto func = std::make_shared("utf8_extract_re2", Arity::Unary(), + &utf8_extract_re2_doc); + using t32 = ExtractRE2; + using t64 = ExtractRE2; + ScalarKernel kernel; + // null values will be computed based on regex match or not + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + kernel.signature = KernelSignature::Make({utf8()}, {struct_({})}); + kernel.exec = t32::Exec; + kernel.init = t32::State::Init; + DCHECK_OK(func->AddKernel(kernel)); + kernel.signature = KernelSignature::Make({large_utf8()}, {struct_({})}); + kernel.exec = t64::Exec; + kernel.init = t64::State::Init; + DCHECK_OK(func->AddKernel(kernel)); + + DCHECK_OK(registry->AddFunction(std::move(func))); +} +void AddRE2(FunctionRegistry* registry) { AddExtractRE2(registry); } +#endif // ARROW_WITH_RE2 + // ---------------------------------------------------------------------- // strptime string parsing @@ -2143,6 +2270,9 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { #endif AddSplit(registry); +#ifdef ARROW_WITH_RE2 + AddRE2(registry); +#endif AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 2dd0a4d8c74..af4648b8376 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -495,6 +495,18 @@ TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) { Datum input = ArrayFromJSON(this->type(), "[]"); ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input})); } + +TYPED_TEST(TestStringKernels, ExtractRE2) { + RE2Options options{"(?P[ab])(?P\\d)"}; + auto type = struct_({field("letter", this->type()), field("digit", this->type())}); + this->CheckUnary( + "utf8_extract_re2", R"(["a1", "b2", "c3", null])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}, null, null])", + &options); + this->CheckUnary("utf8_extract_re2", R"(["a1", "b2"])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}])", + &options); +} #endif TYPED_TEST(TestStringKernels, Strptime) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 92ac8886f87..90e0ed13c75 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -579,6 +579,20 @@ when a positive ``max_splits`` is given. (``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen as separator. +String extraction +~~~~~~~~~~~~~~~~~ + ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ +| Function name | Arity | Input types | Output type | Options class | ++====================+============+====================================+===============+========================================+ +| utf8_extract_re2 | Unary | String-like | Struct (1) | :struct:`RE2Options` | ++--------------------+------------+------------------------------------+---------------+----------------------------------------+ + +* \(1) Extract substrings defined by a regular expression using the Google RE2 +library. Struct field names refer to the named groups, e.g. 'letter' and 'digit' +for following regular expression: '(?P[ab])(?P\\d)'. + + Structural transforms ~~~~~~~~~~~~~~~~~~~~~ diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 1515bdcfd36..22882dfe877 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -704,6 +704,23 @@ class ReplaceSubstringOptions(_ReplaceSubstringOptions): self._set_options(pattern, replacement, max_replacements) +cdef class _RE2Options(FunctionOptions): + cdef: + unique_ptr[CRE2Options] re2_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.re2_options.get() + + def _set_options(self, regex): + self.re2_options.reset( + new CRE2Options(tobytes(regex))) + + +class RE2Options(_RE2Options): + def __init__(self, regex): + self._set_options(regex) + + cdef class _FilterOptions(FunctionOptions): cdef: unique_ptr[CFilterOptions] filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 3928b9cb904..c26f060eba5 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -43,6 +43,7 @@ ProjectOptions, QuantileOptions, ReplaceSubstringOptions, + RE2Options, SetLookupOptions, SortOptions, StrptimeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index ebdcd08334c..aba21c1a9c7 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1823,6 +1823,11 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_string replacement int64_t max_replacements + cdef cppclass CRE2Options \ + "arrow::compute::RE2Options"(CFunctionOptions): + CRE2Options(c_string regex) + c_string regex + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 94a6189f41c..54e2bfc4ddf 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -598,6 +598,13 @@ def test_replace_regex(): assert ar.tolist() == ['f00', 'm00d', None] +def test_extract_re2(): + ar = pa.array(['a1', 'b2']) + struct = pc.utf8_extract_re2(ar, regex='(?P[ab])(?P\\d)') + assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, { + 'letter': 'b', 'digit': '2'}] + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty) From 886a16fe15cfe4580cdc8bbe904a2950ddf8d7f8 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 12 Mar 2021 14:29:00 +0100 Subject: [PATCH 2/6] skip zero slice for struct output --- cpp/src/arrow/compute/kernels/test_util.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index a8a0c8b95f3..7fb573cfc10 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -119,8 +119,11 @@ void CheckScalar(std::string func_name, const ArrayVector& inputs, } // should also work with an empty slice - CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0), - options); + // a zero slice will not call the kernel, so we cannot cuntruct an empty struct with + // fields + if (expected->type_id() != Type::STRUCT) + CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0), + options); // Ditto with ChunkedArray inputs if (slice_length > 0) { From 5cca63bf2d68c5ee984f0a15cb29409e8bc8f3e2 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 12 Mar 2021 14:29:27 +0100 Subject: [PATCH 3/6] attempt to fix failed test with ChunkedArray input --- cpp/src/arrow/compute/kernels/scalar_string.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index e51da4c5e12..1682ca0228b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1539,11 +1539,21 @@ struct ExtractRE2 { } RETURN_NOT_OK(struct_builder->Append()); } else { + for (int i = 0; i < group_count; i++) { + BuilderType* builder = + static_cast(struct_builder->field_builder(i)); + RETURN_NOT_OK(builder->Append("")); + } RETURN_NOT_OK(struct_builder->AppendNull()); } return Status::OK(); }, [&]() { + for (int i = 0; i < group_count; i++) { + BuilderType* builder = + static_cast(struct_builder->field_builder(i)); + RETURN_NOT_OK(builder->Append("")); + } RETURN_NOT_OK(struct_builder->AppendNull()); return Status::OK(); })); From 33e81920e25e5740eac4b208bd2ff1e4ea2f60e3 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 12 Mar 2021 16:17:36 +0100 Subject: [PATCH 4/6] do not ignore return value --- cpp/src/arrow/compute/kernels/scalar_string.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 1682ca0228b..387ba975cf0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1520,7 +1520,8 @@ struct ExtractRE2 { if (batch[0].kind() == Datum::ARRAY) { std::unique_ptr array_builder_tmp; - MakeBuilder(ctx->memory_pool(), type, &array_builder_tmp); + KERNEL_RETURN_IF_ERROR(ctx, + MakeBuilder(ctx->memory_pool(), type, &array_builder_tmp)); std::shared_ptr struct_builder; struct_builder.reset(checked_cast(array_builder_tmp.release())); From afff2282cf191d6a9b4cc4197d07b2784fc2f68f Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 20 Apr 2021 19:48:33 +0200 Subject: [PATCH 5/6] Various fixes and improvements --- cpp/src/arrow/compute/api_scalar.h | 10 +- .../arrow/compute/kernels/scalar_string.cc | 299 +++++++++++------- .../compute/kernels/scalar_string_test.cc | 74 ++++- cpp/src/arrow/compute/kernels/test_util.cc | 9 +- docs/source/cpp/compute.rst | 9 +- python/pyarrow/_compute.pyx | 8 +- python/pyarrow/compute.py | 2 +- python/pyarrow/includes/libarrow.pxd | 6 +- python/pyarrow/tests/test_compute.py | 6 +- 9 files changed, 287 insertions(+), 136 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 4c3e5b39614..f003bac42b4 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -71,7 +71,9 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, int64_t max_replacements = -1) - : pattern(pattern), replacement(replacement), max_replacements(max_replacements) {} + : pattern(std::move(pattern)), + replacement(std::move(replacement)), + max_replacements(max_replacements) {} /// Pattern to match, literal, or regular expression depending on which kernel is used std::string pattern; @@ -81,10 +83,10 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { int64_t max_replacements; }; -struct ARROW_EXPORT RE2Options : public FunctionOptions { - explicit RE2Options(std::string regex) : regex(regex) {} +struct ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { + explicit ExtractRegexOptions(std::string regex) : regex(std::move(regex)) {} - /// Regular expression + /// Regular expression with named capture fields std::string regex; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 387ba975cf0..54135f0312b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -17,6 +17,7 @@ #include #include +#include #include #ifdef ARROW_WITH_UTF8PROC @@ -34,15 +35,36 @@ #include "arrow/builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/utf8.h" #include "arrow/util/value_parsing.h" namespace arrow { + +using internal::checked_cast; + namespace compute { namespace internal { namespace { +#ifdef ARROW_WITH_RE2 +util::string_view ToStringView(re2::StringPiece piece) { + return {piece.data(), piece.length()}; +} + +re2::StringPiece ToStringPiece(util::string_view view) { + return {view.data(), view.length()}; +} + +Status RegexStatus(const RE2& regex) { + if (!regex.ok()) { + return Status::Invalid("Invalid regular expression: ", regex.error()); + } + return Status::OK(); +} +#endif + // Code units in the range [a-z] can only be an encoding of an ascii // character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of an different // codepoint. This guaranteed by non-overlap design of the unicode standard. (see @@ -451,10 +473,8 @@ struct RegexSubstringMatcher { const RE2 regex_match_; RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options) - : options_(options), regex_match_(options_.pattern) { - if (!regex_match_.ok()) { - ctx->SetStatus(Status::Invalid("Regular expression error")); - } + : options_(options), regex_match_(options_.pattern, RE2::Quiet) { + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_match_)); } bool Match(util::string_view current) { @@ -1392,16 +1412,21 @@ struct RegexSubStringReplacer { // we have 2 regexes, one with () around it, one without. RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) : options_(options), - regex_find_("(" + options_.pattern + ")"), - regex_replacement_(options_.pattern) { - if (!(regex_find_.ok() && regex_replacement_.ok())) { - ctx->SetStatus(Status::Invalid("Regular expression error")); - return; + regex_find_("(" + options_.pattern + ")", RE2::Quiet), + regex_replacement_(options_.pattern, RE2::Quiet) { + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_find_)); + KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_replacement_)); + std::string replacement_error; + if (!regex_replacement_.CheckRewriteString(options_.replacement, + &replacement_error)) { + ctx->SetStatus( + Status::Invalid("Invalid replacement string: ", std::move(replacement_error))); } } Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { re2::StringPiece replacement(options_.replacement); + if (options_.max_replacements == -1) { std::string s_copy(s.to_string()); re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); @@ -1474,140 +1499,202 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif -// re2 regex +// ---------------------------------------------------------------------- +// Extract with regex #ifdef ARROW_WITH_RE2 -template -struct ExtractRE2 { - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using BuilderType = typename TypeTraits::BuilderType; - using State = OptionsWrapper; - static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - RE2Options options = State::Get(ctx); - RE2 regex(options.regex); +// TODO cache this once per ExtractRegexOptions +struct ExtractRegexData { + // Use unique_ptr<> because RE2 is non-movable + std::unique_ptr regex; + std::vector group_names; - if (!regex.ok()) { - ctx->SetStatus(Status::Invalid("Regular expression error")); - return; - } - std::vector> fields; - int group_count = regex.NumberOfCapturingGroups(); - fields.reserve(group_count); - const std::map name_map = regex.CapturingGroupNames(); - - // We need to pass RE2 a Args* array, which all point to a std::string - std::vector found_values(group_count); - std::vector args; - std::vector args_pointers; - args.reserve(group_count); - args_pointers.reserve(group_count); + static Result Make(const ExtractRegexOptions& options) { + ExtractRegexData data(options.regex); + RETURN_NOT_OK(RegexStatus(*data.regex)); + + const int group_count = data.regex->NumberOfCapturingGroups(); + const auto& name_map = data.regex->CapturingGroupNames(); + data.group_names.reserve(group_count); for (int i = 0; i < group_count; i++) { auto item = name_map.find(i + 1); // re2 starts counting from 1 if (item == name_map.end()) { - ctx->SetStatus(Status::Invalid("Regular expression contains unnamed groups")); - return; + // XXX should we instead just create fields with an empty name? + return Status::Invalid("Regular expression contains unnamed groups"); } - fields.emplace_back(new Field(item->second, batch[0].type())); + data.group_names.emplace_back(item->second); + } + return std::move(data); + } + + Result ResolveOutputType(const std::vector& args) const { + const auto& input_type = args[0].type; + if (input_type == nullptr) { + // No input type specified => propagate shape + return args[0]; + } + // Input type is either String or LargeString and is also the type of each + // field in the output struct type. + DCHECK(input_type->id() == Type::STRING || input_type->id() == Type::LARGE_STRING); + FieldVector fields; + fields.reserve(group_names.size()); + std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields), + [&](const std::string& name) { return field(name, input_type); }); + return struct_(std::move(fields)); + } + + private: + explicit ExtractRegexData(const std::string& pattern) + : regex(new RE2(pattern, RE2::Quiet)) {} +}; + +Result ResolveExtractRegexOutput(KernelContext* ctx, + const std::vector& args) { + using State = OptionsWrapper; + ExtractRegexOptions options = State::Get(ctx); + ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options)); + return data.ResolveOutputType(args); +} + +struct ExtractRegexBase { + const ExtractRegexData& data; + const int group_count; + std::vector found_values; + std::vector args; + std::vector args_pointers; + const re2::RE2::Arg** args_pointers_start; + const re2::RE2::Arg* null_arg = nullptr; + + explicit ExtractRegexBase(const ExtractRegexData& data) + : data(data), + group_count(static_cast(data.group_names.size())), + found_values(group_count) { + args.reserve(group_count); + args_pointers.reserve(group_count); + + for (int i = 0; i < group_count; i++) { args.emplace_back(&found_values[i]); - // since we reserved capacity, we're guaranteed std::vector does not reallocate - // (which would cause the pointer to be invalid) + // Since we reserved capacity, we're guaranteed the pointer remains valid args_pointers.push_back(&args[i]); } - auto type = struct_(fields); + // Avoid null pointer if there is no capture group + args_pointers_start = (group_count > 0) ? args_pointers.data() : &null_arg; + } + + bool Match(util::string_view s) { + return re2::RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start, + group_count); + } +}; + +template +struct ExtractRegex : public ExtractRegexBase { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using BuilderType = typename TypeTraits::BuilderType; + using State = OptionsWrapper; + + using ExtractRegexBase::ExtractRegexBase; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + ExtractRegexOptions options = State::Get(ctx); + KERNEL_ASSIGN_OR_RAISE(auto data, ctx, ExtractRegexData::Make(options)); + ExtractRegex{data}.Extract(ctx, batch, out); + } + + void Extract(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + KERNEL_ASSIGN_OR_RAISE(auto descr, ctx, + data.ResolveOutputType(batch.GetDescriptors())); + DCHECK_NE(descr.type, nullptr); + const auto& type = descr.type; if (batch[0].kind() == Datum::ARRAY) { - std::unique_ptr array_builder_tmp; - KERNEL_RETURN_IF_ERROR(ctx, - MakeBuilder(ctx->memory_pool(), type, &array_builder_tmp)); - std::shared_ptr struct_builder; - struct_builder.reset(checked_cast(array_builder_tmp.release())); + std::unique_ptr array_builder; + KERNEL_RETURN_IF_ERROR(ctx, MakeBuilder(ctx->memory_pool(), type, &array_builder)); + StructBuilder* struct_builder = checked_cast(array_builder.get()); + + std::vector field_builders; + field_builders.reserve(group_count); + for (int i = 0; i < group_count; i++) { + field_builders.push_back( + checked_cast(struct_builder->field_builder(i))); + } + auto visit_null = [&]() { + for (int i = 0; i < group_count; i++) { + RETURN_NOT_OK(field_builders[i]->AppendEmptyValue()); + } + return struct_builder->AppendNull(); + }; + auto visit_value = [&](util::string_view s) { + if (Match(s)) { + for (int i = 0; i < group_count; i++) { + RETURN_NOT_OK(field_builders[i]->Append(ToStringView(found_values[i]))); + } + return struct_builder->Append(); + } else { + return visit_null(); + } + }; const ArrayData& input = *batch[0].array(); - KERNEL_RETURN_IF_ERROR( - ctx, - VisitArrayDataInline( - input, - [&](util::string_view s) { - re2::StringPiece piece(s.data(), s.length()); - if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { - for (int i = 0; i < group_count; i++) { - BuilderType* builder = - static_cast(struct_builder->field_builder(i)); - RETURN_NOT_OK(builder->Append(found_values[i])); - } - RETURN_NOT_OK(struct_builder->Append()); - } else { - for (int i = 0; i < group_count; i++) { - BuilderType* builder = - static_cast(struct_builder->field_builder(i)); - RETURN_NOT_OK(builder->Append("")); - } - RETURN_NOT_OK(struct_builder->AppendNull()); - } - return Status::OK(); - }, - [&]() { - for (int i = 0; i < group_count; i++) { - BuilderType* builder = - static_cast(struct_builder->field_builder(i)); - RETURN_NOT_OK(builder->Append("")); - } - RETURN_NOT_OK(struct_builder->AppendNull()); - return Status::OK(); - })); - std::shared_ptr struct_array = - std::make_shared(out->array()); - KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&struct_array)); - *out = struct_array; + KERNEL_RETURN_IF_ERROR(ctx, + VisitArrayDataInline(input, visit_value, visit_null)); + + std::shared_ptr out_array; + KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&out_array)); + *out = std::move(out_array); } else { const auto& input = checked_cast(*batch[0].scalar()); auto result = std::make_shared(type); - if (input.is_valid) { - util::string_view s = static_cast(*input.value); - re2::StringPiece piece(s.data(), s.length()); - if (re2::RE2::FullMatchN(piece, regex, &args_pointers[0], group_count)) { - for (int i = 0; i < group_count; i++) { - result->value.push_back(std::make_shared(found_values[i])); - } - result->is_valid = true; - } else { - result->is_valid = false; + if (input.is_valid && Match(util::string_view(*input.value))) { + result->value.reserve(group_count); + for (int i = 0; i < group_count; i++) { + result->value.push_back( + std::make_shared(found_values[i].as_string())); } + result->is_valid = true; } else { result->is_valid = false; } - out->value = result; + out->value = std::move(result); } } }; -const FunctionDoc utf8_extract_re2_doc("Extract", ("Long.."), {"strings"}, "RE2Options"); - -void AddExtractRE2(FunctionRegistry* registry) { - auto func = std::make_shared("utf8_extract_re2", Arity::Unary(), - &utf8_extract_re2_doc); - using t32 = ExtractRE2; - using t64 = ExtractRE2; +const FunctionDoc extract_regex_doc( + "Extract substrings captured by a regex pattern", + ("For each string in `strings`, match the regular expression and, if\n" + "successful, emit a struct with field names and values coming from the\n" + "regular expression's named capture groups. If the input is null or the\n" + "regular expression fails matching, a null output value is emitted.\n" + "\n" + "Regular expression matching is done using the Google RE2 library."), + {"strings"}, "ExtractRegexOptions"); + +void AddExtractRegex(FunctionRegistry* registry) { + auto func = std::make_shared("extract_regex", Arity::Unary(), + &extract_regex_doc); + using t32 = ExtractRegex; + using t64 = ExtractRegex; + OutputType out_ty(ResolveExtractRegexOutput); ScalarKernel kernel; - // null values will be computed based on regex match or not + + // Null values will be computed based on regex match or not kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; - - kernel.signature = KernelSignature::Make({utf8()}, {struct_({})}); + kernel.signature.reset(new KernelSignature({utf8()}, out_ty)); kernel.exec = t32::Exec; kernel.init = t32::State::Init; DCHECK_OK(func->AddKernel(kernel)); - kernel.signature = KernelSignature::Make({large_utf8()}, {struct_({})}); + kernel.signature.reset(new KernelSignature({large_utf8()}, out_ty)); kernel.exec = t64::Exec; kernel.init = t64::State::Init; DCHECK_OK(func->AddKernel(kernel)); DCHECK_OK(registry->AddFunction(std::move(func))); } -void AddRE2(FunctionRegistry* registry) { AddExtractRE2(registry); } #endif // ARROW_WITH_RE2 // ---------------------------------------------------------------------- @@ -2281,9 +2368,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { #endif AddSplit(registry); -#ifdef ARROW_WITH_RE2 - AddRE2(registry); -#endif AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); @@ -2294,6 +2378,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MakeUnaryStringBatchKernelWithState( "replace_substring_regex", registry, &replace_substring_regex_doc, MemAllocation::NO_PREALLOCATE); + AddExtractRegex(registry); #endif AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index af4648b8376..577493913b5 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -17,6 +17,7 @@ #include +#include #include #ifdef ARROW_WITH_UTF8PROC @@ -366,6 +367,26 @@ TYPED_TEST(TestStringKernels, MatchSubstringRegex) { MatchSubstringOptions options_plus{"a+b"}; this->CheckUnary("match_substring_regex", R"(["aacb", "aab", "dab", "caaab", "b", ""])", boolean(), "[false, true, true, true, false, false]", &options_plus); + + // Unicode character semantics + // "\pL" means: unicode category "letter" + // (re2 interprets "\w" as ASCII-only: https://github.com/google/re2/wiki/Syntax) + MatchSubstringOptions options_unicode{"^\\pL+$"}; + this->CheckUnary("match_substring_regex", R"(["été", "ß", "€", ""])", boolean(), + "[true, true, false, false]", &options_unicode); +} + +TYPED_TEST(TestStringKernels, MatchSubstringRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("match_substring_regex", {input})); +} + +TYPED_TEST(TestStringKernels, MatchSubstringRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), "[null]"); + MatchSubstringOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("match_substring_regex", {input}, &options)); } #endif @@ -496,17 +517,62 @@ TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) { ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input})); } -TYPED_TEST(TestStringKernels, ExtractRE2) { - RE2Options options{"(?P[ab])(?P\\d)"}; +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), R"(["foo"])"); + ReplaceSubstringOptions options{"invalid[", ""}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("replace_substring_regex", {input}, &options)); + + // Capture group number out of range + options = ReplaceSubstringOptions{"(.)", "\\9"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid replacement string"), + CallFunction("replace_substring_regex", {input}, &options)); +} + +TYPED_TEST(TestStringKernels, ExtractRegex) { + ExtractRegexOptions options{"(?P[ab])(?P\\d)"}; auto type = struct_({field("letter", this->type()), field("digit", this->type())}); + this->CheckUnary("extract_regex", R"([])", type, R"([])", &options); this->CheckUnary( - "utf8_extract_re2", R"(["a1", "b2", "c3", null])", type, + "extract_regex", R"(["a1", "b2", "c3", null])", type, R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}, null, null])", &options); - this->CheckUnary("utf8_extract_re2", R"(["a1", "b2"])", type, + this->CheckUnary("extract_regex", R"(["a1", "b2"])", type, R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "2"}])", &options); + this->CheckUnary("extract_regex", R"(["a1", "zb3z"])", type, + R"([{"letter": "a", "digit": "1"}, {"letter": "b", "digit": "3"}])", + &options); } + +TYPED_TEST(TestStringKernels, ExtractRegexNoCapture) { + // XXX Should we accept this or is it a user error? + ExtractRegexOptions options{"foo"}; + auto type = struct_({}); + this->CheckUnary("extract_regex", R"(["oofoo", "bar", null])", type, + R"([{}, null, null])", &options); +} + +TYPED_TEST(TestStringKernels, ExtractRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("extract_regex", {input})); +} + +TYPED_TEST(TestStringKernels, ExtractRegexInvalid) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ExtractRegexOptions options{"invalid["}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Invalid regular expression: missing ]"), + CallFunction("extract_regex", {input}, &options)); + + options = ExtractRegexOptions{"(.)"}; + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Regular expression contains unnamed groups"), + CallFunction("extract_regex", {input}, &options)); +} + #endif TYPED_TEST(TestStringKernels, Strptime) { diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 7fb573cfc10..11d5e76d342 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -118,12 +118,9 @@ void CheckScalar(std::string func_name, const ArrayVector& inputs, expected->Slice(2 * slice_length), options); } - // should also work with an empty slice - // a zero slice will not call the kernel, so we cannot cuntruct an empty struct with - // fields - if (expected->type_id() != Type::STRUCT) - CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0), - options); + // Should also work with an empty slice + CheckScalarNonRecursive(func_name, SliceAll(inputs, 0, 0), expected->Slice(0, 0), + options); // Ditto with ChunkedArray inputs if (slice_length > 0) { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 90e0ed13c75..fb50f8cef65 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -579,19 +579,20 @@ when a positive ``max_splits`` is given. (``'\t'``, ``'\n'``, ``'\v'``, ``'\f'``, ``'\r'`` and ``' '``) is seen as separator. + String extraction ~~~~~~~~~~~~~~~~~ +--------------------+------------+------------------------------------+---------------+----------------------------------------+ | Function name | Arity | Input types | Output type | Options class | +====================+============+====================================+===============+========================================+ -| utf8_extract_re2 | Unary | String-like | Struct (1) | :struct:`RE2Options` | +| extract_regex | Unary | String-like | Struct (1) | :struct:`ExtractRegexOptions` | +--------------------+------------+------------------------------------+---------------+----------------------------------------+ * \(1) Extract substrings defined by a regular expression using the Google RE2 -library. Struct field names refer to the named groups, e.g. 'letter' and 'digit' -for following regular expression: '(?P[ab])(?P\\d)'. - + library. The output struct field names refer to the named capture groups, + e.g. 'letter' and 'digit' for the regular expression + ``(?P[ab])(?P\\d)``. Structural transforms diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 22882dfe877..f0519ef36c6 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -704,19 +704,19 @@ class ReplaceSubstringOptions(_ReplaceSubstringOptions): self._set_options(pattern, replacement, max_replacements) -cdef class _RE2Options(FunctionOptions): +cdef class _ExtractRegexOptions(FunctionOptions): cdef: - unique_ptr[CRE2Options] re2_options + unique_ptr[CExtractRegexOptions] re2_options cdef const CFunctionOptions* get_options(self) except NULL: return self.re2_options.get() def _set_options(self, regex): self.re2_options.reset( - new CRE2Options(tobytes(regex))) + new CExtractRegexOptions(tobytes(regex))) -class RE2Options(_RE2Options): +class ExtractRegexOptions(_ExtractRegexOptions): def __init__(self, regex): self._set_options(regex) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index c26f060eba5..ec38710b023 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -33,6 +33,7 @@ CastOptions, CountOptions, DictionaryEncodeOptions, + ExtractRegexOptions, FilterOptions, MatchSubstringOptions, MinMaxOptions, @@ -43,7 +44,6 @@ ProjectOptions, QuantileOptions, ReplaceSubstringOptions, - RE2Options, SetLookupOptions, SortOptions, StrptimeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index aba21c1a9c7..3bc3c026d34 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1823,9 +1823,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_string replacement int64_t max_replacements - cdef cppclass CRE2Options \ - "arrow::compute::RE2Options"(CFunctionOptions): - CRE2Options(c_string regex) + cdef cppclass CExtractRegexOptions \ + "arrow::compute::ExtractRegexOptions"(CFunctionOptions): + CExtractRegexOptions(c_string regex) c_string regex cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 54e2bfc4ddf..1e2371b8363 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -598,9 +598,9 @@ def test_replace_regex(): assert ar.tolist() == ['f00', 'm00d', None] -def test_extract_re2(): - ar = pa.array(['a1', 'b2']) - struct = pc.utf8_extract_re2(ar, regex='(?P[ab])(?P\\d)') +def test_extract_regex(): + ar = pa.array(['a1', 'zb2z']) + struct = pc.extract_regex(ar, regex=r'(?P[ab])(?P\d)') assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, { 'letter': 'b', 'digit': '2'}] From 88ba0dc7b0771adae07c867bcf9b3f9b1b73bde1 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 21 Apr 2021 11:13:37 +0200 Subject: [PATCH 6/6] Use "pattern" --- cpp/src/arrow/compute/api_scalar.h | 4 ++-- cpp/src/arrow/compute/kernels/scalar_string.cc | 2 +- python/pyarrow/_compute.pyx | 14 +++++++------- python/pyarrow/includes/libarrow.pxd | 4 ++-- python/pyarrow/tests/test_compute.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index f003bac42b4..53892ff6b3c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -84,10 +84,10 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { }; struct ARROW_EXPORT ExtractRegexOptions : public FunctionOptions { - explicit ExtractRegexOptions(std::string regex) : regex(std::move(regex)) {} + explicit ExtractRegexOptions(std::string pattern) : pattern(std::move(pattern)) {} /// Regular expression with named capture fields - std::string regex; + std::string pattern; }; /// Options for IsIn and IndexIn functions diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 54135f0312b..d5473749fe1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -1511,7 +1511,7 @@ struct ExtractRegexData { std::vector group_names; static Result Make(const ExtractRegexOptions& options) { - ExtractRegexData data(options.regex); + ExtractRegexData data(options.pattern); RETURN_NOT_OK(RegexStatus(*data.regex)); const int group_count = data.regex->NumberOfCapturingGroups(); diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index f0519ef36c6..3af485343f2 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -706,19 +706,19 @@ class ReplaceSubstringOptions(_ReplaceSubstringOptions): cdef class _ExtractRegexOptions(FunctionOptions): cdef: - unique_ptr[CExtractRegexOptions] re2_options + unique_ptr[CExtractRegexOptions] extract_regex_options cdef const CFunctionOptions* get_options(self) except NULL: - return self.re2_options.get() + return self.extract_regex_options.get() - def _set_options(self, regex): - self.re2_options.reset( - new CExtractRegexOptions(tobytes(regex))) + def _set_options(self, pattern): + self.extract_regex_options.reset( + new CExtractRegexOptions(tobytes(pattern))) class ExtractRegexOptions(_ExtractRegexOptions): - def __init__(self, regex): - self._set_options(regex) + def __init__(self, pattern): + self._set_options(pattern) cdef class _FilterOptions(FunctionOptions): diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 3bc3c026d34..45f7c4fee94 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1825,8 +1825,8 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CExtractRegexOptions \ "arrow::compute::ExtractRegexOptions"(CFunctionOptions): - CExtractRegexOptions(c_string regex) - c_string regex + CExtractRegexOptions(c_string pattern) + c_string pattern cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 1e2371b8363..5ad0d2db91b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -600,7 +600,7 @@ def test_replace_regex(): def test_extract_regex(): ar = pa.array(['a1', 'zb2z']) - struct = pc.extract_regex(ar, regex=r'(?P[ab])(?P\d)') + struct = pc.extract_regex(ar, pattern=r'(?P[ab])(?P\d)') assert struct.tolist() == [{'letter': 'a', 'digit': '1'}, { 'letter': 'b', 'digit': '2'}]