diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 190696f6ed5..6e9a9340f2c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -77,6 +77,18 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { + explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) + : start(start), stop(stop), replacement(std::move(replacement)) {} + + /// Index to start slicing at + int64_t start; + /// Index to stop slicing at + int64_t stop; + /// String to replace the slice with + std::string replacement; +}; + struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, int64_t max_replacements = -1) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 8b740f3742a..b6c1b8f6261 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2288,6 +2288,165 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct BinaryReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min(input_string_ncodeunits, std::max(before_slice, opts.stop)); + } else { + // Count from right + after_slice = std::max(before_slice, input_string_ncodeunits + opts.stop); + } + output = std::copy(input, input + before_slice, output); + output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); + output = std::copy(input + after_slice, input + input_string_ncodeunits, output); + return output - output_start; + } +}; + +struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t *begin_sliced, *end_sliced; + uint8_t* output_start = output; + + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.start >= 0) { + // Count from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) { + return kTransformError; + } + if (opts.stop > options->start) { + // Continue counting from left + const int64_t length = opts.stop - options->start; + if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) { + return kTransformError; + } + } else if (opts.stop < 0) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // Zero-length slice + end_sliced = begin_sliced; + } + } else { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, + -opts.start)) { + return kTransformError; + } + if (opts.stop >= 0) { + // Restart counting from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) { + return kTransformError; + } + if (end_sliced <= begin_sliced) { + // Zero-length slice + end_sliced = begin_sliced; + } + } else if ((opts.stop < 0) && (options->stop > options->start)) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // zero-length slice + end_sliced = begin_sliced; + } + } + output = std::copy(begin, begin_sliced, output); + output = std::copy(opts.replacement.begin(), options->replacement.end(), output); + output = std::copy(end_sliced, end, output); + return output - output_start; + } +}; + +template +using BinaryReplaceSlice = + StringTransformExecWithState; +template +using Utf8ReplaceSlice = StringTransformExecWithState; + +const FunctionDoc binary_replace_slice_doc( + "Replace a slice of a binary string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in bytes.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +const FunctionDoc utf8_replace_slice_doc( + "Replace a slice of a string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in codeunits.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +void AddReplaceSlice(FunctionRegistry* registry) { + { + auto func = std::make_shared("binary_replace_slice", Arity::Unary(), + &binary_replace_slice_doc); + for (const auto& ty : BaseBinaryTypes()) { + DCHECK_OK(func->AddKernel({ty}, ty, + GenerateTypeAgnosticVarBinaryBase(ty), + ReplaceSliceTransformBase::State::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared("utf8_replace_slice", Arity::Unary(), + &utf8_replace_slice_doc); + DCHECK_OK(func->AddKernel({utf8()}, utf8(), Utf8ReplaceSlice::Exec, + ReplaceSliceTransformBase::State::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), + Utf8ReplaceSlice::Exec, + ReplaceSliceTransformBase::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); + } +} + // ---------------------------------------------------------------------- // Extract with regex @@ -3434,6 +3593,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MemAllocation::NO_PREALLOCATE); AddExtractRegex(registry); #endif + AddReplaceSlice(registry); AddSlice(registry); AddSplit(registry); AddStrptime(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index c4b6956be2b..7d52d6aacf2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -122,6 +122,56 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) { // TODO: case-insensitive } +TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) { + ReplaceSliceOptions options{0, 1, "XX"}; + this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options); + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])", + this->type(), R"([null, "XX", "XX", "XXb", "XXbc"])", &options); + + ReplaceSliceOptions options_whole{0, 5, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcde", "abcdef"])", this->type(), + R"([null, "XX", "XX", "XX", "XX", "XX", "XXf"])", &options_whole); + + ReplaceSliceOptions options_middle{2, 4, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXX", "abXX", "abXXe"])", + &options_middle); + + ReplaceSliceOptions options_neg_start{-3, -2, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XXa", "XXab", "XXbc", "aXXcd", "abXXde"])", + &options_neg_start); + + ReplaceSliceOptions options_neg_end{2, -2, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXde"])", + &options_neg_end); + + ReplaceSliceOptions options_neg_pos{-1, 2, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XX", "aXX", "abXXc", "abcXXd", "abcdXXe"])", + &options_neg_pos); + + // Effectively the same as [2, 2) + ReplaceSliceOptions options_flip{2, 0, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXcde"])", + &options_flip); + + // Effectively the same as [-3, -3) + ReplaceSliceOptions options_neg_flip{-3, -5, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XXa", "XXab", "XXabc", "aXXbcd", "abXXcde"])", + &options_neg_flip); +} + template class TestStringKernels : public BaseTestStringKernels {}; @@ -745,6 +795,56 @@ TYPED_TEST(TestStringKernels, SplitRegexReverse) { } #endif +TYPED_TEST(TestStringKernels, Utf8ReplaceSlice) { + ReplaceSliceOptions options{0, 1, "χχ"}; + this->CheckUnary("utf8_replace_slice", "[]", this->type(), "[]", &options); + this->CheckUnary("utf8_replace_slice", R"([null, "", "π", "πb", "πbθ"])", this->type(), + R"([null, "χχ", "χχ", "χχb", "χχbθ"])", &options); + + ReplaceSliceOptions options_whole{0, 5, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθde", "πbθdef"])", this->type(), + R"([null, "χχ", "χχ", "χχ", "χχ", "χχ", "χχf"])", &options_whole); + + ReplaceSliceOptions options_middle{2, 4, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχ", "πbχχ", "πbχχe"])", + &options_middle); + + ReplaceSliceOptions options_neg_start{-3, -2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχπ", "χχπb", "χχbθ", "πχχθd", "πbχχde"])", + &options_neg_start); + + ReplaceSliceOptions options_neg_end{2, -2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχde"])", + &options_neg_end); + + ReplaceSliceOptions options_neg_pos{-1, 2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχ", "πχχ", "πbχχθ", "πbθχχd", "πbθdχχe"])", + &options_neg_pos); + + // Effectively the same as [2, 2) + ReplaceSliceOptions options_flip{2, 0, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχθde"])", + &options_flip); + + // Effectively the same as [-3, -3) + ReplaceSliceOptions options_neg_flip{-3, -5, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχπ", "χχπb", "χχπbθ", "πχχbθd", "πbχχθde"])", + &options_neg_flip); +} + TYPED_TEST(TestStringKernels, ReplaceSubstring) { ReplaceSubstringOptions options{"foo", "bazz"}; this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index ad2d9f8f5d2..b28e3928a74 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -451,29 +451,33 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+ -| Function name | Arity | Input types | Output type | Notes | Options class | -+==========================+============+=========================+=====================+=========+=======================================+ -| ascii_lower | Unary | String-like | String-like | \(1) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| ascii_reverse | Unary | String-like | String-like | \(2) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| ascii_upper | Unary | String-like | String-like | \(1) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(3) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| replace_substring | Unary | String-like | String-like | \(4) | :struct:`ReplaceSubstringOptions` | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| replace_substring_regex | Unary | String-like | String-like | \(5) | :struct:`ReplaceSubstringOptions` | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_length | Unary | String-like | Int32 or Int64 | \(6) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_lower | Unary | String-like | String-like | \(7) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_reverse | Unary | String-like | String-like | \(8) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_upper | Unary | String-like | String-like | \(7) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ ++--------------------------+------------+-------------------------+------------------------+-------------------------------------------------+ +| Function name | Arity | Input types | Output type | Notes | Options class | ++==========================+============+=========================+========================+=========+=======================================+ +| ascii_lower | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| ascii_reverse | Unary | String-like | String-like | \(2) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| ascii_upper | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(3) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| binary_replace_slice | Unary | String-like | Binary- or String-like | \(4) | :struct:`ReplaceSliceOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| replace_substring | Unary | String-like | String-like | \(5) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| replace_substring_regex | Unary | String-like | String-like | \(6) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_length | Unary | String-like | Int32 or Int64 | \(7) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_lower | Unary | String-like | String-like | \(8) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_replace_slice | Unary | String-like | String-like | \(4) | :struct:`ReplaceSliceOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_reverse | Unary | String-like | String-like | \(9) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_upper | Unary | String-like | String-like | \(8) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ * \(1) Each ASCII character in the input is converted to lowercase or @@ -485,26 +489,31 @@ String transforms * \(3) Output is the physical length in bytes of each input element. Output type is Int32 for Binary / String, Int64 for LargeBinary / LargeString. -* \(4) Replace non-overlapping substrings that match to +* \(4) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` + (inclusive) to :member:`ReplaceSliceOptions::stop` (exclusive) by + :member:`ReplaceSubstringOptions::replacement`. The binary kernel measures the + slice in bytes, while the UTF8 kernel measures the slice in codeunits. + +* \(5) Replace non-overlapping substrings that match to :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. -* \(5) Replace non-overlapping substrings that match to the regular expression +* \(6) Replace non-overlapping substrings that match to the regular expression :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. Note that if the pattern contains groups, backreferencing can be used. -* \(6) Output is the number of characters (not bytes) of each input element. +* \(7) Output is the number of characters (not bytes) of each input element. Output type is Int32 for String, Int64 for LargeString. -* \(7) Each UTF8-encoded character in the input is converted to lowercase or +* \(8) Each UTF8-encoded character in the input is converted to lowercase or uppercase. -* \(8) Each UTF8-encoded code unit is written in reverse order to the output. +* \(9) Each UTF8-encoded code unit is written in reverse order to the output. If the input is not valid UTF8, then the output is undefined (but the size of output buffers will be preserved). diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 1dbcb3073ca..2e37f9169a7 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -168,7 +168,13 @@ String Transforms ascii_lower ascii_reverse ascii_upper + binary_length + binary_replace_slice + replace_substring + replace_substring_regex + utf8_length utf8_lower + utf8_replace_slice utf8_reverse utf8_upper diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 8da0ea05006..104cd1bac1f 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -701,6 +701,24 @@ class TrimOptions(_TrimOptions): self._set_options(characters) +cdef class _ReplaceSliceOptions(FunctionOptions): + cdef: + unique_ptr[CReplaceSliceOptions] replace_slice_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.replace_slice_options.get() + + def _set_options(self, start, stop, replacement): + self.replace_slice_options.reset( + new CReplaceSliceOptions(start, stop, tobytes(replacement)) + ) + + +class ReplaceSliceOptions(_ReplaceSliceOptions): + def __init__(self, start, stop, replacement): + self._set_options(start, stop, replacement) + + cdef class _ReplaceSubstringOptions(FunctionOptions): cdef: unique_ptr[CReplaceSubstringOptions] replace_substring_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8dc7181514c..44282369f87 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -41,6 +41,7 @@ PartitionNthOptions, ProjectOptions, QuantileOptions, + ReplaceSliceOptions, ReplaceSubstringOptions, ScalarAggregateOptions, SetLookupOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index eefca44605c..d5ce98d9a88 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1817,6 +1817,13 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse) c_string pattern + cdef cppclass CReplaceSliceOptions \ + "arrow::compute::ReplaceSliceOptions"(CFunctionOptions): + CReplaceSliceOptions(int64_t start, int64_t stop, c_string replacement) + int64_t start + int64_t stop + c_string replacement + cdef cppclass CReplaceSubstringOptions \ "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions): CReplaceSubstringOptions(c_string pattern, c_string replacement, diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 64d5ad0a30d..8de24c8c249 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -25,6 +25,11 @@ import numpy as np +try: + import pandas as pd +except ImportError: + pd = None + import pyarrow as pa import pyarrow.compute as pc @@ -693,6 +698,29 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +@pytest.mark.pandas +def test_replace_slice(): + offsets = range(-3, 4) + + arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.binary_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() + + arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.utf8_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() + + def test_replace_plain(): ar = pa.array(['foo', 'food', None]) ar = pc.replace_substring(ar, pattern='foo', replacement='bar')