diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 1d9e41bba7a..c5b55eef42a 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -79,8 +79,10 @@ build() { export CPPFLAGS="${CPPFLAGS} -I${MINGW_PREFIX}/include" export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF + export ARROW_WITH_RE2=OFF else export ARROW_S3=ON + export ARROW_WITH_RE2=ON fi MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ @@ -105,6 +107,7 @@ build() { -DARROW_SNAPPY_USE_SHARED=OFF \ -DARROW_USE_GLOG=OFF \ -DARROW_WITH_LZ4=ON \ + -DARROW_WITH_RE2="${ARROW_WITH_RE2}" \ -DARROW_WITH_SNAPPY=ON \ -DARROW_WITH_ZLIB=ON \ -DARROW_WITH_ZSTD=ON \ diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 0d95092c95b..730836bd118 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -68,6 +68,19 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { + explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, + int64_t max_replacements = -1) + : pattern(pattern), replacement(replacement), max_replacements(max_replacements) {} + + /// Pattern to match, literal, or regular expression depending on which kernel is used + std::string pattern; + /// String to replace the pattern with + std::string replacement; + /// Max number of substrings to replace (-1 means unbounded) + int64_t max_replacements; +}; + /// Options for IsIn and IndexIn functions struct ARROW_EXPORT SetLookupOptions : public FunctionOptions { explicit SetLookupOptions(Datum value_set, bool skip_nulls = false) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 88c91a18818..39869879561 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -23,6 +23,10 @@ #include #endif +#ifdef ARROW_WITH_RE2 +#include +#endif + #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" @@ -1230,6 +1234,197 @@ void AddSplit(FunctionRegistry* registry) { #endif } +// ---------------------------------------------------------------------- +// Replace substring (plain, regex) + +template +struct ReplaceSubString { + using ScalarType = typename TypeTraits::ScalarType; + using offset_type = typename Type::offset_type; + using ValueDataBuilder = TypedBufferBuilder; + using OffsetBuilder = TypedBufferBuilder; + using State = OptionsWrapper; + + static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // TODO Cache replacer accross invocations (for regex compilation) + Replacer replacer{ctx, State::Get(ctx)}; + if (!ctx->HasError()) { + Replace(ctx, batch, &replacer, out); + } + } + + static void Replace(KernelContext* ctx, const ExecBatch& batch, Replacer* replacer, + Datum* out) { + ValueDataBuilder value_data_builder(ctx->memory_pool()); + OffsetBuilder offset_builder(ctx->memory_pool()); + + if (batch[0].kind() == Datum::ARRAY) { + // We already know how many strings we have, so we can use Reserve/UnsafeAppend + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Reserve(batch[0].array()->length)); + offset_builder.UnsafeAppend(0); // offsets start at 0 + + const ArrayData& input = *batch[0].array(); + KERNEL_RETURN_IF_ERROR( + ctx, VisitArrayDataInline( + input, + [&](util::string_view s) { + RETURN_NOT_OK(replacer->ReplaceString(s, &value_data_builder)); + offset_builder.UnsafeAppend( + static_cast(value_data_builder.length())); + return Status::OK(); + }, + [&]() { + // offset for null value + offset_builder.UnsafeAppend( + static_cast(value_data_builder.length())); + return Status::OK(); + })); + ArrayData* output = out->mutable_array(); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&output->buffers[2])); + KERNEL_RETURN_IF_ERROR(ctx, offset_builder.Finish(&output->buffers[1])); + } else { + const auto& input = checked_cast(*batch[0].scalar()); + auto result = std::make_shared(); + if (input.is_valid) { + util::string_view s = static_cast(*input.value); + KERNEL_RETURN_IF_ERROR(ctx, replacer->ReplaceString(s, &value_data_builder)); + KERNEL_RETURN_IF_ERROR(ctx, value_data_builder.Finish(&result->value)); + result->is_valid = true; + } + out->value = result; + } + } +}; + +struct PlainSubStringReplacer { + const ReplaceSubstringOptions& options_; + + PlainSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) + : options_(options) {} + + Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { + const char* i = s.begin(); + const char* end = s.end(); + int64_t max_replacements = options_.max_replacements; + while ((i < end) && (max_replacements != 0)) { + const char* pos = + std::search(i, end, options_.pattern.begin(), options_.pattern.end()); + if (pos == end) { + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + i = end; + } else { + // the string before the pattern + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(pos - i))); + // the replacement + RETURN_NOT_OK( + builder->Append(reinterpret_cast(options_.replacement.data()), + options_.replacement.length())); + // skip pattern + i = pos + options_.pattern.length(); + max_replacements--; + } + } + // if we exited early due to max_replacements, add the trailing part + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + return Status::OK(); + } +}; + +#ifdef ARROW_WITH_RE2 +struct RegexSubStringReplacer { + const ReplaceSubstringOptions& options_; + const RE2 regex_find_; + const RE2 regex_replacement_; + + // Using RE2::FindAndConsume we can only find the pattern if it is a group, therefore + // we have 2 regexes, one with () around it, one without. + RegexSubStringReplacer(KernelContext* ctx, const ReplaceSubstringOptions& options) + : options_(options), + regex_find_("(" + options_.pattern + ")"), + regex_replacement_(options_.pattern) { + if (!(regex_find_.ok() && regex_replacement_.ok())) { + ctx->SetStatus(Status::Invalid("Regular expression error")); + return; + } + } + + Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) { + re2::StringPiece replacement(options_.replacement); + if (options_.max_replacements == -1) { + std::string s_copy(s.to_string()); + re2::RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); + RETURN_NOT_OK(builder->Append(reinterpret_cast(s_copy.data()), + s_copy.length())); + return Status::OK(); + } + + // Since RE2 does not have the concept of max_replacements, we have to do some work + // ourselves. + // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite + const char* i = s.begin(); + const char* end = s.end(); + re2::StringPiece piece(s.data(), s.length()); + + int64_t max_replacements = options_.max_replacements; + while ((i < end) && (max_replacements != 0)) { + std::string found; + if (!re2::RE2::FindAndConsume(&piece, regex_find_, &found)) { + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + i = end; + } else { + // wind back to the beginning of the match + const char* pos = piece.begin() - found.length(); + // the string before the pattern + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(pos - i))); + // replace the pattern in what we found + if (!re2::RE2::Replace(&found, regex_replacement_, replacement)) { + return Status::Invalid("Regex found, but replacement failed"); + } + RETURN_NOT_OK(builder->Append(reinterpret_cast(found.data()), + static_cast(found.length()))); + // skip pattern + i = piece.begin(); + max_replacements--; + } + } + // If we exited early due to max_replacements, add the trailing part + RETURN_NOT_OK(builder->Append(reinterpret_cast(i), + static_cast(end - i))); + return Status::OK(); + } +}; +#endif + +template +using ReplaceSubStringPlain = ReplaceSubString; + +const FunctionDoc replace_substring_doc( + "Replace non-overlapping substrings that match pattern by replacement", + ("For each string in `strings`, replace non-overlapping substrings that match\n" + "`pattern` by `replacement`. If `max_replacements != -1`, it determines the\n" + "maximum amount of replacements made, counting from the left. Null values emit\n" + "null."), + {"strings"}, "ReplaceSubstringOptions"); + +#ifdef ARROW_WITH_RE2 +template +using ReplaceSubStringRegex = ReplaceSubString; + +const FunctionDoc replace_substring_regex_doc( + "Replace non-overlapping substrings that match regex `pattern` by `replacement`", + ("For each string in `strings`, replace non-overlapping substrings that match the\n" + "regular expression `pattern` by `replacement` using the Google RE2 library.\n" + "If `max_replacements != -1`, it determines the maximum amount of replacements\n" + "made, counting from the left. Note that if the pattern contains groups,\n" + "backreferencing macan be used. Null values emit null."), + {"strings"}, "ReplaceSubstringOptions"); +#endif + // ---------------------------------------------------------------------- // strptime string parsing @@ -1904,6 +2099,14 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddBinaryLength(registry); AddUtf8Length(registry); AddMatchSubstring(registry); + MakeUnaryStringBatchKernelWithState( + "replace_substring", registry, &replace_substring_doc, + MemAllocation::NO_PREALLOCATE); +#ifdef ARROW_WITH_RE2 + MakeUnaryStringBatchKernelWithState( + "replace_substring_regex", registry, &replace_substring_regex_doc, + MemAllocation::NO_PREALLOCATE); +#endif AddStrptime(registry); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 281fcb5c7aa..88622e842d1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -48,6 +48,14 @@ class BaseTestStringKernels : public ::testing::Test { CheckScalarUnary(func_name, type(), json_input, out_ty, json_expected, options); } + void CheckBinaryScalar(std::string func_name, std::string json_left_input, + std::string json_right_scalar, std::shared_ptr out_ty, + std::string json_expected, + const FunctionOptions* options = nullptr) { + CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty, + json_expected, options); + } + std::shared_ptr type() { return TypeTraits::type_singleton(); } std::shared_ptr offset_type() { @@ -422,6 +430,52 @@ TYPED_TEST(TestStringKernels, SplitWhitespaceUTF8Reverse) { &options_max); } +TYPED_TEST(TestStringKernels, ReplaceSubstring) { + ReplaceSubstringOptions options{"foo", "bazz"}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that bazz", null])", &options); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringLimited) { + ReplaceSubstringOptions options{"foo", "bazz", 1}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that foo", null])", &options); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("replace_substring", {input})); +} + +#ifdef ARROW_WITH_RE2 +TYPED_TEST(TestStringKernels, ReplaceSubstringRegex) { + ReplaceSubstringOptions options_regex{"(fo+)\\s*", "\\1-bazz"}; + this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])", + this->type(), R"(["foo-bazz", "this foo-bazzthat foo-bazz", null])", + &options_regex); + // make sure we match non-overlapping + ReplaceSubstringOptions options_regex2{"(a.a)", "aba\\1"}; + this->CheckUnary("replace_substring_regex", R"(["aaaaaa"])", this->type(), + R"(["abaaaaabaaaa"])", &options_regex2); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexLimited) { + // With a finite number of replacements + ReplaceSubstringOptions options1{"foo", "bazz", 1}; + this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", + this->type(), R"(["bazz", "this bazz that foo", null])", &options1); + ReplaceSubstringOptions options_regex1{"(fo+)\\s*", "\\1-bazz", 1}; + this->CheckUnary("replace_substring_regex", R"(["foo ", "this foo that foo", null])", + this->type(), R"(["foo-bazz", "this foo-bazzthat foo", null])", + &options_regex1); +} + +TYPED_TEST(TestStringKernels, ReplaceSubstringRegexNoOptions) { + Datum input = ArrayFromJSON(this->type(), "[]"); + ASSERT_RAISES(Invalid, CallFunction("replace_substring_regex", {input})); +} +#endif + TYPED_TEST(TestStringKernels, Strptime) { std::string input1 = R"(["5/1/2020", null, "12/11/1900"])"; std::string output1 = R"(["2020-05-01", null, "1900-12-11"])"; diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index e4eaa94bc59..065b80736aa 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -426,21 +426,25 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+--------------------------+------------+-------------------------+---------------------+---------+ -| Function name | Arity | Input types | Output type | Notes | -+==========================+============+=========================+=====================+=========+ -| ascii_lower | Unary | String-like | String-like | \(1) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| ascii_upper | Unary | String-like | String-like | \(1) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_length | Unary | String-like | Int32 or Int64 | \(3) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_lower | Unary | String-like | String-like | \(4) | -+--------------------------+------------+-------------------------+---------------------+---------+ -| utf8_upper | Unary | String-like | String-like | \(4) | -+--------------------------+------------+-------------------------+---------------------+---------+ ++--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+ +| Function name | Arity | Input types | Output type | Notes | Options class | ++==========================+============+=========================+=====================+=========+=======================================+ +| ascii_lower | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| ascii_upper | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(2) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| replace_substring | Unary | String-like | String-like | \(3) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| replace_substring_regex | Unary | String-like | String-like | \(4) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_length | Unary | String-like | Int32 or Int64 | \(5) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_lower | Unary | String-like | String-like | \(6) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ +| utf8_upper | Unary | String-like | String-like | \(6) | | ++--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ * \(1) Each ASCII character in the input is converted to lowercase or @@ -449,10 +453,23 @@ String transforms * \(2) Output is the physical length in bytes of each input element. Output type is Int32 for Binary / String, Int64 for LargeBinary / LargeString. -* \(3) Output is the number of characters (not bytes) of each input element. +* \(3) Replace non-overlapping substrings that match to + :member:`ReplaceSubstringOptions::pattern` by + :member:`ReplaceSubstringOptions::replacement`. If + :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the + maximum number of replacements made, counting from the left. + +* \(4) Replace non-overlapping substrings that match to the regular expression + :member:`ReplaceSubstringOptions::pattern` by + :member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If + :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the + maximum number of replacements made, counting from the left. Note that if the + pattern contains groups, backreferencing can be used. + +* \(5) Output is the number of characters (not bytes) of each input element. Output type is Int32 for String, Int64 for LargeString. -* \(4) Each UTF8-encoded character in the input is converted to lowercase or +* \(6) Each UTF8-encoded character in the input is converted to lowercase or uppercase. diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index f3a8eb860d4..1515bdcfd36 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -684,6 +684,26 @@ class TrimOptions(_TrimOptions): self._set_options(characters) +cdef class _ReplaceSubstringOptions(FunctionOptions): + cdef: + unique_ptr[CReplaceSubstringOptions] replace_substring_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.replace_substring_options.get() + + def _set_options(self, pattern, replacement, max_replacements): + self.replace_substring_options.reset( + new CReplaceSubstringOptions(tobytes(pattern), + tobytes(replacement), + max_replacements) + ) + + +class ReplaceSubstringOptions(_ReplaceSubstringOptions): + def __init__(self, pattern, replacement, max_replacements=-1): + self._set_options(pattern, replacement, max_replacements) + + cdef class _FilterOptions(FunctionOptions): cdef: unique_ptr[CFilterOptions] filter_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 2cdd843d81a..1b46a08c402 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -42,6 +42,7 @@ PartitionNthOptions, ProjectOptions, QuantileOptions, + ReplaceSubstringOptions, SetLookupOptions, SortOptions, StrptimeOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 61deb658b0c..ebdcd08334c 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1815,6 +1815,14 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse) c_string pattern + cdef cppclass CReplaceSubstringOptions \ + "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions): + CReplaceSubstringOptions(c_string pattern, c_string replacement, + int64_t max_replacements) + c_string pattern + c_string replacement + int64_t max_replacements + cdef cppclass CCastOptions" arrow::compute::CastOptions"(CFunctionOptions): CCastOptions() CCastOptions(c_bool safe) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 112629fc702..160375f93bd 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -579,6 +579,18 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +def test_replace_plain(): + ar = pa.array(['foo', 'food', None]) + ar = pc.replace_substring(ar, pattern='foo', replacement='bar') + assert ar.tolist() == ['bar', 'bard', None] + + +def test_replace_regex(): + ar = pa.array(['foo', 'mood', None]) + ar = pc.replace_substring_regex(ar, pattern='(.)oo', replacement=r'\100') + assert ar.tolist() == ['f00', 'm00d', None] + + @pytest.mark.parametrize(('ty', 'values'), all_array_types) def test_take(ty, values): arr = pa.array(values, type=ty) diff --git a/r/configure.win b/r/configure.win index 88ac0e125e1..d645834fac8 100644 --- a/r/configure.win +++ b/r/configure.win @@ -50,13 +50,13 @@ AWS_LIBS="-laws-cpp-sdk-config -laws-cpp-sdk-transfer -laws-cpp-sdk-identity-man # NOTE: If you make changes to the libraries below, you should also change # ci/scripts/r_windows_build.sh and ci/scripts/PKGBUILD PKG_CFLAGS="-I${RWINLIB}/include -DARROW_STATIC -DPARQUET_STATIC -DARROW_DS_STATIC -DARROW_R_WITH_ARROW -DARROW_R_WITH_PARQUET -DARROW_R_WITH_DATASET" -PKG_LIBS="-L${RWINLIB}/lib"'$(subst gcc,,$(COMPILED_BY))$(R_ARCH) '"-L${RWINLIB}/lib"'$(R_ARCH) '"-lparquet -larrow_dataset -larrow -larrow_bundled_dependencies -lutf8proc -lre2 -lthrift -lsnappy -lz -lzstd -llz4 ${MIMALLOC_LIBS} ${OPENSSL_LIBS}" +PKG_LIBS="-L${RWINLIB}/lib"'$(subst gcc,,$(COMPILED_BY))$(R_ARCH) '"-L${RWINLIB}/lib"'$(R_ARCH) '"-lparquet -larrow_dataset -larrow -larrow_bundled_dependencies -lutf8proc -lthrift -lsnappy -lz -lzstd -llz4 ${MIMALLOC_LIBS} ${OPENSSL_LIBS}" -# S3 support only for Rtools40 (i.e. R >= 4.0) +# S3 and re2 support only for Rtools40 (i.e. R >= 4.0) "${R_HOME}/bin${R_ARCH_BIN}/Rscript.exe" -e 'R.version$major >= 4' | grep TRUE >/dev/null 2>&1 if [ $? -eq 0 ]; then PKG_CFLAGS="${PKG_CFLAGS} -DARROW_R_WITH_S3" - PKG_LIBS="${PKG_LIBS} ${AWS_LIBS}" + PKG_LIBS="${PKG_LIBS} -lre2 ${AWS_LIBS}" else # It seems that order matters PKG_LIBS="${PKG_LIBS} -lws2_32"