Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions {
struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
explicit ReplaceSubstringOptions(std::string pattern, std::string replacement,
int64_t max_replacements = -1)
: pattern(pattern), replacement(replacement), max_replacements(max_replacements) {}
: pattern(std::move(pattern)),
replacement(std::move(replacement)),
max_replacements(max_replacements) {}

/// Pattern to match, literal, or regular expression depending on which kernel is used
std::string pattern;
Expand All @@ -81,6 +83,13 @@ struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions {
int64_t max_replacements;
};

struct ARROW_EXPORT ExtractRegexOptions : public FunctionOptions {
explicit ExtractRegexOptions(std::string pattern) : pattern(std::move(pattern)) {}

/// Regular expression with named capture fields
std::string pattern;
};

/// Options for IsIn and IndexIn functions
struct ARROW_EXPORT SetLookupOptions : public FunctionOptions {
explicit SetLookupOptions(Datum value_set, bool skip_nulls = false)
Expand Down
244 changes: 235 additions & 9 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <cctype>
#include <iterator>
#include <string>

#ifdef ARROW_WITH_UTF8PROC
Expand All @@ -30,17 +31,40 @@
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_nested.h"
#include "arrow/buffer_builder.h"

#include "arrow/builder.h"
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/utf8.h"
#include "arrow/util/value_parsing.h"

namespace arrow {

using internal::checked_cast;

namespace compute {
namespace internal {

namespace {

#ifdef ARROW_WITH_RE2
util::string_view ToStringView(re2::StringPiece piece) {
return {piece.data(), piece.length()};
}

re2::StringPiece ToStringPiece(util::string_view view) {
return {view.data(), view.length()};
}

Status RegexStatus(const RE2& regex) {
if (!regex.ok()) {
return Status::Invalid("Invalid regular expression: ", regex.error());
}
return Status::OK();
}
#endif

// Code units in the range [a-z] can only be an encoding of an ascii
// character/codepoint, not the 2nd, 3rd or 4th code unit (byte) of an different
// codepoint. This guaranteed by non-overlap design of the unicode standard. (see
Expand Down Expand Up @@ -449,10 +473,8 @@ struct RegexSubstringMatcher {
const RE2 regex_match_;

RegexSubstringMatcher(KernelContext* ctx, const MatchSubstringOptions& options)
: options_(options), regex_match_(options_.pattern) {
if (!regex_match_.ok()) {
ctx->SetStatus(Status::Invalid("Regular expression error"));
}
: options_(options), regex_match_(options_.pattern, RE2::Quiet) {
KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_match_));
}

bool Match(util::string_view current) {
Expand Down Expand Up @@ -1390,16 +1412,21 @@ struct RegexSubStringReplacer {
// we have 2 regexes, one with () around it, one without.
RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options)
: options_(options),
regex_find_("(" + options_.pattern + ")"),
regex_replacement_(options_.pattern) {
if (!(regex_find_.ok() && regex_replacement_.ok())) {
ctx->SetStatus(Status::Invalid("Regular expression error"));
return;
regex_find_("(" + options_.pattern + ")", RE2::Quiet),
regex_replacement_(options_.pattern, RE2::Quiet) {
KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_find_));
KERNEL_RETURN_IF_ERROR(ctx, RegexStatus(regex_replacement_));
std::string replacement_error;
if (!regex_replacement_.CheckRewriteString(options_.replacement,
&replacement_error)) {
ctx->SetStatus(
Status::Invalid("Invalid replacement string: ", std::move(replacement_error)));
}
}

Status ReplaceString(util::string_view s, TypedBufferBuilder<uint8_t>* builder) {
re2::StringPiece replacement(options_.replacement);

if (options_.max_replacements == -1) {
std::string s_copy(s.to_string());
re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement);
Expand Down Expand Up @@ -1472,6 +1499,204 @@ const FunctionDoc replace_substring_regex_doc(
{"strings"}, "ReplaceSubstringOptions");
#endif

// ----------------------------------------------------------------------
// Extract with regex

#ifdef ARROW_WITH_RE2

// TODO cache this once per ExtractRegexOptions
struct ExtractRegexData {
// Use unique_ptr<> because RE2 is non-movable
std::unique_ptr<RE2> regex;
std::vector<std::string> group_names;

static Result<ExtractRegexData> Make(const ExtractRegexOptions& options) {
ExtractRegexData data(options.pattern);
RETURN_NOT_OK(RegexStatus(*data.regex));

const int group_count = data.regex->NumberOfCapturingGroups();
const auto& name_map = data.regex->CapturingGroupNames();
data.group_names.reserve(group_count);

for (int i = 0; i < group_count; i++) {
auto item = name_map.find(i + 1); // re2 starts counting from 1
if (item == name_map.end()) {
// XXX should we instead just create fields with an empty name?
return Status::Invalid("Regular expression contains unnamed groups");
}
data.group_names.emplace_back(item->second);
}
return std::move(data);
}

Result<ValueDescr> ResolveOutputType(const std::vector<ValueDescr>& args) const {
const auto& input_type = args[0].type;
if (input_type == nullptr) {
// No input type specified => propagate shape
return args[0];
}
// Input type is either String or LargeString and is also the type of each
// field in the output struct type.
DCHECK(input_type->id() == Type::STRING || input_type->id() == Type::LARGE_STRING);
FieldVector fields;
fields.reserve(group_names.size());
std::transform(group_names.begin(), group_names.end(), std::back_inserter(fields),
[&](const std::string& name) { return field(name, input_type); });
return struct_(std::move(fields));
}

private:
explicit ExtractRegexData(const std::string& pattern)
: regex(new RE2(pattern, RE2::Quiet)) {}
};

Result<ValueDescr> ResolveExtractRegexOutput(KernelContext* ctx,
const std::vector<ValueDescr>& args) {
using State = OptionsWrapper<ExtractRegexOptions>;
ExtractRegexOptions options = State::Get(ctx);
ARROW_ASSIGN_OR_RAISE(auto data, ExtractRegexData::Make(options));
return data.ResolveOutputType(args);
}

struct ExtractRegexBase {
const ExtractRegexData& data;
const int group_count;
std::vector<re2::StringPiece> found_values;
std::vector<re2::RE2::Arg> args;
std::vector<const re2::RE2::Arg*> args_pointers;
const re2::RE2::Arg** args_pointers_start;
const re2::RE2::Arg* null_arg = nullptr;

explicit ExtractRegexBase(const ExtractRegexData& data)
: data(data),
group_count(static_cast<int>(data.group_names.size())),
found_values(group_count) {
args.reserve(group_count);
args_pointers.reserve(group_count);

for (int i = 0; i < group_count; i++) {
args.emplace_back(&found_values[i]);
// Since we reserved capacity, we're guaranteed the pointer remains valid
args_pointers.push_back(&args[i]);
}
// Avoid null pointer if there is no capture group
args_pointers_start = (group_count > 0) ? args_pointers.data() : &null_arg;
}

bool Match(util::string_view s) {
return re2::RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start,
group_count);
}
};

template <typename Type>
struct ExtractRegex : public ExtractRegexBase {
using ArrayType = typename TypeTraits<Type>::ArrayType;
using ScalarType = typename TypeTraits<Type>::ScalarType;
using BuilderType = typename TypeTraits<Type>::BuilderType;
using State = OptionsWrapper<ExtractRegexOptions>;

using ExtractRegexBase::ExtractRegexBase;

static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ExtractRegexOptions options = State::Get(ctx);
KERNEL_ASSIGN_OR_RAISE(auto data, ctx, ExtractRegexData::Make(options));
ExtractRegex{data}.Extract(ctx, batch, out);
}

void Extract(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
KERNEL_ASSIGN_OR_RAISE(auto descr, ctx,
data.ResolveOutputType(batch.GetDescriptors()));
DCHECK_NE(descr.type, nullptr);
const auto& type = descr.type;

if (batch[0].kind() == Datum::ARRAY) {
std::unique_ptr<ArrayBuilder> array_builder;
KERNEL_RETURN_IF_ERROR(ctx, MakeBuilder(ctx->memory_pool(), type, &array_builder));
StructBuilder* struct_builder = checked_cast<StructBuilder*>(array_builder.get());

std::vector<BuilderType*> field_builders;
field_builders.reserve(group_count);
for (int i = 0; i < group_count; i++) {
field_builders.push_back(
checked_cast<BuilderType*>(struct_builder->field_builder(i)));
}

auto visit_null = [&]() {
for (int i = 0; i < group_count; i++) {
RETURN_NOT_OK(field_builders[i]->AppendEmptyValue());
}
return struct_builder->AppendNull();
};
auto visit_value = [&](util::string_view s) {
if (Match(s)) {
for (int i = 0; i < group_count; i++) {
RETURN_NOT_OK(field_builders[i]->Append(ToStringView(found_values[i])));
}
return struct_builder->Append();
} else {
return visit_null();
}
};
const ArrayData& input = *batch[0].array();
KERNEL_RETURN_IF_ERROR(ctx,
VisitArrayDataInline<Type>(input, visit_value, visit_null));

std::shared_ptr<Array> out_array;
KERNEL_RETURN_IF_ERROR(ctx, struct_builder->Finish(&out_array));
*out = std::move(out_array);
} else {
const auto& input = checked_cast<const ScalarType&>(*batch[0].scalar());
auto result = std::make_shared<StructScalar>(type);
if (input.is_valid && Match(util::string_view(*input.value))) {
result->value.reserve(group_count);
for (int i = 0; i < group_count; i++) {
result->value.push_back(
std::make_shared<ScalarType>(found_values[i].as_string()));
}
result->is_valid = true;
} else {
result->is_valid = false;
}
out->value = std::move(result);
}
}
};

const FunctionDoc extract_regex_doc(
"Extract substrings captured by a regex pattern",
("For each string in `strings`, match the regular expression and, if\n"
"successful, emit a struct with field names and values coming from the\n"
"regular expression's named capture groups. If the input is null or the\n"
"regular expression fails matching, a null output value is emitted.\n"
"\n"
"Regular expression matching is done using the Google RE2 library."),
{"strings"}, "ExtractRegexOptions");

void AddExtractRegex(FunctionRegistry* registry) {
auto func = std::make_shared<ScalarFunction>("extract_regex", Arity::Unary(),
&extract_regex_doc);
using t32 = ExtractRegex<StringType>;
using t64 = ExtractRegex<LargeStringType>;
OutputType out_ty(ResolveExtractRegexOutput);
ScalarKernel kernel;

// Null values will be computed based on regex match or not
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
kernel.signature.reset(new KernelSignature({utf8()}, out_ty));
kernel.exec = t32::Exec;
kernel.init = t32::State::Init;
DCHECK_OK(func->AddKernel(kernel));
kernel.signature.reset(new KernelSignature({large_utf8()}, out_ty));
kernel.exec = t64::Exec;
kernel.init = t64::State::Init;
DCHECK_OK(func->AddKernel(kernel));

DCHECK_OK(registry->AddFunction(std::move(func)));
}
#endif // ARROW_WITH_RE2

// ----------------------------------------------------------------------
// strptime string parsing

Expand Down Expand Up @@ -2153,6 +2378,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
MakeUnaryStringBatchKernelWithState<ReplaceSubStringRegex>(
"replace_substring_regex", registry, &replace_substring_regex_doc,
MemAllocation::NO_PREALLOCATE);
AddExtractRegex(registry);
#endif
AddStrptime(registry);
}
Expand Down
Loading