From d569fd6f29162bf234a24fc1f3cec8d539f64fd9 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 8 Jun 2021 16:46:51 -0400 Subject: [PATCH 1/3] ARROW-12948: [C++][Python] Add slice_replace kernel --- cpp/src/arrow/compute/api_scalar.h | 12 ++ .../arrow/compute/kernels/scalar_string.cc | 159 ++++++++++++++++++ .../compute/kernels/scalar_string_test.cc | 88 ++++++++++ docs/source/cpp/compute.rst | 69 ++++---- docs/source/python/api/compute.rst | 6 + python/pyarrow/_compute.pyx | 18 ++ python/pyarrow/compute.py | 1 + python/pyarrow/includes/libarrow.pxd | 7 + python/pyarrow/tests/test_compute.py | 18 ++ 9 files changed, 348 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 190696f6ed5..9a89340d4ed 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -77,6 +77,18 @@ struct ARROW_EXPORT SplitPatternOptions : public SplitOptions { std::string pattern; }; +struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { + explicit ReplaceSliceOptions(int64_t start, int64_t stop, std::string replacement) + : start(start), stop(stop), replacement(std::move(replacement)) {} + + /// Index to start slicing at + int64_t start = 0; + /// Index to stop slicing at + int64_t stop = std::numeric_limits::max(); + /// String to replace the slice with + std::string replacement; +}; + struct ARROW_EXPORT ReplaceSubstringOptions : public FunctionOptions { explicit ReplaceSubstringOptions(std::string pattern, std::string replacement, int64_t max_replacements = -1) diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 8b740f3742a..7cb90592a8b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2288,6 +2288,164 @@ const FunctionDoc replace_substring_regex_doc( {"strings"}, "ReplaceSubstringOptions"); #endif +// ---------------------------------------------------------------------- +// Replace slice + +struct ReplaceSliceTransformBase : public StringTransformBase { + using State = OptionsWrapper; + + const ReplaceSliceOptions* options; + + explicit ReplaceSliceTransformBase(const ReplaceSliceOptions& options) + : options{&options} {} + + int64_t MaxCodeunits(int64_t ninputs, int64_t input_ncodeunits) override { + return ninputs * options->replacement.size() + input_ncodeunits; + } +}; + +struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + int64_t before_slice = 0; + int64_t after_slice = 0; + uint8_t* output_start = output; + + if (opts.start >= 0) { + // Count from left + before_slice = std::min(input_string_ncodeunits, opts.start); + } else { + // Count from right + before_slice = std::max(0, input_string_ncodeunits + opts.start); + } + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.stop >= 0) { + // Count from left + after_slice = + std::min(input_string_ncodeunits, std::max(opts.start, opts.stop)); + } else { + // Count from right + after_slice = std::max(before_slice, input_string_ncodeunits + opts.stop); + } + output = std::copy(input, input + before_slice, output); + output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); + output = std::copy(input + after_slice, input + input_string_ncodeunits, output); + return std::distance(output_start, output); + } +}; + +struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase { + using ReplaceSliceTransformBase::ReplaceSliceTransformBase; + int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, + uint8_t* output) { + const auto& opts = *options; + const uint8_t* begin = input; + const uint8_t* end = input + input_string_ncodeunits; + const uint8_t *begin_sliced, *end_sliced; + uint8_t* output_start = output; + + // Mimic Pandas: if stop would be before start, treat as 0-length slice + if (opts.start >= 0) { + // Count from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &begin_sliced, opts.start)) { + return kTransformError; + } + if (opts.stop > options->start) { + // Continue counting from left + const int64_t length = opts.stop - options->start; + if (!arrow::util::UTF8AdvanceCodepoints(begin_sliced, end, &end_sliced, length)) { + return kTransformError; + } + } else if (opts.stop < 0) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // Zero-length slice + end_sliced = begin_sliced; + } + } else { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin, end, &begin_sliced, + -opts.start)) { + return kTransformError; + } + if (opts.stop >= 0) { + // Restart counting from left + if (!arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opts.stop)) { + return kTransformError; + } + if (end_sliced <= begin_sliced) { + // Zero-length slice + end_sliced = begin_sliced; + } + } else if ((opts.stop < 0) && (options->stop > options->start)) { + // Count from right + if (!arrow::util::UTF8AdvanceCodepointsReverse(begin_sliced, end, &end_sliced, + -opts.stop)) { + return kTransformError; + } + } else { + // zero-length slice + end_sliced = begin_sliced; + } + } + output = std::copy(begin, begin_sliced, output); + output = std::copy(opts.replacement.begin(), options->replacement.end(), output); + output = std::copy(end_sliced, end, output); + return std::distance(output_start, output); + } +}; + +template +using AsciiReplaceSlice = StringTransformExecWithState; +template +using Utf8ReplaceSlice = StringTransformExecWithState; + +const FunctionDoc ascii_replace_slice_doc( + "Replace a slice of a string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in bytes.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +const FunctionDoc utf8_replace_slice_doc( + "Replace a slice of a string with `replacement`", + ("For each string in `strings`, replace a slice of the string defined by `start`" + "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " + "and both are measured in codeunits.\n" + "Null values emit null."), + {"strings"}, "ReplaceSliceOptions"); + +void AddReplaceSlice(FunctionRegistry* registry) { + { + auto func = std::make_shared("ascii_replace_slice", Arity::Unary(), + &ascii_replace_slice_doc); + for (const auto& ty : BaseBinaryTypes()) { + DCHECK_OK(func->AddKernel({ty}, ty, + GenerateTypeAgnosticVarBinaryBase(ty), + ReplaceSliceTransformBase::State::Init)); + } + DCHECK_OK(registry->AddFunction(std::move(func))); + } + + { + auto func = std::make_shared("utf8_replace_slice", Arity::Unary(), + &ascii_replace_slice_doc); + DCHECK_OK(func->AddKernel({utf8()}, utf8(), Utf8ReplaceSlice::Exec, + ReplaceSliceTransformBase::State::Init)); + DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), + Utf8ReplaceSlice::Exec, + ReplaceSliceTransformBase::State::Init)); + DCHECK_OK(registry->AddFunction(std::move(func))); + } +} + // ---------------------------------------------------------------------- // Extract with regex @@ -3434,6 +3592,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { MemAllocation::NO_PREALLOCATE); AddExtractRegex(registry); #endif + AddReplaceSlice(registry); AddSlice(registry); AddSplit(registry); AddStrptime(registry); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index c4b6956be2b..b988e987c3a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -122,6 +122,50 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) { // TODO: case-insensitive } +TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) { + ReplaceSliceOptions options{0, 1, "XX"}; + this->CheckUnary("ascii_replace_slice", "[]", this->type(), "[]", &options); + this->CheckUnary("ascii_replace_slice", R"([null, "", "a", "ab", "abc"])", this->type(), + R"([null, "XX", "XX", "XXb", "XXbc"])", &options); + + ReplaceSliceOptions options_whole{0, 5, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcde", "abcdef"])", this->type(), + R"([null, "XX", "XX", "XX", "XX", "XX", "XXf"])", &options_whole); + + ReplaceSliceOptions options_middle{2, 4, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXX", "abXX", "abXXe"])", + &options_middle); + + ReplaceSliceOptions options_neg_start{-3, -2, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XXa", "XXab", "XXbc", "aXXcd", "abXXde"])", + &options_neg_start); + + ReplaceSliceOptions options_neg_end{2, -2, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXde"])", + &options_neg_end); + + // Effectively the same as [2, 2) + ReplaceSliceOptions options_flip{2, 0, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXcde"])", + &options_flip); + + // Effectively the same as [-3, -3) + ReplaceSliceOptions options_neg_flip{-3, -5, "XX"}; + this->CheckUnary("ascii_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XXa", "XXab", "XXabc", "aXXbcd", "abXXcde"])", + &options_neg_flip); +} + template class TestStringKernels : public BaseTestStringKernels {}; @@ -745,6 +789,50 @@ TYPED_TEST(TestStringKernels, SplitRegexReverse) { } #endif +TYPED_TEST(TestStringKernels, Utf8ReplaceSlice) { + ReplaceSliceOptions options{0, 1, "χχ"}; + this->CheckUnary("utf8_replace_slice", "[]", this->type(), "[]", &options); + this->CheckUnary("utf8_replace_slice", R"([null, "", "π", "πb", "πbθ"])", this->type(), + R"([null, "χχ", "χχ", "χχb", "χχbθ"])", &options); + + ReplaceSliceOptions options_whole{0, 5, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθde", "πbθdef"])", this->type(), + R"([null, "χχ", "χχ", "χχ", "χχ", "χχ", "χχf"])", &options_whole); + + ReplaceSliceOptions options_middle{2, 4, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχ", "πbχχ", "πbχχe"])", + &options_middle); + + ReplaceSliceOptions options_neg_start{-3, -2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχπ", "χχπb", "χχbθ", "πχχθd", "πbχχde"])", + &options_neg_start); + + ReplaceSliceOptions options_neg_end{2, -2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχde"])", + &options_neg_end); + + // Effectively the same as [2, 2) + ReplaceSliceOptions options_flip{2, 0, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχθde"])", + &options_flip); + + // Effectively the same as [-3, -3) + ReplaceSliceOptions options_neg_flip{-3, -5, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχπ", "χχπb", "χχπbθ", "πχχbθd", "πbχχθde"])", + &options_neg_flip); +} + TYPED_TEST(TestStringKernels, ReplaceSubstring) { ReplaceSubstringOptions options{"foo", "bazz"}; this->CheckUnary("replace_substring", R"(["foo", "this foo that foo", null])", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index ad2d9f8f5d2..bc831ecc4fd 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -451,60 +451,69 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+--------------------------+------------+-------------------------+---------------------+-------------------------------------------------+ -| Function name | Arity | Input types | Output type | Notes | Options class | -+==========================+============+=========================+=====================+=========+=======================================+ -| ascii_lower | Unary | String-like | String-like | \(1) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| ascii_reverse | Unary | String-like | String-like | \(2) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| ascii_upper | Unary | String-like | String-like | \(1) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(3) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| replace_substring | Unary | String-like | String-like | \(4) | :struct:`ReplaceSubstringOptions` | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| replace_substring_regex | Unary | String-like | String-like | \(5) | :struct:`ReplaceSubstringOptions` | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_length | Unary | String-like | Int32 or Int64 | \(6) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_lower | Unary | String-like | String-like | \(7) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_reverse | Unary | String-like | String-like | \(8) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ -| utf8_upper | Unary | String-like | String-like | \(7) | | -+--------------------------+------------+-------------------------+---------------------+---------+---------------------------------------+ ++--------------------------+------------+-------------------------+------------------------+-------------------------------------------------+ +| Function name | Arity | Input types | Output type | Notes | Options class | ++==========================+============+=========================+========================+=========+=======================================+ +| ascii_lower | Unary | String-like | String-like | \(1) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| ascii_replace_slice | Unary | String-like | Binary- or String-like | \(2) | :struct:`ReplaceSliceOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| ascii_reverse | Unary | String-like | String-like | \(3) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| ascii_upper | Unary | String-like | String-like | \(2) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(4) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| replace_substring | Unary | String-like | String-like | \(5) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| replace_substring_regex | Unary | String-like | String-like | \(6) | :struct:`ReplaceSubstringOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_length | Unary | String-like | Int32 or Int64 | \(7) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_lower | Unary | String-like | String-like | \(8) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_replace_slice | Unary | String-like | String-like | \(2) | :struct:`ReplaceSliceOptions` | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_reverse | Unary | String-like | String-like | \(9) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ +| utf8_upper | Unary | String-like | String-like | \(8) | | ++--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ * \(1) Each ASCII character in the input is converted to lowercase or uppercase. Non-ASCII characters are left untouched. -* \(2) ASCII input is reversed to the output. If non-ASCII characters +* \(2) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` + (inclusive) to :member:`ReplaceSliceOptions::stop` (exclusive) by + :member:`ReplaceSubstringOptions::replacement`. The ASCII kernel measures the slice + in bytes, while the UTF8 kernel measures the slice in codeunits. + +* \(3) ASCII input is reversed to the output. If non-ASCII characters are present, ``Invalid`` :class:`Status` will be returned. -* \(3) Output is the physical length in bytes of each input element. Output +* \(4) Output is the physical length in bytes of each input element. Output type is Int32 for Binary / String, Int64 for LargeBinary / LargeString. -* \(4) Replace non-overlapping substrings that match to +* \(5) Replace non-overlapping substrings that match to :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. -* \(5) Replace non-overlapping substrings that match to the regular expression +* \(6) Replace non-overlapping substrings that match to the regular expression :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. Note that if the pattern contains groups, backreferencing can be used. -* \(6) Output is the number of characters (not bytes) of each input element. +* \(7) Output is the number of characters (not bytes) of each input element. Output type is Int32 for String, Int64 for LargeString. -* \(7) Each UTF8-encoded character in the input is converted to lowercase or +* \(8) Each UTF8-encoded character in the input is converted to lowercase or uppercase. -* \(8) Each UTF8-encoded code unit is written in reverse order to the output. +* \(9) Each UTF8-encoded code unit is written in reverse order to the output. If the input is not valid UTF8, then the output is undefined (but the size of output buffers will be preserved). diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 1dbcb3073ca..469dc932e18 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -166,9 +166,15 @@ String Transforms :toctree: ../generated/ ascii_lower + ascii_replace_slice ascii_reverse ascii_upper + binary_length + replace_substring + replace_substring_regex + utf8_length utf8_lower + utf8_replace_slice utf8_reverse utf8_upper diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index 8da0ea05006..104cd1bac1f 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -701,6 +701,24 @@ class TrimOptions(_TrimOptions): self._set_options(characters) +cdef class _ReplaceSliceOptions(FunctionOptions): + cdef: + unique_ptr[CReplaceSliceOptions] replace_slice_options + + cdef const CFunctionOptions* get_options(self) except NULL: + return self.replace_slice_options.get() + + def _set_options(self, start, stop, replacement): + self.replace_slice_options.reset( + new CReplaceSliceOptions(start, stop, tobytes(replacement)) + ) + + +class ReplaceSliceOptions(_ReplaceSliceOptions): + def __init__(self, start, stop, replacement): + self._set_options(start, stop, replacement) + + cdef class _ReplaceSubstringOptions(FunctionOptions): cdef: unique_ptr[CReplaceSubstringOptions] replace_substring_options diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8dc7181514c..44282369f87 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -41,6 +41,7 @@ PartitionNthOptions, ProjectOptions, QuantileOptions, + ReplaceSliceOptions, ReplaceSubstringOptions, ScalarAggregateOptions, SetLookupOptions, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index eefca44605c..d5ce98d9a88 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1817,6 +1817,13 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: c_bool reverse) c_string pattern + cdef cppclass CReplaceSliceOptions \ + "arrow::compute::ReplaceSliceOptions"(CFunctionOptions): + CReplaceSliceOptions(int64_t start, int64_t stop, c_string replacement) + int64_t start + int64_t stop + c_string replacement + cdef cppclass CReplaceSubstringOptions \ "arrow::compute::ReplaceSubstringOptions"(CFunctionOptions): CReplaceSubstringOptions(c_string pattern, c_string replacement, diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 64d5ad0a30d..b16b6ef3c24 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -693,6 +693,24 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +def test_replace_slice(): + arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd']) + res = pc.ascii_replace_slice(arr, start=1, stop=3, replacement='XX') + assert res.tolist() == [None, 'XX', 'aXX', 'aXX', 'aXX', 'aXXd'] + res = pc.ascii_replace_slice(arr, start=-2, stop=3, replacement='XX') + assert res.tolist() == [None, 'XX', 'XX', 'XX', 'aXX', 'abXXd'] + res = pc.ascii_replace_slice(arr, start=-3, stop=-2, replacement='XX') + assert res.tolist() == [None, 'XX', 'XXa', 'XXab', 'XXbc', 'aXXcd'] + + arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd']) + res = pc.utf8_replace_slice(arr, start=1, stop=3, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'πχχ', 'πχχ', 'πχχ', 'πχχd'] + res = pc.utf8_replace_slice(arr, start=-2, stop=3, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'χχ', 'χχ', 'πχχ', 'πbχχd'] + res = pc.utf8_replace_slice(arr, start=-3, stop=-2, replacement='χχ') + assert res.tolist() == [None, 'χχ', 'χχπ', 'χχπb', 'χχbθ', 'πχχθd'] + + def test_replace_plain(): ar = pa.array(['foo', 'food', None]) ar = pc.replace_substring(ar, pattern='foo', replacement='bar') From b8b62313ff14c079f2639f3fc9e6cab254819864 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 10 Jun 2021 09:44:26 -0400 Subject: [PATCH 2/3] ARROW-12948: [C++] Fix edge case, rename to binary_replace_slice --- cpp/src/arrow/compute/api_scalar.h | 4 +- .../arrow/compute/kernels/scalar_string.cc | 23 ++++++----- .../compute/kernels/scalar_string_test.cc | 30 +++++++++----- docs/source/cpp/compute.rst | 22 +++++----- docs/source/python/api/compute.rst | 2 +- python/pyarrow/tests/test_compute.py | 40 ++++++++++++------- 6 files changed, 72 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 9a89340d4ed..6e9a9340f2c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -82,9 +82,9 @@ struct ARROW_EXPORT ReplaceSliceOptions : public FunctionOptions { : start(start), stop(stop), replacement(std::move(replacement)) {} /// Index to start slicing at - int64_t start = 0; + int64_t start; /// Index to stop slicing at - int64_t stop = std::numeric_limits::max(); + int64_t stop; /// String to replace the slice with std::string replacement; }; diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 7cb90592a8b..b6c1b8f6261 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -2304,7 +2304,7 @@ struct ReplaceSliceTransformBase : public StringTransformBase { } }; -struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { +struct BinaryReplaceSliceTransform : ReplaceSliceTransformBase { using ReplaceSliceTransformBase::ReplaceSliceTransformBase; int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, uint8_t* output) { @@ -2324,7 +2324,7 @@ struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { if (opts.stop >= 0) { // Count from left after_slice = - std::min(input_string_ncodeunits, std::max(opts.start, opts.stop)); + std::min(input_string_ncodeunits, std::max(before_slice, opts.stop)); } else { // Count from right after_slice = std::max(before_slice, input_string_ncodeunits + opts.stop); @@ -2332,7 +2332,7 @@ struct AsciiReplaceSliceTransform : ReplaceSliceTransformBase { output = std::copy(input, input + before_slice, output); output = std::copy(opts.replacement.begin(), opts.replacement.end(), output); output = std::copy(input + after_slice, input + input_string_ncodeunits, output); - return std::distance(output_start, output); + return output - output_start; } }; @@ -2397,17 +2397,18 @@ struct Utf8ReplaceSliceTransform : ReplaceSliceTransformBase { output = std::copy(begin, begin_sliced, output); output = std::copy(opts.replacement.begin(), options->replacement.end(), output); output = std::copy(end_sliced, end, output); - return std::distance(output_start, output); + return output - output_start; } }; template -using AsciiReplaceSlice = StringTransformExecWithState; +using BinaryReplaceSlice = + StringTransformExecWithState; template using Utf8ReplaceSlice = StringTransformExecWithState; -const FunctionDoc ascii_replace_slice_doc( - "Replace a slice of a string with `replacement`", +const FunctionDoc binary_replace_slice_doc( + "Replace a slice of a binary string with `replacement`", ("For each string in `strings`, replace a slice of the string defined by `start`" "and `stop` with `replacement`. `start` is inclusive and `stop` is exclusive, " "and both are measured in bytes.\n" @@ -2424,11 +2425,11 @@ const FunctionDoc utf8_replace_slice_doc( void AddReplaceSlice(FunctionRegistry* registry) { { - auto func = std::make_shared("ascii_replace_slice", Arity::Unary(), - &ascii_replace_slice_doc); + auto func = std::make_shared("binary_replace_slice", Arity::Unary(), + &binary_replace_slice_doc); for (const auto& ty : BaseBinaryTypes()) { DCHECK_OK(func->AddKernel({ty}, ty, - GenerateTypeAgnosticVarBinaryBase(ty), + GenerateTypeAgnosticVarBinaryBase(ty), ReplaceSliceTransformBase::State::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); @@ -2436,7 +2437,7 @@ void AddReplaceSlice(FunctionRegistry* registry) { { auto func = std::make_shared("utf8_replace_slice", Arity::Unary(), - &ascii_replace_slice_doc); + &utf8_replace_slice_doc); DCHECK_OK(func->AddKernel({utf8()}, utf8(), Utf8ReplaceSlice::Exec, ReplaceSliceTransformBase::State::Init)); DCHECK_OK(func->AddKernel({large_utf8()}, large_utf8(), diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index b988e987c3a..7d52d6aacf2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -124,43 +124,49 @@ TYPED_TEST(TestBinaryKernels, CountSubstring) { TYPED_TEST(TestBinaryKernels, AsciiReplaceSlice) { ReplaceSliceOptions options{0, 1, "XX"}; - this->CheckUnary("ascii_replace_slice", "[]", this->type(), "[]", &options); - this->CheckUnary("ascii_replace_slice", R"([null, "", "a", "ab", "abc"])", this->type(), - R"([null, "XX", "XX", "XXb", "XXbc"])", &options); + this->CheckUnary("binary_replace_slice", "[]", this->type(), "[]", &options); + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc"])", + this->type(), R"([null, "XX", "XX", "XXb", "XXbc"])", &options); ReplaceSliceOptions options_whole{0, 5, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcde", "abcdef"])", this->type(), R"([null, "XX", "XX", "XX", "XX", "XX", "XXf"])", &options_whole); ReplaceSliceOptions options_middle{2, 4, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), R"([null, "XX", "aXX", "abXX", "abXX", "abXX", "abXXe"])", &options_middle); ReplaceSliceOptions options_neg_start{-3, -2, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), R"([null, "XX", "XXa", "XXab", "XXbc", "aXXcd", "abXXde"])", &options_neg_start); ReplaceSliceOptions options_neg_end{2, -2, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXde"])", &options_neg_end); + ReplaceSliceOptions options_neg_pos{-1, 2, "XX"}; + this->CheckUnary("binary_replace_slice", + R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), + R"([null, "XX", "XX", "aXX", "abXXc", "abcXXd", "abcdXXe"])", + &options_neg_pos); + // Effectively the same as [2, 2) ReplaceSliceOptions options_flip{2, 0, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), R"([null, "XX", "aXX", "abXX", "abXXc", "abXXcd", "abXXcde"])", &options_flip); // Effectively the same as [-3, -3) ReplaceSliceOptions options_neg_flip{-3, -5, "XX"}; - this->CheckUnary("ascii_replace_slice", + this->CheckUnary("binary_replace_slice", R"([null, "", "a", "ab", "abc", "abcd", "abcde"])", this->type(), R"([null, "XX", "XXa", "XXab", "XXabc", "aXXbcd", "abXXcde"])", &options_neg_flip); @@ -818,6 +824,12 @@ TYPED_TEST(TestStringKernels, Utf8ReplaceSlice) { R"([null, "χχ", "πχχ", "πbχχ", "πbχχθ", "πbχχθd", "πbχχde"])", &options_neg_end); + ReplaceSliceOptions options_neg_pos{-1, 2, "χχ"}; + this->CheckUnary("utf8_replace_slice", + R"([null, "", "π", "πb", "πbθ", "πbθd", "πbθde"])", this->type(), + R"([null, "χχ", "χχ", "πχχ", "πbχχθ", "πbθχχd", "πbθdχχe"])", + &options_neg_pos); + // Effectively the same as [2, 2) ReplaceSliceOptions options_flip{2, 0, "χχ"}; this->CheckUnary("utf8_replace_slice", diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index bc831ecc4fd..3428f1c6add 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -456,13 +456,13 @@ String transforms +==========================+============+=========================+========================+=========+=======================================+ | ascii_lower | Unary | String-like | String-like | \(1) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ -| ascii_replace_slice | Unary | String-like | Binary- or String-like | \(2) | :struct:`ReplaceSliceOptions` | +| ascii_reverse | Unary | String-like | String-like | \(2) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ -| ascii_reverse | Unary | String-like | String-like | \(3) | | +| ascii_upper | Unary | String-like | String-like | \(1) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ -| ascii_upper | Unary | String-like | String-like | \(2) | | +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(3) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | \(4) | | +| binary_replace_slice | Unary | String-like | Binary- or String-like | \(4) | :struct:`ReplaceSliceOptions` | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ | replace_substring | Unary | String-like | String-like | \(5) | :struct:`ReplaceSubstringOptions` | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ @@ -483,17 +483,17 @@ String transforms * \(1) Each ASCII character in the input is converted to lowercase or uppercase. Non-ASCII characters are left untouched. -* \(2) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` - (inclusive) to :member:`ReplaceSliceOptions::stop` (exclusive) by - :member:`ReplaceSubstringOptions::replacement`. The ASCII kernel measures the slice - in bytes, while the UTF8 kernel measures the slice in codeunits. - -* \(3) ASCII input is reversed to the output. If non-ASCII characters +* \(2) ASCII input is reversed to the output. If non-ASCII characters are present, ``Invalid`` :class:`Status` will be returned. -* \(4) Output is the physical length in bytes of each input element. Output +* \(3) Output is the physical length in bytes of each input element. Output type is Int32 for Binary / String, Int64 for LargeBinary / LargeString. +* \(4) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` + (inclusive) to :member:`ReplaceSliceOptions::stop` (exclusive) by + :member:`ReplaceSubstringOptions::replacement`. The binary kernel measures the + slice in bytes, while the UTF8 kernel measures the slice in codeunits. + * \(5) Replace non-overlapping substrings that match to :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`. If diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 469dc932e18..2e37f9169a7 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -166,10 +166,10 @@ String Transforms :toctree: ../generated/ ascii_lower - ascii_replace_slice ascii_reverse ascii_upper binary_length + binary_replace_slice replace_substring replace_substring_regex utf8_length diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index b16b6ef3c24..8de24c8c249 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -25,6 +25,11 @@ import numpy as np +try: + import pandas as pd +except ImportError: + pd = None + import pyarrow as pa import pyarrow.compute as pc @@ -693,22 +698,27 @@ def test_string_py_compat_boolean(function_name, variant): assert arrow_func(ar)[0].as_py() == getattr(c, py_name)() +@pytest.mark.pandas def test_replace_slice(): - arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd']) - res = pc.ascii_replace_slice(arr, start=1, stop=3, replacement='XX') - assert res.tolist() == [None, 'XX', 'aXX', 'aXX', 'aXX', 'aXXd'] - res = pc.ascii_replace_slice(arr, start=-2, stop=3, replacement='XX') - assert res.tolist() == [None, 'XX', 'XX', 'XX', 'aXX', 'abXXd'] - res = pc.ascii_replace_slice(arr, start=-3, stop=-2, replacement='XX') - assert res.tolist() == [None, 'XX', 'XXa', 'XXab', 'XXbc', 'aXXcd'] - - arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd']) - res = pc.utf8_replace_slice(arr, start=1, stop=3, replacement='χχ') - assert res.tolist() == [None, 'χχ', 'πχχ', 'πχχ', 'πχχ', 'πχχd'] - res = pc.utf8_replace_slice(arr, start=-2, stop=3, replacement='χχ') - assert res.tolist() == [None, 'χχ', 'χχ', 'χχ', 'πχχ', 'πbχχd'] - res = pc.utf8_replace_slice(arr, start=-3, stop=-2, replacement='χχ') - assert res.tolist() == [None, 'χχ', 'χχπ', 'χχπb', 'χχbθ', 'πχχθd'] + offsets = range(-3, 4) + + arr = pa.array([None, '', 'a', 'ab', 'abc', 'abcd', 'abcde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.binary_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() + + arr = pa.array([None, '', 'π', 'πb', 'πbθ', 'πbθd', 'πbθde']) + series = arr.to_pandas() + for start in offsets: + for stop in offsets: + expected = series.str.slice_replace(start, stop, 'XX') + actual = pc.utf8_replace_slice( + arr, start=start, stop=stop, replacement='XX') + assert actual.tolist() == expected.tolist() def test_replace_plain(): From 30e164d76850ae0edbb56797288f24fc6e54e5bd Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 10 Jun 2021 11:15:18 -0400 Subject: [PATCH 3/3] ARROW-12948: [C++] Fix note number --- docs/source/cpp/compute.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 3428f1c6add..b28e3928a74 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -472,7 +472,7 @@ String transforms +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ | utf8_lower | Unary | String-like | String-like | \(8) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ -| utf8_replace_slice | Unary | String-like | String-like | \(2) | :struct:`ReplaceSliceOptions` | +| utf8_replace_slice | Unary | String-like | String-like | \(4) | :struct:`ReplaceSliceOptions` | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+ | utf8_reverse | Unary | String-like | String-like | \(9) | | +--------------------------+------------+-------------------------+------------------------+---------+---------------------------------------+